From cdfba77ba2164c0b36268dd2541e40a4d3bb15a0 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 13:28:30 -0500 Subject: [PATCH 01/33] Add partitioning of rules by shared EClassesToSearch for optimized rule application --- .../eqsat/rewriting/EClassSearcher.scala | 72 +++++++++++++++++++ .../eqsat/saturation/SearchAndApply.scala | 71 ++++++++++++++++-- 2 files changed, 136 insertions(+), 7 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala index 9bc01f66..8e9a3c79 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala @@ -3,6 +3,7 @@ package foresight.eqsat.rewriting import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.EClassCall import foresight.eqsat.readonly.EGraph +import foresight.util.collections.StrictMapOps.toStrictMapOps import java.util.concurrent.atomic.AtomicIntegerArray import scala.collection.compat.immutable.ArraySeq @@ -77,6 +78,12 @@ private[eqsat] object EClassSearcher { */ final val blockSize: Int = 64 + /** + * The threshold size (in number of e-classes) below which we consider an e-graph "small" + * and may choose simpler strategies that have a lower constant overhead. + */ + final val smallEGraphThreshold: Int = 4 * blockSize + /** * Searches for matches within multiple e-classes in parallel. * @@ -135,4 +142,69 @@ private[eqsat] object EClassSearcher { case HaltSearchException => // Swallow the exception to halt the search early } } + + /** + * Result of partitioning rules by shared EClassesToSearch instances. + * + * @param rulesPerSharedEClassToSearch Map from shared EClassesToSearch instances to the rules that use them. + * @param regularRules Rules that either don't use EClassSearcher or use unique EClassesToSearch. + */ + final case class PartitionedRules[ + NodeT, + MatchT, + EGraphT <: EGraph[NodeT] + ]( + rulesPerSharedEClassToSearch: Map[EClassesToSearch[EGraphT], Seq[Rewrite[NodeT, MatchT, EGraphT]]], + regularRules: Seq[Rewrite[NodeT, MatchT, EGraphT]] + ) + + /** + * Partitions a sequence of rules into those that share EClassesToSearch instances and those that do not. + * + * This is useful for optimizing rule application by grouping rules that search the same e-classes. + * @param rules Sequence of rules to partition. + * @tparam NodeT Node payload type. + * @tparam MatchT Match element type. + * @tparam EGraphT E-graph type. + * @return A [[PartitionedRules]] instance containing the partitioned rules. + */ + def partitionRulesBySharedEClassesToSearch[ + NodeT, + MatchT, + EGraphT <: EGraph[NodeT] + ](rules: Seq[Rewrite[NodeT, MatchT, EGraphT]]): PartitionedRules[NodeT, MatchT, EGraphT] = { + // Collect the EClassesToSearch instances from all EClassSearcher rules. + val eclassesToSearchPerRule = rules.collect { + case Rule(name, searcher: EClassSearcher[NodeT, MatchT, _], _) => + name -> searcher.classesToSearch + }.toMap + + // Count how many times each unique EClassesToSearch instance is used across all rules. + val usageCounts = + eclassesToSearchPerRule + .values + .groupBy(identity) // groups by unique EClassesToSearch instance + .mapValuesStrict(_.size) // count how many times each one appears + + // For shared EClassesToSearch, group the corresponding rules together. + val rulesPerSharedEClassToSearch = + usageCounts + .filter { case (_, count) => count > 1 } + .keys + .map { eclassesToSearch => + eclassesToSearch -> rules.collect { + case rule if eclassesToSearchPerRule.get(rule.name).contains(eclassesToSearch) => + rule + } + }.toMap + + // Rules that either don't use EClassSearcher or use a unique EClassesToSearch. + val regularRules: Seq[Rewrite[NodeT, MatchT, EGraphT]] = + rules.filter { rule => + !eclassesToSearchPerRule.contains(rule.name) || + usageCounts(eclassesToSearchPerRule(rule.name)) == 1 + } + + PartitionedRules(rulesPerSharedEClassToSearch, regularRules) + } } diff --git a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala index 3081f3e2..6b2a6105 100644 --- a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala +++ b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala @@ -2,11 +2,13 @@ package foresight.eqsat.saturation import foresight.eqsat.commands.{Command, CommandQueue} import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.rewriting.{PortableMatch, Rewrite} +import foresight.eqsat.rewriting.{EClassSearcher, EClassesToSearch, PortableMatch, Rewrite, Rule} import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithRecordedApplications} import foresight.eqsat.mutable.{FreezableEGraph, EGraph => MutableEGraph} import foresight.eqsat.readonly +import foresight.eqsat.rewriting.SearcherContinuation.Continuation import foresight.util.collections.StrictMapOps.toStrictMapOps +import foresight.util.collections.UnsafeSeqFromArray import scala.collection.mutable.HashMap @@ -263,15 +265,70 @@ object SearchAndApply { final override def apply(rules: Seq[Rewrite[NodeT, MatchT, EGraphT]], egraph: EGraphT, parallelize: ParallelMap): Option[EGraphT] = { + + val updates = Seq.newBuilder[Command[NodeT]] val ruleMatchingAndApplicationParallelize = parallelize.child("rule matching+application") - val updates = ruleMatchingAndApplicationParallelize( - rules, - (rule: Rewrite[NodeT, MatchT, EGraphT]) => { - rule.delayed(egraph, ruleMatchingAndApplicationParallelize) + + if (egraph.classCount <= EClassSearcher.smallEGraphThreshold) { + // Small e-graph optimization: for small e-graphs, the overhead of partitioning and + // fusing rule applications outweighs the benefits. Just process each rule normally. + for (rule <- rules) { + updates += rule.delayed(egraph, ruleMatchingAndApplicationParallelize) + } + } else { + // Idea: EClassSearcher rules are the common case, and they apply in parallel over a subset of + // e-classes in the e-graph. If multiple rules share the same subset of e-classes to search, + // we can group them together to fuse iterations over those e-classes. Fusion both reduces + // redundant work and increases the amount of work done per e-class per iteration, allowing + // for better parallelism. + val partitioned = EClassSearcher.partitionRulesBySharedEClassesToSearch(rules) + + // Process regular rules normally. + for (rule <- partitioned.regularRules) { + updates += rule.delayed(egraph, ruleMatchingAndApplicationParallelize) + } + + // Process shared EClassesToSearch rules together. + for ((eclassesToSearch, sharedRules) <- partitioned.rulesPerSharedEClassToSearch) { + updates ++= ruleMatchingAndApplicationParallelize.collectFrom[Command[NodeT]] { (add: Command[NodeT] => Unit) => + // Build combined searchers that first search and for each match apply the corresponding rule's applier. + // Each combined searcher corresponds to one of the shared rules. + val commandSearchers = sharedRules.map { + case Rule(_, searcher: EClassSearcher[NodeT, MatchT, _], applier) => + val castSearcher = searcher.asInstanceOf[EClassSearcher[NodeT, MatchT, EGraphT]] + + castSearcher + .andThen(new castSearcher.ContinuationBuilder { + def apply(downstream: castSearcher.Continuation): Continuation[NodeT, MatchT, EGraphT] = (m: MatchT, egraph: EGraphT) => { + if (downstream(m, egraph)) { + applier(m, egraph) match { + case CommandQueue(Seq()) => // Ignore no-op commands. + case cmd => + // Collect nontrivial commands. + add(cmd) + } + true + } else { + false + } + } + }) + + case _ => + throw new IllegalStateException("Expected only EClassSearcher rules in shared EClassesToSearch group.") + } + + EClassSearcher.searchMultiple( + UnsafeSeqFromArray(commandSearchers.toArray), + eclassesToSearch(egraph), + egraph, + ruleMatchingAndApplicationParallelize + ) + } } - ).toSeq + } - update(updates, Map.empty[String, Seq[MatchT]], egraph, parallelize) + update(updates.result(), Map.empty[String, Seq[MatchT]], egraph, parallelize) } } } From 7c70f351e192e1f6abafec864395c43a2e577019 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 13:44:15 -0500 Subject: [PATCH 02/33] Refine type casting for EClassesToSearch in EClassSearcher rule collection --- .../main/scala/foresight/eqsat/rewriting/EClassSearcher.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala index 8e9a3c79..49eecc25 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala @@ -176,7 +176,7 @@ private[eqsat] object EClassSearcher { // Collect the EClassesToSearch instances from all EClassSearcher rules. val eclassesToSearchPerRule = rules.collect { case Rule(name, searcher: EClassSearcher[NodeT, MatchT, _], _) => - name -> searcher.classesToSearch + name -> searcher.classesToSearch.asInstanceOf[EClassesToSearch[EGraphT]] }.toMap // Count how many times each unique EClassesToSearch instance is used across all rules. From de9013ae4a667435a28cbd49e0b4cf197bdcafa4 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 13:48:00 -0500 Subject: [PATCH 03/33] Refactor type handling in EClassSearcher for improved rule processing --- .../main/scala/foresight/eqsat/rewriting/EClassSearcher.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala index 49eecc25..202b659f 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala @@ -154,7 +154,7 @@ private[eqsat] object EClassSearcher { MatchT, EGraphT <: EGraph[NodeT] ]( - rulesPerSharedEClassToSearch: Map[EClassesToSearch[EGraphT], Seq[Rewrite[NodeT, MatchT, EGraphT]]], + rulesPerSharedEClassToSearch: Map[EClassesToSearch[EGraphT], Seq[Rule[NodeT, MatchT, EGraphT]]], regularRules: Seq[Rewrite[NodeT, MatchT, EGraphT]] ) @@ -194,7 +194,7 @@ private[eqsat] object EClassSearcher { .map { eclassesToSearch => eclassesToSearch -> rules.collect { case rule if eclassesToSearchPerRule.get(rule.name).contains(eclassesToSearch) => - rule + rule.asInstanceOf[Rule[NodeT, MatchT, EGraphT]] } }.toMap From 0d196ff916d4d6834c9babdd3d248726dbfb521b Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 13:53:36 -0500 Subject: [PATCH 04/33] Refactor command searcher construction logic --- .../eqsat/rewriting/EClassSearcher.scala | 42 ++++++++++++++++++- .../eqsat/saturation/SearchAndApply.scala | 30 +------------ 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala index 202b659f..84cb7bc6 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala @@ -3,6 +3,7 @@ package foresight.eqsat.rewriting import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.EClassCall import foresight.eqsat.readonly.EGraph +import foresight.eqsat.commands.{Command, CommandQueue} import foresight.util.collections.StrictMapOps.toStrictMapOps import java.util.concurrent.atomic.AtomicIntegerArray @@ -156,7 +157,46 @@ private[eqsat] object EClassSearcher { ]( rulesPerSharedEClassToSearch: Map[EClassesToSearch[EGraphT], Seq[Rule[NodeT, MatchT, EGraphT]]], regularRules: Seq[Rewrite[NodeT, MatchT, EGraphT]] - ) + ) { + + /** + * Creates command searchers for a group of rules that share the same EClassesToSearch. + * + * Each searcher first searches for matches using the shared EClassesToSearch, + * then applies all associated rules to produce commands. These commands + * are passed to the provided continuation. + * + * @param eclassesToSearch The EClassesToSearch instance to get rules for. + * @param continuation Continuation to handle commands produced by rule applications. + * @return A sequence of EClassSearcher instances that produce commands. + */ + def commandSearchers(eclassesToSearch: EClassesToSearch[EGraphT], + continuation: Command[NodeT] => Unit): Seq[EClassSearcher[NodeT, MatchT, EGraphT]] = { + rulesPerSharedEClassToSearch(eclassesToSearch).map { + case Rule(_, searcher: EClassSearcher[NodeT, MatchT, _], applier) => + val castSearcher = searcher.asInstanceOf[EClassSearcher[NodeT, MatchT, EGraphT]] + + castSearcher + .andThen(new castSearcher.ContinuationBuilder { + def apply(downstream: castSearcher.Continuation): castSearcher.Continuation = (m: MatchT, egraph: EGraphT) => { + if (downstream(m, egraph)) { + applier(m, egraph) match { + case CommandQueue(Seq()) => // Ignore no-op commands. + case cmd => + // Collect nontrivial commands. + continuation(cmd) + } + true + } else { + false + } + } + }) + + case _ => throw new IllegalStateException("Expected EClassSearcher rule.") + } + } + } /** * Partitions a sequence of rules into those that share EClassesToSearch instances and those that do not. diff --git a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala index 6b2a6105..b3ae23d9 100644 --- a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala +++ b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala @@ -289,35 +289,9 @@ object SearchAndApply { } // Process shared EClassesToSearch rules together. - for ((eclassesToSearch, sharedRules) <- partitioned.rulesPerSharedEClassToSearch) { + for (eclassesToSearch <- partitioned.rulesPerSharedEClassToSearch.keys) { updates ++= ruleMatchingAndApplicationParallelize.collectFrom[Command[NodeT]] { (add: Command[NodeT] => Unit) => - // Build combined searchers that first search and for each match apply the corresponding rule's applier. - // Each combined searcher corresponds to one of the shared rules. - val commandSearchers = sharedRules.map { - case Rule(_, searcher: EClassSearcher[NodeT, MatchT, _], applier) => - val castSearcher = searcher.asInstanceOf[EClassSearcher[NodeT, MatchT, EGraphT]] - - castSearcher - .andThen(new castSearcher.ContinuationBuilder { - def apply(downstream: castSearcher.Continuation): Continuation[NodeT, MatchT, EGraphT] = (m: MatchT, egraph: EGraphT) => { - if (downstream(m, egraph)) { - applier(m, egraph) match { - case CommandQueue(Seq()) => // Ignore no-op commands. - case cmd => - // Collect nontrivial commands. - add(cmd) - } - true - } else { - false - } - } - }) - - case _ => - throw new IllegalStateException("Expected only EClassSearcher rules in shared EClassesToSearch group.") - } - + val commandSearchers = partitioned.commandSearchers(eclassesToSearch, add) EClassSearcher.searchMultiple( UnsafeSeqFromArray(commandSearchers.toArray), eclassesToSearch(egraph), From 9079bf4bc99f79e6da390a074e4ff04c8d46a9b6 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 14:17:30 -0500 Subject: [PATCH 05/33] Introduce SimplifiedAddCommandInstantiator to reduce closure allocations --- .../rewriting/patterns/PatternApplier.scala | 62 ++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index c6bdc275..730b0195 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -67,39 +67,43 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } - private def instantiateAsSimplifiedAddCommand(pattern: MixedTree[NodeT, Pattern.Var], - m: PatternMatch[NodeT], - egraph: EGraphT, - builder: CommandQueueBuilder[NodeT]): EClassSymbol = { + private final class SimplifiedAddCommandInstantiator(m: PatternMatch[NodeT], + egraph: EGraphT, + builder: CommandQueueBuilder[NodeT]) { + def instantiate(pattern: MixedTree[NodeT, Pattern.Var]): EClassSymbol = { + pattern match { + case MixedTree.Atom(p) => builder.addSimplifiedReal(m(p), egraph) + case MixedTree.Node(t, defs@Seq(), uses, args) => + // No definitions, so we can reuse the PatternMatch and its original slot mapping + addSimplifiedNode(t, defs, uses, args) - pattern match { - case MixedTree.Atom(p) => builder.addSimplifiedReal(m(p), egraph) - case MixedTree.Node(t, defs@Seq(), uses, args) => - // No definitions, so we can reuse the PatternMatch and its original slot mapping - addSimplifiedNode(m, t, defs, uses, args, egraph, builder) - - case MixedTree.Node(t, defs, uses, args) => - val defSlots = defs.map { (s: Slot) => - m.slotMapping.get(s) match { - case Some(v) => v - case None => Slot.fresh() + case MixedTree.Node(t, defs, uses, args) => + val defSlots = defs.map { (s: Slot) => + m.slotMapping.get(s) match { + case Some(v) => v + case None => Slot.fresh() + } } - } - val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) - addSimplifiedNode(newMatch, t, defSlots, uses, args, egraph, builder) + val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) + new SimplifiedAddCommandInstantiator(newMatch, egraph, builder).addSimplifiedNode(t, defSlots, uses, args) + } + } + + private def addSimplifiedNode(nodeType: NodeT, + definitions: SlotSeq, + uses: SlotSeq, + args: immutable.ArraySeq[MixedTree[NodeT, Pattern.Var]]): EClassSymbol = { + val argSymbols = CommandQueueBuilder.symbolArrayFrom(args, instantiate) + val useSymbols = uses.map(m.apply: Slot => Slot) + builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, egraph) } } - private def addSimplifiedNode(m: PatternMatch[NodeT], - nodeType: NodeT, - definitions: SlotSeq, - uses: SlotSeq, - args: immutable.ArraySeq[MixedTree[NodeT, Pattern.Var]], - egraph: EGraphT, - builder: CommandQueueBuilder[NodeT]): EClassSymbol = { - val argSymbols = CommandQueueBuilder.symbolArrayFrom( - args, (mt: MixedTree[NodeT, Pattern.Var]) => instantiateAsSimplifiedAddCommand(mt, m, egraph, builder)) - val useSymbols = uses.map(m.apply: Slot => Slot) - builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, egraph) + private def instantiateAsSimplifiedAddCommand(pattern: MixedTree[NodeT, Pattern.Var], + m: PatternMatch[NodeT], + egraph: EGraphT, + builder: CommandQueueBuilder[NodeT]): EClassSymbol = { + + new SimplifiedAddCommandInstantiator(m, egraph, builder).instantiate(pattern) } } From 11319157d6e48aeb69ff476c63ab3e254ac63b45 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 14:39:02 -0500 Subject: [PATCH 06/33] Add CommandQueue.highestDependencyForNodes method to efficiently compute batch index from node dependencies --- .../eqsat/commands/CommandQueue.scala | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala index 8fcb232a..786995c8 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala @@ -230,6 +230,35 @@ object CommandQueue { } } + /** + * Computes the highest batch index among the dependencies of the given nodes, + * based on the currently known definitions in `defs`. + */ + private def highestDependencyForNodes[NodeT]( + nodes: ArraySeq[(EClassSymbol.Virtual, ENodeSymbol[NodeT])], + defs: mutable.HashMap[EClassSymbol, Int] + ): Int = { + var highest = -1 + var n = 0 + while (n < nodes.length) { + val args = nodes(n)._2.args + var a = 0 + while (a < args.length) { + args(a) match { + case use: EClassSymbol.Virtual => + if (defs.contains(use)) { + val idx = defs(use) + if (idx > highest) highest = idx + } + case _ => // ignore non-virtual uses + } + a += 1 + } + n += 1 + } + highest + } + /** * Batches independent [[AddManyCommand]]s into layers based on their intra-batch dependencies. * @@ -248,9 +277,7 @@ object CommandQueue { val defs = mutable.HashMap.empty[EClassSymbol, Int] for (command <- group) { - val highestDependency = (-1 +: command.uses.collect { - case use: EClassSymbol.Virtual if defs.contains(use) => defs(use) - }).max + val highestDependency = highestDependencyForNodes[NodeT](command.nodes, defs) val batchIndex = highestDependency + 1 if (batchIndex == batches.size) { From f2863447d456653381a5d7588504f3c9f92332df Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 15:02:22 -0500 Subject: [PATCH 07/33] Temporarily disable search-loop interchange optimization in SearchAndApply --- .../foresight/eqsat/saturation/SearchAndApply.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala index b3ae23d9..ab127d3c 100644 --- a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala +++ b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala @@ -171,6 +171,8 @@ object SearchAndApply { MatchT ]: SearchAndApply[NodeT, Rewrite[NodeT, MatchT, EGraphT], EGraphT, MatchT] = { new NoMatchCaching[NodeT, EGraphT, MatchT] { + override def searchLoopInterchange: Boolean = false + override def update(command: Command[NodeT], matches: Map[String, Seq[MatchT]], egraph: EGraphT, @@ -194,6 +196,8 @@ object SearchAndApply { MatchT ]: SearchAndApply[NodeT, Rewrite[NodeT, MatchT, EGraphT], EGraphT, MatchT] = { new NoMatchCaching[NodeT, EGraphT, MatchT] { + override def searchLoopInterchange: Boolean = false + override def update(command: Command[NodeT], matches: Map[String, Seq[MatchT]], egraph: EGraphT, @@ -249,6 +253,13 @@ object SearchAndApply { EGraphT <: readonly.EGraph[NodeT], MatchT ] extends SearchAndApply[NodeT, Rewrite[NodeT, MatchT, EGraphT], EGraphT, MatchT] { + /** + * Whether to perform the search-loop interchange optimization, which groups rules that search the same + * e-classes together to reduce redundant work. + * @return True to enable search-loop interchange, false to disable it. + */ + def searchLoopInterchange: Boolean + final override def search(rule: Rewrite[NodeT, MatchT, EGraphT], egraph: EGraphT, parallelize: ParallelMap): Seq[MatchT] = { @@ -269,7 +280,7 @@ object SearchAndApply { val updates = Seq.newBuilder[Command[NodeT]] val ruleMatchingAndApplicationParallelize = parallelize.child("rule matching+application") - if (egraph.classCount <= EClassSearcher.smallEGraphThreshold) { + if (!searchLoopInterchange || egraph.classCount <= EClassSearcher.smallEGraphThreshold) { // Small e-graph optimization: for small e-graphs, the overhead of partitioning and // fusing rule applications outweighs the benefits. Just process each rule normally. for (rule <- rules) { From c1e0661815154e8c813465c388ba1cf34abb349d Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 15:02:52 -0500 Subject: [PATCH 08/33] Optimize CommandQueue.optimizeAdds --- .../eqsat/commands/CommandQueue.scala | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala index 786995c8..9a3d9c94 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala @@ -5,6 +5,7 @@ import foresight.eqsat.{EClassCall, EClassSymbol, ENodeSymbol, MixedTree} import foresight.eqsat.mutable.{EGraph => MutableEGraph} import foresight.eqsat.readonly +import java.util import scala.collection.mutable import scala.collection.compat.immutable.ArraySeq @@ -230,13 +231,10 @@ object CommandQueue { } } - /** - * Computes the highest batch index among the dependencies of the given nodes, - * based on the currently known definitions in `defs`. - */ - private def highestDependencyForNodes[NodeT]( + + @inline private def highestDependencyForNodes[NodeT]( nodes: ArraySeq[(EClassSymbol.Virtual, ENodeSymbol[NodeT])], - defs: mutable.HashMap[EClassSymbol, Int] + defs: util.IdentityHashMap[EClassSymbol, Int] ): Int = { var highest = -1 var n = 0 @@ -246,8 +244,8 @@ object CommandQueue { while (a < args.length) { args(a) match { case use: EClassSymbol.Virtual => - if (defs.contains(use)) { - val idx = defs(use) + if (defs.containsKey(use)) { + val idx = defs.get(use) if (idx > highest) highest = idx } case _ => // ignore non-virtual uses @@ -274,8 +272,15 @@ object CommandQueue { // i is the highest batch in which any of its dependencies are defined. type ArraySeqBuilder = mutable.Builder[(EClassSymbol.Virtual, ENodeSymbol[NodeT]), ArraySeq[(EClassSymbol.Virtual, ENodeSymbol[NodeT])]] val batches = mutable.ArrayBuffer.empty[ArraySeqBuilder] - val defs = mutable.HashMap.empty[EClassSymbol, Int] + + // Pre-size to avoid rehashing during hot insert loop. + var totalDefs = 0 + for (command <- group) { + totalDefs += command.nodes.length + } + + val defs = new util.IdentityHashMap[EClassSymbol, Int](totalDefs) for (command <- group) { val highestDependency = highestDependencyForNodes[NodeT](command.nodes, defs) @@ -292,7 +297,7 @@ object CommandQueue { var i = 0 while (i < command.nodes.length) { val node = command.nodes(i) - defs(node._1) = batchIndex + defs.put(node._1, batchIndex) i += 1 } } From 9e1bc7ff8453cd796be430522a80b948bea3d3f7 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 15:28:35 -0500 Subject: [PATCH 09/33] Eliminate legacy Command simplification logic --- .../eqsat/commands/AddManyCommand.scala | 68 ------------------- .../foresight/eqsat/commands/Command.scala | 35 ---------- .../eqsat/commands/CommandQueue.scala | 31 --------- .../eqsat/commands/UnionManyCommand.scala | 40 ----------- .../foresight/eqsat/rewriting/Applier.scala | 25 ------- 5 files changed, 199 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala b/foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala index ff605c21..d7f4f107 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala @@ -77,72 +77,4 @@ final case class AddManyCommand[NodeT]( anyChanges } - - /** - * Simplifies this command for the given e-graph and partial reification. - * - * For each node: - * - All argument symbols are refined using the partial reification. - * - If all arguments are real, the node is fully reified and checked - * against the e-graph: - * - If already present, the output symbol is bound immediately. - * - If absent, the node remains for insertion. - * - * @param egraph Target e-graph for context. - * @param partialReification Known bindings for virtual symbols. - * @return - * - A simplified command containing only unresolved insertions, - * or [[CommandQueue.empty]] if all nodes were already present. - * - An updated partial reification containing all newly resolved outputs. - */ - override def simplify( - egraph: readonly.EGraph[NodeT], - partialReification: Map[EClassSymbol.Virtual, EClassCall] - ): (Command[NodeT], Map[EClassSymbol.Virtual, EClassCall]) = { - - val resolvedBuilder = Map.newBuilder[EClassSymbol.Virtual, EClassCall] - val unresolvedBuilder = ArraySeq.newBuilder[(EClassSymbol.Virtual, ENodeSymbol[NodeT])] - - def resolveAllOrNull(args: Seq[EClassSymbol]): Seq[EClassCall] = { - val resolvedArgs = Seq.newBuilder[EClassCall] - for (arg <- args) { - arg match { - case call: EClassCall => - resolvedArgs += call - case v: EClassSymbol.Virtual if partialReification.contains(v) => - resolvedArgs += partialReification(v) - case _ => - // Argument is virtual and not in the partial reification. - // Cannot fully resolve this node. - return null - } - } - resolvedArgs.result() - } - - for ((result, node) <- nodes) { - val resolvedArgs = resolveAllOrNull(node.args) - if (resolvedArgs != null) { - val reifiedNode = ENode(node.nodeType, node.definitions, node.uses, resolvedArgs) - egraph.find(reifiedNode) match { - case Some(call) => - resolvedBuilder += (result -> call) - case None => - val refined = node.withArgs(UnsafeSeqFromArray(resolvedArgs)) - unresolvedBuilder += (result -> refined) - } - } else { - val refined = node.withArgs(node.args.map(_.refine(partialReification))) - unresolvedBuilder += (result -> refined) - } - } - - val resolved = resolvedBuilder.result() - val unresolved = unresolvedBuilder.result() - - ( - if (unresolved.isEmpty) CommandQueue.empty else AddManyCommand(unresolved), - resolved - ) - } } diff --git a/foresight/src/main/scala/foresight/eqsat/commands/Command.scala b/foresight/src/main/scala/foresight/eqsat/commands/Command.scala index 251faa18..3ca917c3 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/Command.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/Command.scala @@ -87,41 +87,6 @@ trait Command[NodeT] { (None, mutableReification.toMap) } } - - /** - * Returns a semantically equivalent command that is cheaper to execute on the given e-graph. - * - * Typical simplifications include: - * - eliminating unions whose endpoints are already congruent - * - dropping inserts of nodes already present - * - narrowing work by pre-binding outputs in the returned partial reification - * - * The returned partial map contains bindings this command can prove without running, which callers - * may compose across multiple commands to reduce future work. - * - * @param egraph Target e-graph used as context for optimization. - * @param partialReification Known virtual-to-concrete bindings available upstream. - * @return A pair `(simplifiedCommand, partialBindings)` for this command. - */ - def simplify( - egraph: EGraph[NodeT], - partialReification: Map[EClassSymbol.Virtual, EClassCall] - ): (Command[NodeT], Map[EClassSymbol.Virtual, EClassCall]) - - /** - * Returns a semantically equivalent command that is cheaper to execute on the given e-graph. - * Assumes no prior reification information and does not return any partial bindings. - * - * Typical simplifications include: - * - eliminating unions whose endpoints are already congruent - * - dropping inserts of nodes already present - * - narrowing work by pre-binding outputs in the returned partial reification - * - * @param egraph Target e-graph used as context for optimization. - * @return A simplified command. - */ - final def simplify(egraph: EGraph[NodeT]): Command[NodeT] = - simplify(egraph, Map.empty)._1 } /** diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala index 9a3d9c94..40bcc1e3 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala @@ -146,36 +146,6 @@ final case class CommandQueue[NodeT](commands: Seq[Command[NodeT]]) extends Comm case queue: CommandQueue[NodeT] => queue.flatCommands case command => Seq(command) } - - /** - * Simplifies each command against the current e-graph and threads partial reification. - * - * Each command is simplified in sequence, accumulating any discovered bindings. The resulting queue - * is then [[optimized]] to merge unions and batch adds. - * - * @param egraph Context graph used by sub-command simplifications. - * @param partialReification Upstream bindings available prior to running this queue. - * @return The simplified-and-optimized queue and the accumulated partial bindings. - * - * @example - * {{{ - * val (simp, partial) = q.simplify(g, Map.empty) - * val (maybeG, finalRefs) = simp.apply(g, partial, parallel) - * }}} - */ - override def simplify( - egraph: readonly.EGraph[NodeT], - partialReification: Map[EClassSymbol.Virtual, EClassCall] - ): (Command[NodeT], Map[EClassSymbol.Virtual, EClassCall]) = { - val newQueue = Seq.newBuilder[Command[NodeT]] - var newReification = partialReification - for (command <- flatCommands) { - val (simplified, newReificationPart) = command.simplify(egraph, newReification) - newQueue += simplified - newReification ++= newReificationPart - } - (CommandQueue(newQueue.result()), newReification) - } } /** @@ -273,7 +243,6 @@ object CommandQueue { type ArraySeqBuilder = mutable.Builder[(EClassSymbol.Virtual, ENodeSymbol[NodeT]), ArraySeq[(EClassSymbol.Virtual, ENodeSymbol[NodeT])]] val batches = mutable.ArrayBuffer.empty[ArraySeqBuilder] - // Pre-size to avoid rehashing during hot insert loop. var totalDefs = 0 for (command <- group) { diff --git a/foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala b/foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala index c7ec6970..58bc42ab 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala @@ -67,44 +67,4 @@ final case class UnionManyCommand[NodeT](pairs: Seq[(EClassSymbol, EClassSymbol) true } } - - /** - * Simplifies the union set against the current e-graph and partial bindings. - * - * Each side is first refined using `partialReification`. Pairs that become - * two real calls already known to be equal are dropped. If all pairs drop, - * the result is [[CommandQueue.empty]]; otherwise a reduced [[UnionManyCommand]] - * is returned. - * - * @param egraph Context used for equality checks. - * @param partialReification Known virtual-to-real bindings. - * @return A simplified command and an (empty) partial reification. - * - * @example - * {{{ - * val v = EClassSymbol.virtual() - * val simplified = UnionManyCommand(Seq(v -> EClassSymbol.real(callX))) - * .simplify(egraph, Map(v -> callX)) - * // Becomes CommandQueue.empty because both sides resolve to the same class. - * }}} - */ - override def simplify( - egraph: readonly.EGraph[NodeT], - partialReification: Map[EClassSymbol.Virtual, EClassCall] - ): (Command[NodeT], Map[EClassSymbol.Virtual, EClassCall]) = { - val builder = Seq.newBuilder[(EClassSymbol, EClassSymbol)] - for ((left, right) <- pairs) { - val lRefined = left.refine(partialReification) - val rRefined = right.refine(partialReification) - (lRefined, rRefined) match { - case (l: EClassCall, r: EClassCall) => - if (!egraph.areSame(l, r)) builder += ((lRefined, rRefined)) - case _ => - builder += ((lRefined, rRefined)) - } - } - val simplifiedPairs = builder.result() - if (simplifiedPairs.isEmpty) (CommandQueue.empty, Map.empty) - else (UnionManyCommand(simplifiedPairs), Map.empty) - } } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala index 78d657fc..4e67eea1 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala @@ -59,24 +59,6 @@ object Applier { override def tryReverse: Option[Searcher[NodeT, MatchT, EGraphT]] = Some(Searcher.empty) } - /** - * An applier that simplifies the commands produced by another applier. - * @param applier Inner applier whose commands will be simplified. - * @tparam NodeT Node payload type stored in the e-graph. - * @tparam MatchT The match type produced by a [[Searcher]] and consumed here. - * @tparam EGraphT Concrete e-graph type (must be both [[EGraphLike]] and [[EGraph]]). - */ - final case class Simplify[ - NodeT, - MatchT, - EGraphT <: EGraph[NodeT] - ](applier: Applier[NodeT, MatchT, EGraphT]) extends Applier[NodeT, MatchT, EGraphT] { - override def apply(m: MatchT, egraph: EGraphT): Command[NodeT] = { - val command = applier.apply(m, egraph) - command.simplify(egraph) - } - } - /** * Conditionally apply: run `applier` only when `filter(match, egraph)` is true; otherwise emit no-op. * @@ -172,13 +154,6 @@ object Applier { EGraphT <: EGraph[NodeT] ](private val applier: Applier[NodeT, MatchT, EGraphT]) extends AnyVal { - /** - * Simplify the commands produced by this applier before returning them. - * - * @return An applier that simplifies its commands. - */ - def simplify: Applier[NodeT, MatchT, EGraphT] = Simplify(applier) - /** * Conditionally apply this applier; otherwise emit an empty command. * From 286df58e7c6e0d9376b511cf5abd82057b093b40 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 16:05:25 -0500 Subject: [PATCH 10/33] Restore parallelization across rules --- .../eqsat/saturation/SearchAndApply.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala index ab127d3c..794a6906 100644 --- a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala +++ b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala @@ -276,17 +276,21 @@ object SearchAndApply { final override def apply(rules: Seq[Rewrite[NodeT, MatchT, EGraphT]], egraph: EGraphT, parallelize: ParallelMap): Option[EGraphT] = { - - val updates = Seq.newBuilder[Command[NodeT]] val ruleMatchingAndApplicationParallelize = parallelize.child("rule matching+application") if (!searchLoopInterchange || egraph.classCount <= EClassSearcher.smallEGraphThreshold) { // Small e-graph optimization: for small e-graphs, the overhead of partitioning and // fusing rule applications outweighs the benefits. Just process each rule normally. - for (rule <- rules) { - updates += rule.delayed(egraph, ruleMatchingAndApplicationParallelize) - } + val updates = ruleMatchingAndApplicationParallelize( + rules, + (rule: Rewrite[NodeT, MatchT, EGraphT]) => { + rule.delayed(egraph, ruleMatchingAndApplicationParallelize) + } + ).toSeq + update(updates, Map.empty[String, Seq[MatchT]], egraph, parallelize) } else { + val updates = Seq.newBuilder[Command[NodeT]] + // Idea: EClassSearcher rules are the common case, and they apply in parallel over a subset of // e-classes in the e-graph. If multiple rules share the same subset of e-classes to search, // we can group them together to fuse iterations over those e-classes. Fusion both reduces @@ -311,9 +315,9 @@ object SearchAndApply { ) } } - } - update(updates.result(), Map.empty[String, Seq[MatchT]], egraph, parallelize) + update(updates.result(), Map.empty[String, Seq[MatchT]], egraph, parallelize) + } } } } From 60b84b3f30d5c99bd9e909b1a84f5fd9266bc4f8 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:09:56 -0500 Subject: [PATCH 11/33] Replace Command class hierarchy with CommandSchedule --- .../examples/arithWithLang/ApplierOps.scala | 8 +- .../eqsat/examples/arith/ApplierOps.scala | 10 +- .../eqsat/examples/liar/ApplierOps.scala | 18 +- .../examples/liar/TypeRequirements.scala | 11 +- .../eqsat/examples/sdql/ApplierOps.scala | 10 +- .../scala/foresight/eqsat/EClassCall.scala | 2 +- .../main/scala/foresight/eqsat/ENode.scala | 6 +- .../eqsat/commands/AddManyCommand.scala | 80 ---- .../foresight/eqsat/commands/Command.scala | 173 -------- .../eqsat/commands/CommandQueue.scala | 376 ---------------- .../eqsat/commands/CommandQueueBuilder.scala | 242 ---------- .../eqsat/commands/CommandSchedule.scala | 137 ++++++ .../commands/CommandScheduleBuilder.scala | 200 +++++++++ .../ConcurrentCommandScheduleBuilder.scala | 77 ++++ .../eqsat/commands/UnionManyCommand.scala | 70 --- .../foresight/eqsat/commands/package.scala | 83 ---- .../AbstractMutableHashConsEGraph.scala | 9 +- .../hashCons/immutable/HashConsEGraph.scala | 11 +- .../eqsat/immutable/EGraphLike.scala | 6 +- .../eqsat/immutable/EGraphWithMetadata.scala | 8 +- .../EGraphWithRecordedApplications.scala | 8 +- .../eqsat/immutable/EGraphWithRoot.scala | 8 +- .../foresight/eqsat/mutable/EGraph.scala | 6 +- .../eqsat/mutable/EGraphWithMetadata.scala | 3 +- .../mutable/UpdatingImmutableEGraph.scala | 4 +- .../foresight/eqsat/rewriting/Applier.scala | 27 +- .../eqsat/rewriting/EClassSearcher.scala | 24 +- .../foresight/eqsat/rewriting/Rewrite.scala | 55 ++- .../foresight/eqsat/rewriting/Rule.scala | 38 +- .../foresight/eqsat/rewriting/Searcher.scala | 18 - .../eqsat/rewriting/SearcherLike.scala | 23 + .../foresight/eqsat/rewriting/package.scala | 4 +- .../rewriting/patterns/PatternApplier.scala | 33 +- .../eqsat/saturation/SearchAndApply.scala | 118 +++-- .../commands/CommandQueueBuilderTest.scala | 414 +++++++++--------- 35 files changed, 872 insertions(+), 1448 deletions(-) delete mode 100644 foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala delete mode 100644 foresight/src/main/scala/foresight/eqsat/commands/Command.scala delete mode 100644 foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala delete mode 100644 foresight/src/main/scala/foresight/eqsat/commands/CommandQueueBuilder.scala create mode 100644 foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala create mode 100644 foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala create mode 100644 foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala delete mode 100644 foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala delete mode 100644 foresight/src/main/scala/foresight/eqsat/commands/package.scala diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala index fa299026..bc1d6b80 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala @@ -1,10 +1,10 @@ package foresight.eqsat.examples.arithWithLang import foresight.eqsat.{EClassCall, MixedTree, Slot} -import foresight.eqsat.commands.Command +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.eqsat.extraction.ExtractionAnalysis import foresight.eqsat.lang.Language -import foresight.eqsat.immutable.{EGraphLike, EGraphWithMetadata, EGraph} +import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} import foresight.eqsat.rewriting.Applier import foresight.eqsat.rewriting.patterns.PatternMatch @@ -26,7 +26,7 @@ object ApplierOps { destination: PatternVar): Applier[ArithIR, PatternMatch[ArithIR], EGraphWithMetadata[ArithIR, EGraphT]] = { new Applier[ArithIR, PatternMatch[ArithIR], EGraphWithMetadata[ArithIR, EGraphT]] { - override def apply(m: PatternMatch[ArithIR], egraph: EGraphWithMetadata[ArithIR, EGraphT]): Command[ArithIR] = { + override def apply(m: PatternMatch[ArithIR], egraph: EGraphWithMetadata[ArithIR, EGraphT], builder: CommandScheduleBuilder[ArithIR]): Unit = { val extractedTree = ExtractionAnalysis.smallest[ArithIR].extractor(m(source.variable), egraph) val extractedExpr = L.fromTree[EClassCall](extractedTree) @@ -47,7 +47,7 @@ object ApplierOps { val substituted = subst(extractedExpr) val newMatch = m.copy(varMapping = m.varMapping + (destination.variable -> L.toTree[EClassCall](substituted))) - applier.apply(newMatch, egraph) + applier.apply(newMatch, egraph, builder) } } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala index d308afcd..375b7de1 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala @@ -1,11 +1,11 @@ package foresight.eqsat.examples.arith -import foresight.eqsat.commands.Command +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.eqsat.extraction.ExtractionAnalysis -import foresight.eqsat.readonly.{EGraphWithMetadata, EGraph} +import foresight.eqsat.readonly.{EGraph, EGraphWithMetadata} import foresight.eqsat.rewriting.Applier import foresight.eqsat.rewriting.patterns.{Pattern, PatternMatch} -import foresight.eqsat._ +import foresight.eqsat.* object ApplierOps { implicit class ApplierOfPatternMatchOps[EGraphT <: EGraph[ArithIR]](private val applier: Applier[ArithIR, PatternMatch[ArithIR], EGraphWithMetadata[ArithIR, EGraphT]]) extends AnyVal { @@ -24,7 +24,7 @@ object ApplierOps { destination: Pattern.Var): Applier[ArithIR, PatternMatch[ArithIR], EGraphWithMetadata[ArithIR, EGraphT]] = { new Applier[ArithIR, PatternMatch[ArithIR], EGraphWithMetadata[ArithIR, EGraphT]] { - override def apply(m: PatternMatch[ArithIR], egraph: EGraphWithMetadata[ArithIR, EGraphT]): Command[ArithIR] = { + override def apply(m: PatternMatch[ArithIR], egraph: EGraphWithMetadata[ArithIR, EGraphT], builder: CommandScheduleBuilder[ArithIR]): Unit = { val extracted = ExtractionAnalysis.smallest[ArithIR].extractor[EGraphT](m(source), egraph) def subst(tree: Tree[ArithIR]): MixedTree[ArithIR, EClassCall] = { @@ -37,7 +37,7 @@ object ApplierOps { val substituted = subst(extracted) val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) - applier.apply(newMatch, egraph) + applier.apply(newMatch, egraph, builder) } } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala index 07eeb3bd..e0a40ed4 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala @@ -1,12 +1,12 @@ package foresight.eqsat.examples.liar -import foresight.eqsat.commands.Command +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.eqsat.extraction.ExtractionAnalysis -import foresight.eqsat.immutable.{EGraphLike, EGraphWithMetadata, EGraph} +import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} import foresight.eqsat.rewriting.Applier import foresight.eqsat.rewriting.patterns.{Pattern, PatternApplier, PatternMatch} -import foresight.eqsat._ +import foresight.eqsat.* object ApplierOps { implicit class ApplierOfPatternMatchOps[EGraphT <: EGraphLike[ArrayIR, EGraphT] with EGraph[ArrayIR]](private val applier: Applier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]]) extends AnyVal { @@ -25,7 +25,7 @@ object ApplierOps { destination: Pattern.Var): Applier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]] = { new Applier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]] { - override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT]): Command[ArrayIR] = { + override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT], builder: CommandScheduleBuilder[ArrayIR]): Unit = { val extracted = ExtractionAnalysis.smallest[ArrayIR].extractor[EGraphT](m(source), egraph) def typeOf(tree: MixedTree[ArrayIR, EClassCall]): MixedTree[Type, EClassCall] = { @@ -45,7 +45,7 @@ object ApplierOps { val substituted = subst(extracted) val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) - applier.apply(newMatch, egraph) + applier.apply(newMatch, egraph, builder) } } } @@ -62,10 +62,12 @@ object ApplierOps { */ def typeChecked: Applier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]] = { new Applier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]] { - override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT]): Command[ArrayIR] = { + override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT], builder: CommandScheduleBuilder[ArrayIR]): Unit = { val tree = applier.instantiate(m) - inferType(tree.mapAtoms(_.asInstanceOf[EClassCall]), egraph) - Command.equivalenceSimplified(EClassSymbol.real(m.root), tree, egraph) + val realTree = tree.mapAtoms(_.asInstanceOf[EClassCall]) + inferType(realTree, egraph) + val c = builder.addSimplifiedReal(realTree, egraph) + builder.unionSimplified(EClassSymbol.real(m.root), c, egraph) } } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/TypeRequirements.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/TypeRequirements.scala index c03d4a92..fce48288 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/TypeRequirements.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/TypeRequirements.scala @@ -1,10 +1,9 @@ package foresight.eqsat.examples.liar -import foresight.eqsat.commands.Command -import foresight.eqsat.parallel.ParallelMap +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.eqsat.rewriting.SearcherContinuation.Continuation import foresight.eqsat.rewriting.patterns.{CompiledPattern, Pattern, PatternMatch} -import foresight.eqsat.rewriting.{Applier, ReversibleApplier, ReversibleSearcher, Searcher, SearcherContinuation} +import foresight.eqsat.rewriting.{Applier, ReversibleApplier, Searcher, SearcherContinuation} import foresight.eqsat.MixedTree import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} @@ -54,7 +53,7 @@ object TypeRequirements { } final case class ApplierWithRequirements[EGraphT <: EGraphLike[ArrayIR, EGraphT] with EGraph[ArrayIR]](applier: Applier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]], - types: Map[Pattern.Var, MixedTree[ArrayIR, Pattern.Var]]) + types: Map[Pattern.Var, MixedTree[ArrayIR, Pattern.Var]]) extends ReversibleApplier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]] { // Precompile the patterns. @@ -62,8 +61,8 @@ object TypeRequirements { case (v, t) => v -> t.compiled[EGraph[ArrayIR]] } - override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT]): Command[ArrayIR] = { - applier.flatMap((m2: PatternMatch[ArrayIR], egraph2: EGraphWithMetadata[ArrayIR, EGraphT]) => checkMatch(m2, compiledPatterns, egraph2)).apply(m, egraph) + override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT], builder: CommandScheduleBuilder[ArrayIR]): Unit = { + applier.flatMap((m2: PatternMatch[ArrayIR], egraph2: EGraphWithMetadata[ArrayIR, EGraphT]) => checkMatch(m2, compiledPatterns, egraph2)).apply(m, egraph, builder) } override def tryReverse: Option[Searcher[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]]] = { diff --git a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala index ca7649d5..b0d8d235 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala @@ -1,11 +1,11 @@ package foresight.eqsat.examples.sdql -import foresight.eqsat.commands.Command +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.eqsat.extraction.ExtractionAnalysis import foresight.eqsat.rewriting.Applier import foresight.eqsat.rewriting.patterns.{Pattern, PatternMatch} -import foresight.eqsat.immutable.{EGraphLike, EGraphWithMetadata, EGraph} -import foresight.eqsat._ +import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} +import foresight.eqsat.* object ApplierOps { implicit class ApplierOfPatternMatchOps[EGraphT <: EGraphLike[SdqlIR, EGraphT] with EGraph[SdqlIR]](private val applier: Applier[SdqlIR, PatternMatch[SdqlIR], EGraphWithMetadata[SdqlIR, EGraphT]]) extends AnyVal { @@ -24,7 +24,7 @@ object ApplierOps { destination: Pattern.Var): Applier[SdqlIR, PatternMatch[SdqlIR], EGraphWithMetadata[SdqlIR, EGraphT]] = { new Applier[SdqlIR, PatternMatch[SdqlIR], EGraphWithMetadata[SdqlIR, EGraphT]] { - override def apply(m: PatternMatch[SdqlIR], egraph: EGraphWithMetadata[SdqlIR, EGraphT]): Command[SdqlIR] = { + override def apply(m: PatternMatch[SdqlIR], egraph: EGraphWithMetadata[SdqlIR, EGraphT], builder: CommandScheduleBuilder[SdqlIR]): Unit = { val extracted = ExtractionAnalysis.smallest[SdqlIR].extractor[EGraphT](m(source), egraph) def subst(tree: Tree[SdqlIR]): MixedTree[SdqlIR, EClassCall] = { @@ -37,7 +37,7 @@ object ApplierOps { val substituted = subst(extracted) val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) - applier.apply(newMatch, egraph) + applier.apply(newMatch, egraph, builder) } } } diff --git a/foresight/src/main/scala/foresight/eqsat/EClassCall.scala b/foresight/src/main/scala/foresight/eqsat/EClassCall.scala index 87caee5a..d6ee2399 100644 --- a/foresight/src/main/scala/foresight/eqsat/EClassCall.scala +++ b/foresight/src/main/scala/foresight/eqsat/EClassCall.scala @@ -140,7 +140,7 @@ sealed trait EClassSymbol { * val realCall = v.reify(Map(v -> call)) // returns call * }}} */ - final def reify(reification: collection.Map[EClassSymbol.Virtual, EClassCall]): EClassCall = this match { + final def reify(reification: EClassSymbol.Virtual => EClassCall): EClassCall = this match { case call: EClassCall => call case virtual: EClassSymbol.Virtual => reification(virtual) } diff --git a/foresight/src/main/scala/foresight/eqsat/ENode.scala b/foresight/src/main/scala/foresight/eqsat/ENode.scala index 3b6bb254..fec89df2 100644 --- a/foresight/src/main/scala/foresight/eqsat/ENode.scala +++ b/foresight/src/main/scala/foresight/eqsat/ENode.scala @@ -340,7 +340,7 @@ final class ENode[+NodeT] private ( result } - override def reify(reification: collection.Map[EClassSymbol.Virtual, EClassCall]): ENode[NodeT] = this + override def reify(reification: EClassSymbol.Virtual => EClassCall): ENode[NodeT] = this } /** @@ -504,7 +504,7 @@ sealed trait ENodeSymbol[+NodeT] { * @param reification Mapping from virtual e-class symbols to concrete [[EClassCall]]s. * @return An [[ENode]] with all arguments fully resolved. */ - def reify(reification: collection.Map[EClassSymbol.Virtual, EClassCall]): ENode[NodeT] + def reify(reification: EClassSymbol.Virtual => EClassCall): ENode[NodeT] /** * Creates a copy of this symbol with the given arguments. @@ -583,7 +583,7 @@ object ENodeSymbol { * val node: ENode[MyOp] = symbol.reify(Map(v1 -> call1)) * }}} */ - override def reify(reification: collection.Map[EClassSymbol.Virtual, EClassCall]): ENode[NodeT] = { + override def reify(reification: EClassSymbol.Virtual => EClassCall): ENode[NodeT] = { val reifiedArgs = args.map(_.reify(reification)) ENode(nodeType, definitions, uses, reifiedArgs) } diff --git a/foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala b/foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala deleted file mode 100644 index d7f4f107..00000000 --- a/foresight/src/main/scala/foresight/eqsat/commands/AddManyCommand.scala +++ /dev/null @@ -1,80 +0,0 @@ -package foresight.eqsat.commands - -import foresight.eqsat.{AddNodeResult, EClassCall, EClassSymbol, ENodeSymbol, ENode} -import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.mutable -import foresight.eqsat.readonly -import foresight.util.collections.UnsafeSeqFromArray - -import scala.collection.compat.immutable.ArraySeq -import scala.collection.mutable.{Map => MutableMap} - -/** - * A [[Command]] that inserts multiple [[ENodeSymbol]]s into an e-graph in one batch. - * - * Each entry in [[nodes]] associates: - * - a fresh [[EClassSymbol.Virtual]] that will represent the root e-class of the node once inserted, and - * - the [[ENodeSymbol]] itself, which may reference real or virtual arguments. - * - * Nodes in this batch are *independently* reifiable and can be processed in parallel. - * No reification state is shared between them during insertion. - * - * @tparam NodeT Node type for expressions represented by the e-graph. - * @param nodes Sequence of `(outputSymbol, node)` pairs to insert. The `outputSymbol` - * is bound to the resulting e-class after insertion. The node’s arguments - * are resolved via the reification map provided to [[apply]]. - */ -final case class AddManyCommand[NodeT]( - nodes: ArraySeq[(EClassSymbol.Virtual, ENodeSymbol[NodeT])] - ) extends Command[NodeT] { - - /** All argument symbols used by every node in this batch. */ - override def uses: Seq[EClassSymbol] = nodes.flatMap(_._2.args) - - /** All virtual symbols that will be defined as a result of this command. */ - override def definitions: Seq[EClassSymbol.Virtual] = nodes.map(_._1) - - /** - * Reifies all nodes in [[nodes]] using the given map, then inserts them - * into the e-graph with [[mutable.EGraph.tryAddMany]]. - * - * Nodes are processed independently, allowing parallel insertion. - * The returned reification map binds each output symbol from [[nodes]] - * to the [[EClassCall]] of its inserted or matched e-class. - * - * @param egraph Target e-graph to update. - * @param reification Mapping from virtual e-class symbols to concrete calls, used to resolve each node’s arguments - * before insertion. This map is mutated to include new bindings for every symbol in - * [[definitions]]. - * @param parallelize Strategy for distributing the work. - * @return `true` if at least one node was newly added, or `false` if all were already present. - * - * @example - * {{{ - * val v = EClassSymbol.virtual() - * val n = ENodeSymbol(nodeType, defs, uses, args = Seq(v)) - * val cmd = AddManyCommand(Seq(v -> n)) - * val updated = cmd(egraph, Map(v -> existingCall), parallel) - * }}} - */ - override def apply( - egraph: mutable.EGraph[NodeT], - reification: MutableMap[EClassSymbol.Virtual, EClassCall], - parallelize: ParallelMap - ): Boolean = { - val reifiedNodes = nodes.map(_._2.reify(reification)) - val addResults = egraph.tryAddMany(reifiedNodes, parallelize) - - var anyChanges: Boolean = false - for (((symbol, _), result) <- nodes.zip(addResults)) { - reification(symbol) = result.call - - result match { - case AddNodeResult.Added(_) => anyChanges = true - case AddNodeResult.AlreadyThere(_) => // no change - } - } - - anyChanges - } -} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/Command.scala b/foresight/src/main/scala/foresight/eqsat/commands/Command.scala deleted file mode 100644 index 3ca917c3..00000000 --- a/foresight/src/main/scala/foresight/eqsat/commands/Command.scala +++ /dev/null @@ -1,173 +0,0 @@ -package foresight.eqsat.commands - -import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.{EClassCall, EClassSymbol, MixedTree} -import foresight.eqsat.immutable -import foresight.eqsat.mutable -import foresight.eqsat.readonly.EGraph - -import scala.collection.mutable.{Map => MutableMap} - -/** - * A [[Command]] encapsulates a single, replayable edit to an e-graph. - * - * Commands are pure values that describe what to do; they don’t perform any mutation until applied. - * This design allows callers to build, simplify, batch, or reorder edits before committing them to - * an e-graph. - * - * @tparam NodeT Node type for expressions represented by the e-graph. - */ -trait Command[NodeT] { - - /** - * The e-class symbols this command expects to already exist when it runs. - * - * These are the external dependencies needed to interpret the command (e.g., “union the new class - * with this existing one”). Implementations list all required symbols so that callers can: - * - validate a reification plan ahead of time - * - topologically sort commands - */ - def uses: Seq[EClassSymbol] - - /** - * The virtual e-class symbols this command promises to define when executed. - * - * Each virtual symbol represents an e-class that will be materialized in the target e-graph and - * receive a concrete [[EClassCall]]. After a successful [[apply]], the returned reification map - * contains an entry for every symbol listed here. - */ - def definitions: Seq[EClassSymbol.Virtual] - - /** - * Executes the command against an e-graph. - * - * @param egraph - * Destination e-graph that will be mutated in place. - * @param reification - * Mapping from virtual symbols to concrete calls available before this command runs. This is used - * to ground virtual references present in [[uses]] and is mutated to include new bindings for every - * symbol in [[definitions]]. - * @param parallelize - * Parallelization strategy to label and distribute any internal work. - * @return `true` if any change occurred, or `false` for a no-op. - */ - def apply(egraph: mutable.EGraph[NodeT], - reification: MutableMap[EClassSymbol.Virtual, EClassCall], - parallelize: ParallelMap): Boolean - - /** - * Executes the command against an e-graph. - * - * @param egraph - * Destination e-graph. Implementations may either return it unchanged or produce a new immutable - * e-graph snapshot. - * @param reification - * Mapping from virtual symbols to concrete calls available before this command runs. This is used - * to ground virtual references present in [[uses]]. - * @param parallelize - * Parallelization strategy to label and distribute any internal work. - * @return - A pair `(maybeNewGraph, outMap)` where: - * - `maybeNewGraph` is `Some(newGraph)` if any change occurred, or `None` for a no-op. - * - `outMap` binds every symbol in [[definitions]] to its realized [[EClassCall]]. - */ - final def applyImmutable[ - Repr <: immutable.EGraphLike[NodeT, Repr] with immutable.EGraph[NodeT] - ]( - egraph: Repr, - reification: Map[EClassSymbol.Virtual, EClassCall], - parallelize: ParallelMap - ): (Option[Repr], Map[EClassSymbol.Virtual, EClassCall]) = { - val mutableEGraph = mutable.FreezableEGraph[NodeT, Repr](egraph) - val mutableReification = MutableMap(reification.toSeq: _*) - val updated = apply(mutableEGraph, mutableReification, parallelize) - if (updated) { - (Some(mutableEGraph.freeze()), mutableReification.toMap) - } else { - (None, mutableReification.toMap) - } - } -} - -/** - * Factory methods for constructing common [[Command]] instances. - */ -object Command { - - /** - * Creates a [[Command]] that asserts an existing [[EClassSymbol]] is equivalent to a given - * expression tree. - * - * Internally, this: - * 1. Inserts the tree (creating a fresh virtual symbol for its root if needed) - * 2. Unions that root with `symbol` - * - * This is the canonical “add-and-unify” operation when you want to grow the e-graph with a - * concrete term and immediately equate it with an existing class. - * - * @param symbol E-class symbol to unify with the tree’s root. - * @param tree Expression to insert/reuse and equate. - * @tparam NodeT Node type for the expression. - * @return A compound command performing the insert and the union. - * - * @example - * {{{ - * import foresight.eqsat.commands.Command - * - * val cmd: Command[MyNode] = - * Command.equivalence(existingSym, myTree) - * - * val (optGraph, out) = - * cmd.simplify(egraph).apply(egraph, Map.empty, parallel) - * }}} - */ - def equivalence[NodeT]( - symbol: EClassSymbol, - tree: MixedTree[NodeT, EClassSymbol] - ): Command[NodeT] = { - val builder = new CommandQueueBuilder[NodeT] - val c = builder.add(tree) - builder.union(symbol, c) - builder.result() - } - - /** - * Creates a [[Command]] that asserts an existing [[EClassSymbol]] is equivalent to a given - * expression tree, simplifying the command with respect to the given e-graph. - * - * Internally, this: - * 1. Inserts the tree (creating a fresh virtual symbol for its root if needed) - * 2. Unions that root with `symbol` - * 3. Simplifies the resulting command with respect to `egraph` - * - * This is the canonical “add-and-unify” operation when you want to grow the e-graph with a - * concrete term and immediately equate it with an existing class, while avoiding redundant work. - * - * @param symbol E-class symbol to unify with the tree’s root. - * @param tree Expression to insert/reuse and equate. - * @param egraph E-graph used as context for simplification. - * @tparam NodeT Node type for the expression. - * @return A compound command performing the insert and the union, simplified with respect to `egraph`. - * - * @example - * {{{ - * import foresight.eqsat.commands.Command - * - * val cmd: Command[MyNode] = - * Command.equivalenceSimplified(existingSym, myTree, egraph) - * - * val (optGraph, out) = - * cmd.apply(egraph, Map.empty, parallel) - * }}} - */ - def equivalenceSimplified[NodeT]( - symbol: EClassSymbol, - tree: MixedTree[NodeT, EClassSymbol], - egraph: EGraph[NodeT] - ): Command[NodeT] = { - val builder = new CommandQueueBuilder[NodeT] - val c = builder.addSimplified(tree, egraph) - builder.unionSimplified(symbol, c, egraph) - builder.result() - } -} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala deleted file mode 100644 index 40bcc1e3..00000000 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueue.scala +++ /dev/null @@ -1,376 +0,0 @@ -package foresight.eqsat.commands - -import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.{EClassCall, EClassSymbol, ENodeSymbol, MixedTree} -import foresight.eqsat.mutable.{EGraph => MutableEGraph} -import foresight.eqsat.readonly - -import java.util -import scala.collection.mutable -import scala.collection.compat.immutable.ArraySeq - -/** - * A composable batch of [[Command]]s that itself behaves as a single [[Command]]. - * - * Queues enable staging, simplification, and reordering of edits before applying them to a [[MutableEGraph]]. - * When applied, commands run in sequence, threading the evolving reification map through each step. - * - * @tparam NodeT Node type for expressions represented by the e-graph. - * @param commands The commands in execution order (prior to [[optimized]]). - */ -final case class CommandQueue[NodeT](commands: Seq[Command[NodeT]]) extends Command[NodeT] { - - /** All symbols used by the commands in this queue. */ - override def uses: Seq[EClassSymbol] = commands.flatMap(_.uses) - - /** All virtual symbols defined by the commands in this queue. */ - override def definitions: Seq[EClassSymbol.Virtual] = commands.flatMap(_.definitions) - - /** - * Applies each command in order, threading the latest graph and reification. - * - * For each step, the current graph and the accumulated virtual-to-real bindings are passed to the next command. - * - * @param egraph Initial graph snapshot. - * @param reification Initial virtual-to-concrete bindings available to the first command. This map is - * mutated in place to include new bindings from each command. - * @param parallelize Strategy for distributing work across commands that can parallelize internally. - * @return `true` if at least one command changed the graph, otherwise `false`. - */ - override def apply(egraph: MutableEGraph[NodeT], - reification: mutable.Map[EClassSymbol.Virtual, EClassCall], - parallelize: ParallelMap): Boolean = { - var anyChanges: Boolean = false - for (command <- commands) { - val changed = command.apply(egraph, reification, parallelize) - anyChanges ||= changed - } - anyChanges - } - - /** - * Appends an insertion of a single [[ENodeSymbol]] as an [[AddManyCommand]] of size 1. - * - * @param node Node to add. - * @return The fresh output symbol for the node and the extended queue. - * - * @example - * {{{ - * val (sym, q2) = q1.add(ENodeSymbol(op, defs, uses, args)) - * }}} - */ - def add(node: ENodeSymbol[NodeT]): (EClassSymbol, CommandQueue[NodeT]) = { - val result = EClassSymbol.virtual() - (result, CommandQueue(commands :+ AddManyCommand(ArraySeq(result -> node)))) - } - - /** - * Appends an insertion of a [[MixedTree]]. - * - * Child subtrees are added first (depth-first), then a final node is inserted referencing - * their produced symbols. If the tree is a [[MixedTree.Atom]], it is treated as already-real - * and no command is added. - * - * @param tree Tree to insert. - * @return The symbol for the tree's root and the extended queue. - * @example - * {{{ - * val (root, q2) = q1.add(myTree) - * }}} - */ - def add(tree: MixedTree[NodeT, EClassSymbol]): (EClassSymbol, CommandQueue[NodeT]) = { - val builder = new CommandQueueBuilder[NodeT]() - builder.appendAll(commands) - val result = builder.add(tree) - (result, builder.result()) - } - - /** - * Appends a [[UnionManyCommand]] of size 1. - * - * @param left Left class symbol. - * @param right Right class symbol. - * @return The extended queue. - * - * @example - * {{{ - * val q2 = q1.union(a, b) - * }}} - */ - def union(left: EClassSymbol, right: EClassSymbol): CommandQueue[NodeT] = - CommandQueue(commands :+ UnionManyCommand(Seq((left, right)))) - - /** - * Appends a single [[Command]] to the end of this queue. - * - * @param command Command to add. - * @return The extended queue. - */ - def chain(command: Command[NodeT]): CommandQueue[NodeT] = - CommandQueue(commands :+ command) - - /** - * Concatenates another [[CommandQueue]] to the end of this queue. - * - * @param commandQueue Queue to append. - * @return The extended queue. - */ - def chain(commandQueue: CommandQueue[NodeT]): CommandQueue[NodeT] = - CommandQueue(commands ++ commandQueue.commands) - - /** - * Rewrites this queue into an equivalent but cheaper sequence by: - * - flattening nested queues, - * - merging adjacent [[UnionManyCommand]]s, - * - batching independent [[AddManyCommand]]s and layering dependent ones. - * - * Commands may be reordered where independence permits. - * - * @return An optimized queue; does not modify this instance. - * - * @example - * {{{ - * val qOptim = q.optimized - * }}} - */ - def optimized: CommandQueue[NodeT] = - CommandQueue(CommandQueue.optimizeCommands(flatCommands)) - - /** - * Recursively flattens nested [[CommandQueue]]s. - * - * @return A sequence of leaf commands in program order. - */ - private def flatCommands: Seq[Command[NodeT]] = - commands.flatMap { - case queue: CommandQueue[NodeT] => queue.flatCommands - case command => Seq(command) - } -} - -/** - * Helpers for constructing and optimizing [[CommandQueue]]s. - */ -object CommandQueue { - - private val emptyQueue: CommandQueue[_] = CommandQueue(Seq.empty) - - /** - * Creates an empty queue. - * - * @tparam NodeT Node type for expressions represented by the e-graph. - */ - def empty[NodeT]: CommandQueue[NodeT] = emptyQueue.asInstanceOf[CommandQueue[NodeT]] - - /** Applies union-merge and add-batching passes to a flat list of commands. */ - private def optimizeCommands[NodeT](commands: Seq[Command[NodeT]]): Seq[Command[NodeT]] = { - mergeUnions(commands, others => { - val adds = others.collect { case cmd: AddManyCommand[NodeT] => cmd } - if (adds.size == others.size) { - optimizeAdds(adds) - } else { - CommandQueue.independentGroups(others).flatMap(CommandQueue.optimizeIndependentGroup) - } - }) - } - - /** - * Merges all [[UnionManyCommand]]s in a group into a single command, then processes the rest. - * - * @param group Flat list of commands. - * @param processRemaining Pass to handle non-union commands once unions are merged. - * @return Optimized sequence with at most one [[UnionManyCommand]] in the tail. - */ - private def mergeUnions[NodeT]( - group: Seq[Command[NodeT]], - processRemaining: Seq[Command[NodeT]] => Seq[Command[NodeT]] - ): Seq[Command[NodeT]] = { - // Partition the remaining commands into union commands and other commands. - val (unionCommands, remainingCommands) = group.partition { - case _: UnionManyCommand[NodeT] => true - case _ => false - } match { - case (left, right) => (left.collect { case u: UnionManyCommand[NodeT] => u }, right) - } - - // Merge all the union commands. - val unionPairs = unionCommands.flatMap(_.pairs) - unionPairs match { - case Seq() => processRemaining(remainingCommands) - case Seq(_*) => processRemaining(remainingCommands) :+ UnionManyCommand[NodeT](unionPairs) - } - } - - - @inline private def highestDependencyForNodes[NodeT]( - nodes: ArraySeq[(EClassSymbol.Virtual, ENodeSymbol[NodeT])], - defs: util.IdentityHashMap[EClassSymbol, Int] - ): Int = { - var highest = -1 - var n = 0 - while (n < nodes.length) { - val args = nodes(n)._2.args - var a = 0 - while (a < args.length) { - args(a) match { - case use: EClassSymbol.Virtual => - if (defs.containsKey(use)) { - val idx = defs.get(use) - if (idx > highest) highest = idx - } - case _ => // ignore non-virtual uses - } - a += 1 - } - n += 1 - } - highest - } - - /** - * Batches independent [[AddManyCommand]]s into layers based on their intra-batch dependencies. - * - * Nodes whose arguments are produced in earlier layers are scheduled in later layers; unrelated - * nodes share a layer and can be added together. - * - * @param group Only add-commands (pre-filtered). - * @return A sequence of batched [[AddManyCommand]]s in dependency order. - */ - private def optimizeAdds[NodeT](group: Seq[AddManyCommand[NodeT]]): Seq[Command[NodeT]] = { - // Our aim is to partition the add commands into batches of independent additions. We do this by tracking the - // batches in which each node is defined. When we encounter a fresh addition, we add it to batch i + 1 such that - // i is the highest batch in which any of its dependencies are defined. - type ArraySeqBuilder = mutable.Builder[(EClassSymbol.Virtual, ENodeSymbol[NodeT]), ArraySeq[(EClassSymbol.Virtual, ENodeSymbol[NodeT])]] - val batches = mutable.ArrayBuffer.empty[ArraySeqBuilder] - - // Pre-size to avoid rehashing during hot insert loop. - var totalDefs = 0 - for (command <- group) { - totalDefs += command.nodes.length - } - - val defs = new util.IdentityHashMap[EClassSymbol, Int](totalDefs) - for (command <- group) { - val highestDependency = highestDependencyForNodes[NodeT](command.nodes, defs) - - val batchIndex = highestDependency + 1 - if (batchIndex == batches.size) { - val newBatch = ArraySeq.newBuilder[(EClassSymbol.Virtual, ENodeSymbol[NodeT])] - newBatch ++= command.nodes - batches += newBatch - } else { - batches(batchIndex) ++= command.nodes - } - - // Record the batch in which each node is defined. - var i = 0 - while (i < command.nodes.length) { - val node = command.nodes(i) - defs.put(node._1, batchIndex) - i += 1 - } - } - - // Merge all the addition commands in each batch. - batches.map(_.result()).map(AddManyCommand[NodeT]).toSeq - } - - /** - * Merges independent subgroups: batches adjacent [[AddManyCommand]]s and leaves others as-is. - * - * @param group A set of commands known to be independent of each other. - * @return Optimized sequence for that group. - */ - private def optimizeIndependentGroup[NodeT](group: Seq[Command[NodeT]]): Seq[Command[NodeT]] = { - // Partition the commands into addition commands and other commands. - val (addCommands, remainingCommands) = group.partition { - case _: AddManyCommand[NodeT] => true - case _ => false - } match { - case (left, right) => (left.collect { case a: AddManyCommand[NodeT] => a }, right) - } - - // Merge all the addition and union commands. - val addPairs = addCommands.flatMap(_.nodes) - addPairs match { - case Seq() => remainingCommands - case Seq(_*) => remainingCommands :+ AddManyCommand[NodeT](ArraySeq(addPairs: _*)) - } - } - - /** - * Groups commands into independent sets using virtual-symbol dataflow. - * - * A command A depends on B if A uses a virtual symbol that B defines. The result is - * a topologically ordered partition where commands inside the same group have no such - * dependencies and may run in any order. - * - * @return A sequence of independent command groups in execution order. - */ - private def independentGroups[NodeT](commands: Seq[Command[NodeT]]): Seq[Seq[Command[NodeT]]] = { - val commandNumbers = commands.indices - - val defs = commandNumbers.flatMap(i => { - commands(i).definitions.map(_ -> i) - }).toMap - - // Step 1: Create a dependency graph - val dependencyGraph = mutable.Map.empty[Int, Set[Int]] - val reverseDependencies = mutable.Map.empty[Int, Set[Int]] - - for (command <- commandNumbers) { - dependencyGraph(command) = Set.empty - reverseDependencies(command) = Set.empty - } - - for (i <- commandNumbers) { - for (use <- commands(i).uses.collect { case u: EClassSymbol.Virtual => u }) { - defs.get(use) match { - case Some(j) if j != i => - dependencyGraph(i) += j - reverseDependencies(j) += i - case _ => - } - } - } - - // Step 2: Topological sort to find independent sets of commands - val sortedCommands = mutable.ArrayBuffer.empty[Int] - val noIncomingEdges = mutable.Queue(commandNumbers.filter(c => dependencyGraph(c).isEmpty): _*) - - while (noIncomingEdges.nonEmpty) { - val command = noIncomingEdges.dequeue() - sortedCommands += command - - for (dependent <- reverseDependencies(command)) { - dependencyGraph(dependent) -= command - if (dependencyGraph(dependent).isEmpty) { - noIncomingEdges.enqueue(dependent) - } - } - } - - // Step 3: Merge independent commands - val groups = mutable.ArrayBuffer.empty[Seq[Command[NodeT]]] - var currentBatch = mutable.ArrayBuffer.empty[Command[NodeT]] - var currentBatchDefs = Set.empty[EClassSymbol] - - for (command <- sortedCommands) { - val cmd = commands(command) - if (currentBatch.nonEmpty && cmd.uses.toSet.intersect(currentBatchDefs).nonEmpty) { - groups += currentBatch.toSeq - currentBatch = mutable.ArrayBuffer.empty - currentBatchDefs = Set.empty - } - - currentBatch += cmd - currentBatchDefs ++= cmd.definitions - } - - if (currentBatch.nonEmpty) { - groups += currentBatch.toSeq - } - - groups.toSeq - } -} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueueBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandQueueBuilder.scala deleted file mode 100644 index e45b607c..00000000 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandQueueBuilder.scala +++ /dev/null @@ -1,242 +0,0 @@ -package foresight.eqsat.commands - -import foresight.eqsat.collections.SlotSeq -import foresight.eqsat.readonly.EGraph -import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} -import foresight.util.collections.UnsafeSeqFromArray - -import scala.collection.compat.immutable.ArraySeq -import scala.collection.mutable - -/** - * Incrementally constructs a [[CommandQueue]] for later execution. - * - * This builder is a mutable convenience wrapper around [[CommandQueue]]’s - * immutable API. It allows code to append commands without repeatedly - * reassigning the queue, and to retrieve the finished queue via [[queue]]. - * - * Typical usage is to: - * 1. Create a `CommandQueueBuilder` - * 2. Append additions or unions in the desired order - * 3. Retrieve the resulting [[CommandQueue]] for simplification/optimization/application - * - * @tparam NodeT Node type for expressions represented by the e-graph. - * - * @example - * {{{ - * val b = new CommandQueueBuilder[MyNode] - * val a = b.add(myTree) // add a tree, get its symbol - * val bSym = b.add(ENodeSymbol(op, Nil, Nil, Seq(a))) - * b.union(a, bSym) // request that their classes be merged - * val q: CommandQueue[MyNode] = b.queue.optimized - * }}} - */ -final class CommandQueueBuilder[NodeT] { - private var commands: mutable.Builder[Command[NodeT], Seq[Command[NodeT]]] = null - - /** - * The [[CommandQueue]] accumulated so far. - */ - def result(): CommandQueue[NodeT] = { - if (commands == null) CommandQueue.empty - else CommandQueue(commands.result()) - } - - private def initCommands(): Unit = { - if (commands == null) commands = Seq.newBuilder[Command[NodeT]] - } - - /** - * Appends a [[Command]] to the queue. - * - * @param cmd Command to append. - */ - def append(cmd: Command[NodeT]): Unit = { - initCommands() - commands += cmd - } - - /** - * Appends multiple [[Command]]s to the queue. - * - * @param cmds Commands to append. - */ - def appendAll(cmds: Iterable[Command[NodeT]]): Unit = { - initCommands() - commands ++= cmds - } - - /** - * Appends an insertion of an [[ENodeSymbol]]. - * - * Internally wraps the node in a single-node [[AddManyCommand]]. - * - * @param node Node to insert. - * @return The fresh [[EClassSymbol.Virtual]] assigned to the inserted node’s e-class. - */ - def add(node: ENodeSymbol[NodeT]): EClassSymbol = { - val result = EClassSymbol.virtual() - append(AddManyCommand(ArraySeq(result -> node))) - result - } - - /** - * Appends an insertion of a [[MixedTree]]. - * - * Child subtrees are inserted first, then a final node referencing their - * symbols is added. If the tree is a [[MixedTree.Atom]], no new command - * is added and the existing [[EClassSymbol.Real]] is returned. - * - * @param tree Tree to insert. - * @return The [[EClassSymbol]] for the tree’s root e-class. - */ - def add(tree: MixedTree[NodeT, EClassSymbol]): EClassSymbol = { - tree match { - case MixedTree.Node(t, defs, uses, args) => - val result = EClassSymbol.virtual() - append(AddManyCommand(ArraySeq(result -> ENodeSymbol(t, defs, uses, args.map(add))))) - result - - case MixedTree.Atom(call) => - call - } - } - - /** - * Appends an insertion of a [[MixedTree]], using the provided e-graph - * to simplify the insertion. - * - * Child subtrees are inserted first, then a final node referencing their - * symbols is added. If the tree is a [[MixedTree.Atom]], no new command - * is added and the existing [[EClassSymbol.Real]] is returned. - * - * @param tree Tree to insert. - * @return The [[EClassSymbol]] for the tree’s root e-class. - */ - def addSimplified(tree: MixedTree[NodeT, EClassSymbol], egraph: EGraph[NodeT]): EClassSymbol = { - tree match { - case MixedTree.Node(t, defs, uses, args) => - val argSymbols = CommandQueueBuilder.symbolArrayFrom( - args, - (tree: MixedTree[NodeT, EClassSymbol]) => addSimplified(tree, egraph)) - addSimplifiedNode(t, defs, uses, argSymbols, egraph) - - case MixedTree.Atom(call) => call - } - } - - private[eqsat] def addSimplifiedNode(nodeType: NodeT, - definitions: SlotSeq, - uses: SlotSeq, - args: Array[EClassSymbol], - egraph: EGraph[NodeT]): EClassSymbol = { - - // Check if all children are already in the graph - val argCalls = CommandQueueBuilder.resolveAllOrNull(args) - - // If the children are already present, we might not need to add a new node - if (argCalls != null) { - val candidateNode = ENode.unsafeWrapArrays(nodeType, definitions, uses, argCalls) - egraph.findOrNull(candidateNode) match { - case null => - // Node does not exist in the graph; queue it for insertion - val result = EClassSymbol.virtual() - append(AddManyCommand(ArraySeq(result -> candidateNode))) - result - - case existingCall => - // Node already exists in the graph; reuse its class - EClassSymbol.real(existingCall) - } - } else { - val result = EClassSymbol.virtual() - val candidateNode = ENodeSymbol[NodeT](nodeType, definitions, uses, UnsafeSeqFromArray(args)) - append(AddManyCommand(ArraySeq(result -> candidateNode))) - result - } - } - - private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], egraph: EGraph[NodeT]): EClassSymbol = { - tree match { - case MixedTree.Node(t, defs, uses, args) => - val argSymbols = CommandQueueBuilder.symbolArrayFrom( - args, (tree: MixedTree[NodeT, EClassCall]) => addSimplifiedReal(tree, egraph)) - addSimplifiedNode(t, defs, uses, argSymbols, egraph) - - case MixedTree.Atom(call) => EClassSymbol.real(call) - } - } - - /** - * Appends a [[UnionManyCommand]] request to merge two e-classes. - * - * @param a First class symbol. - * @param b Second class symbol. - */ - def union(a: EClassSymbol, b: EClassSymbol): Unit = { - append(UnionManyCommand(Seq((a, b)))) - } - - /** - * Appends a [[UnionManyCommand]] request to merge two e-classes, - * but only if they are not already known to be equivalent in the - * provided e-graph. - * - * If both `a` and `b` are [[EClassSymbol.Real]], their canonical - * representatives in the e-graph are compared; if they differ, a - * union command is added. If either is virtual, a union command is - * always added. - * - * @param a First class symbol. - * @param b Second class symbol. - * @param egraph E-graph used to check existing equivalences. - */ - def unionSimplified(a: EClassSymbol, b: EClassSymbol, egraph: EGraph[NodeT]): Unit = { - (a, b) match { - case (callA: EClassCall, callB: EClassCall) => - if (egraph.canonicalize(callA) != egraph.canonicalize(callB)) { - append(UnionManyCommand(Seq((a, b)))) - } - case _ => - append(UnionManyCommand(Seq((a, b)))) - } - } -} - -private[eqsat] object CommandQueueBuilder { - def symbolArrayFrom[A](values: ArraySeq[A], valueToSymbol: A => EClassSymbol): Array[EClassSymbol] = { - // Try to avoid allocating an array of EClassSymbol if all entries are EClassCall. - // The common case is that all children are already in the e-graph, and we will - // want to construct an ENode with an Array[EClassCall]. - // If we find any entry that is not an EClassCall, we fall back to allocating - // an Array[EClassSymbol] and copying the prefix of calls. - val n = values.length - val calls = new Array[EClassCall](n) - var i = 0 - while (i < n) { - valueToSymbol(values(i)) match { - case c: EClassCall => - calls(i) = c - case other => - // Fallback: allocate symbols array, copy the prefix of calls, and finish filling - val syms = new Array[EClassSymbol](n) - var j = 0 - while (j < i) { syms(j) = calls(j); j += 1 } - syms(i) = other - j = i + 1 - while (j < n) { syms(j) = valueToSymbol(values(j)); j += 1 } - return syms - } - i += 1 - } - // All entries were EClassCall. Perform a safe upcast to Array[EClassSymbol] - calls.asInstanceOf[Array[EClassSymbol]] - } - - def resolveAllOrNull(args: Array[EClassSymbol]): Array[EClassCall] = { - if (args.isInstanceOf[Array[EClassCall]]) - args.asInstanceOf[Array[EClassCall]] - else - null - } -} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala new file mode 100644 index 00000000..c825cd54 --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala @@ -0,0 +1,137 @@ +package foresight.eqsat.commands + +import foresight.eqsat.parallel.ParallelMap +import foresight.eqsat.{AddNodeResult, EClassCall, EClassSymbol, ENode, ENodeSymbol, mutable, immutable} + +import java.util +import scala.collection.compat.immutable.ArraySeq + +/** + * A schedule of commands to be executed on an e-graph. + * + * Additions are grouped into batches, where all additions in lower-numbered + * batches are executed before additions in higher-numbered batches. There is + * no ordering guarantee between additions within the same batch. + * + * Unions are executed after all additions have been processed. + * + * @tparam NodeT Node type for expressions represented by the e-graph. + */ +final case class CommandSchedule[NodeT](batchZero: (ArraySeq[EClassSymbol.Virtual], ArraySeq[ENode[NodeT]]), + otherBatches: Seq[(ArraySeq[EClassSymbol.Virtual], ArraySeq[ENodeSymbol[NodeT]])], + unions: ArraySeq[(EClassSymbol, EClassSymbol)] + ) { + /** + * The additions scheduled in this command schedule, grouped by batch. + */ + def additions: Seq[(ArraySeq[EClassSymbol.Virtual], ArraySeq[ENodeSymbol[NodeT]])] = { + batchZero +: otherBatches + } + + private type ReificationMap = util.IdentityHashMap[EClassSymbol.Virtual, EClassCall] + + private def processAdditionResults(symbols: ArraySeq[EClassSymbol.Virtual], + addResults: ArraySeq[AddNodeResult], + reification: ReificationMap): Boolean = { + var anyChanges: Boolean = false + + var i = 0 + while (i < symbols.length) { + val symbol = symbols(i) + val result = addResults(i) + reification.put(symbol, result.call) + + result match { + case AddNodeResult.Added(_) => anyChanges = true + case AddNodeResult.AlreadyThere(_) => // no change + } + i += 1 + } + + anyChanges + } + + private def applyBatchZero(egraph: mutable.EGraph[NodeT], + parallelize: ParallelMap, + reification: ReificationMap): Boolean = { + processAdditionResults(batchZero._1, egraph.tryAddMany(batchZero._2, parallelize), reification) + } + + private def applyReifiedBatch(egraph: mutable.EGraph[NodeT], + batch: (ArraySeq[EClassSymbol.Virtual], ArraySeq[ENodeSymbol[NodeT]]), + parallelize: ParallelMap, + reification: ReificationMap): Boolean = { + val reifiedNodes = batch._2.map(_.reify(reification.get)) + processAdditionResults(batch._1, egraph.tryAddMany(reifiedNodes, parallelize), reification) + } + + private def applyUnions(egraph: mutable.EGraph[NodeT], + unions: ArraySeq[(EClassSymbol, EClassSymbol)], + parallelize: ParallelMap, + reification: ReificationMap): Boolean = { + val reifiedUnions = unions + .map { case (l, r) => (l.reify(reification.get), r.reify(reification.get)) } + .filter { case (l, r) => !egraph.areSame(l, r) } + + if (reifiedUnions.isEmpty) { + false + } else { + egraph.unionMany(reifiedUnions, parallelize) + true + } + } + + /** + * Executes the scheduled commands against an e-graph. + * + * @param egraph + * Destination e-graph that will be mutated in place. + * @param parallelize + * Parallelization strategy to label and distribute any internal work. + * @return `true` if any change occurred, or `false` for a no-op. + */ + def apply(egraph: mutable.EGraph[NodeT], + parallelize: ParallelMap): Boolean = { + + val reification = util.IdentityHashMap[EClassSymbol.Virtual, foresight.eqsat.EClassCall]() + + var anyChanges: Boolean = false + anyChanges = anyChanges | applyBatchZero(egraph, parallelize, reification) + for (batch <- otherBatches) { + anyChanges = anyChanges | applyReifiedBatch(egraph, batch, parallelize, reification) + } + + anyChanges = anyChanges | applyUnions(egraph, unions, parallelize, reification) + + anyChanges + } + + /** + * Executes the command schedule against an immutable e-graph. + * + * @param egraph + * Destination e-graph. Implementations may either return it unchanged or produce a new immutable + * e-graph snapshot. + * @param reification + * Mapping from virtual symbols to concrete calls available before this command runs. This is used + * to ground virtual references present in [[uses]]. + * @param parallelize + * Parallelization strategy to label and distribute any internal work. + * @return + * `Some(newGraph)` if any change occurred, or `None` for a no-op. + */ + def applyImmutable[ + Repr <: immutable.EGraphLike[NodeT, Repr] with immutable.EGraph[NodeT] + ]( + egraph: Repr, + parallelize: ParallelMap + ): Option[Repr] = { + val mutableEGraph = mutable.FreezableEGraph[NodeT, Repr](egraph) + val updated = apply(mutableEGraph, parallelize) + if (updated) { + Some(mutableEGraph.freeze()) + } else { + None + } + } +} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala new file mode 100644 index 00000000..d62ab9c0 --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -0,0 +1,200 @@ +package foresight.eqsat.commands + +import foresight.eqsat.collections.SlotSeq +import foresight.eqsat.readonly.EGraph +import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} +import foresight.util.Debug +import foresight.util.collections.UnsafeSeqFromArray + +import scala.collection.compat.immutable.ArraySeq +import scala.runtime.IntRef + +/** + * Constructs commands for later execution. Commands are scheduled in batches. + * Union commands are executed at the end of the schedule while add commands + * are grouped into batches. Additions with lower batch numbers + * are always executed before additions with higher batch numbers. There is no ordering + * guarantee between additions with the same batch number. + */ +trait CommandScheduleBuilder[NodeT] { + /** + * Appends an add command to the schedule. This command adds the given node + * to the e-graph, producing an e-class with the given symbolic name. + * @param symbol The e-class symbol of the e-class that contains or will contain the node. + * @param node The node to add. + * @param batch The batch number for scheduling. + */ + def add(symbol: EClassSymbol.Virtual, node: ENodeSymbol[NodeT], batch: Int): Unit + + /** + * Appends a union command to the schedule. This command unions the two given e-classes + * in the e-graph. + * @param a The first e-class symbol to union. + * @param b The second e-class symbol to union. + */ + def union(a: EClassSymbol, b: EClassSymbol): Unit + + /** + * Synthesizes the accumulated additions and unions into a single command. + * + * Calling this method finalizes the schedule. No further additions or unions + * may be appended after calling this method. + * + * @return The additions and unions as a single command. + */ + def result(): CommandSchedule[NodeT] + + /** + * Appends an add command to the schedule. This command adds the given node + * to the e-graph, producing a new e-class with a fresh symbolic name. + * + * @param node The node to add. + * @param batch The batch number for scheduling. + * @return The fresh e-class symbol assigned to the added node's e-class. + */ + final def add(node: ENodeSymbol[NodeT], batch: Int): EClassSymbol.Virtual = { + val symbol = EClassSymbol.virtual() + add(symbol, node, batch) + symbol + } + + /** + * Appends a request to merge two e-classes, but only if they are + * not already known to be equivalent in the provided e-graph. + * + * If both `a` and `b` are [[EClassSymbol.Real]], their canonical + * representatives in the e-graph are compared; if they differ, a + * union command is added. If either is virtual, a union command is + * always added. + * + * @param a First class symbol. + * @param b Second class symbol. + * @param egraph E-graph used to check existing equivalences. + */ + final def unionSimplified(a: EClassSymbol, b: EClassSymbol, egraph: EGraph[NodeT]): Unit = { + (a, b) match { + case (callA: EClassCall, callB: EClassCall) => + if (egraph.canonicalize(callA) != egraph.canonicalize(callB)) { + union(a, b) + } + case _ => + union(a, b) + } + } + + private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], + egraph: EGraph[NodeT]): EClassSymbol = { + val maxBatch = IntRef(0) + addSimplifiedReal(tree, egraph, maxBatch) + } + + private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], + egraph: EGraph[NodeT], + maxBatch: IntRef): EClassSymbol = { + tree match { + case MixedTree.Node(t, defs, uses, args) => + // Local accumulator for children of this node. + val childMax = IntRef(0) + val argSymbols = CommandScheduleBuilder.symbolArrayFrom( + args, + childMax, + (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb) + ) + val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph) + // Propagate maximum required batch up to the caller's accumulator. + if (childMax.elem > maxBatch.elem) maxBatch.elem = childMax.elem + sym + + case MixedTree.Atom(call) => + // No insertion required; keep caller's accumulator unchanged. + EClassSymbol.real(call) + } + } + + private[eqsat] def addSimplifiedNode(nodeType: NodeT, + definitions: SlotSeq, + uses: SlotSeq, + args: Array[EClassSymbol], + maxBatch: IntRef, + egraph: EGraph[NodeT]): EClassSymbol = { + + // Check if all children are already in the graph. + val argCalls = CommandScheduleBuilder.resolveAllOrNull(args) + + // If the children are already present, we might not need to add a new node. + if (argCalls != null) { + if (Debug.isEnabled) { + assert(maxBatch.elem == 0) + } + + val candidateNode = ENode.unsafeWrapArrays(nodeType, definitions, uses, argCalls) + egraph.findOrNull(candidateNode) match { + case null => + // Node does not exist in the graph but its children do exist in the graph. + // Queue it for insertion in batch zero. + add(candidateNode, 0) + + case existingCall => + // Node already exists in the graph; reuse its class. + EClassSymbol.real(existingCall) + } + } else { + val candidateNode = ENodeSymbol[NodeT](nodeType, definitions, uses, UnsafeSeqFromArray(args)) + maxBatch.elem += 1 + add(candidateNode, maxBatch.elem) + } + } +} + +/** + * Companion object for [[CommandScheduleBuilder]]. + */ +object CommandScheduleBuilder { + /** + * Creates a new concurrent command schedule builder. This builder can safely be used + * from multiple threads. + * @tparam NodeT The type of the nodes in the e-graph. + * @return A new concurrent command schedule builder. + */ + def newConcurrentBuilder[NodeT]: CommandScheduleBuilder[NodeT] = new ConcurrentCommandScheduleBuilder[NodeT]() + + private[eqsat] def symbolArrayFrom[A](values: ArraySeq[A], maxBatch: IntRef, valueToSymbol: (A, IntRef) => EClassSymbol): Array[EClassSymbol] = { + // Try to avoid allocating an array of EClassSymbol if all entries are EClassCall. + // The common case is that all children are already in the e-graph, and we will + // want to construct an ENode with an Array[EClassCall]. + // If we find any entry that is not an EClassCall, we fall back to allocating + // an Array[EClassSymbol] and copying the prefix of calls. + val n = values.length + val calls = new Array[EClassCall](n) + var i = 0 + while (i < n) { + valueToSymbol(values(i), maxBatch) match { + case c: EClassCall => + calls(i) = c + case other => + // Fallback: allocate symbols array, copy the prefix of calls, and finish filling + val syms = new Array[EClassSymbol](n) + var j = 0 + while (j < i) { + syms(j) = calls(j); j += 1 + } + syms(i) = other + j = i + 1 + while (j < n) { + syms(j) = valueToSymbol(values(j), maxBatch); j += 1 + } + return syms + } + i += 1 + } + // All entries were EClassCall. Perform a safe upcast to Array[EClassSymbol] + calls.asInstanceOf[Array[EClassSymbol]] + } + + private[eqsat] def resolveAllOrNull(args: Array[EClassSymbol]): Array[EClassCall] = { + if (args.isInstanceOf[Array[EClassCall]]) + args.asInstanceOf[Array[EClassCall]] + else + null + } +} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala new file mode 100644 index 00000000..e2904f4c --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala @@ -0,0 +1,77 @@ +package foresight.eqsat.commands + +import foresight.eqsat.{EClassSymbol, ENode, ENodeSymbol} +import foresight.util.collections.UnsafeSeqFromArray + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} +import scala.collection.compat.immutable.ArraySeq +import scala.reflect.ClassTag + +private[commands] class ConcurrentCommandScheduleBuilder[NodeT] extends CommandScheduleBuilder[NodeT] { + private val batchZeroAdds = new ConcurrentLinkedQueue[(EClassSymbol.Virtual, ENode[NodeT])]() + private val otherBatchAdds = new ConcurrentHashMap[Int, ConcurrentLinkedQueue[(EClassSymbol.Virtual, ENodeSymbol[NodeT])]]() + private val unions: ConcurrentLinkedQueue[(EClassSymbol, EClassSymbol)] = new ConcurrentLinkedQueue() + + override def add(symbol: EClassSymbol.Virtual, node: ENodeSymbol[NodeT], batch: Int): Unit = { + if (batch == 0) { + node match { + case eNode: ENode[NodeT] => batchZeroAdds.add((symbol, eNode)) + case _ => throw new IllegalArgumentException("Only ENode instances are allowed in batch 0") + } + } else { + val queue = otherBatchAdds.computeIfAbsent(batch, _ => new ConcurrentLinkedQueue()) + queue.add((symbol, node)) + } + } + + override def union(a: EClassSymbol, b: EClassSymbol): Unit = { + unions.add((a, b)) + } + + private def _highestBatchIndex: Int = { + var highestBatchIndex = 0 + val batchIndexIterator = otherBatchAdds.keySet.iterator() + while (batchIndexIterator.hasNext) { + val batchIndex = batchIndexIterator.next() + if (batchIndex > highestBatchIndex) { + highestBatchIndex = batchIndex + } + } + highestBatchIndex + } + + private def linkedQueueToSplitArrays[A: ClassTag](queue: ConcurrentLinkedQueue[(EClassSymbol.Virtual, A)]): (ArraySeq[EClassSymbol.Virtual], ArraySeq[A]) = { + val size = queue.size() + val symbols = new Array[EClassSymbol.Virtual](size) + val nodes = new Array[A](size) + + val iterator = queue.iterator() + var index = 0 + while (iterator.hasNext) { + val (first, second) = iterator.next() + symbols(index) = first + nodes(index) = second + index += 1 + } + + (UnsafeSeqFromArray(symbols), UnsafeSeqFromArray(nodes)) + } + + override def result(): CommandSchedule[NodeT] = { + val batchZero = linkedQueueToSplitArrays[ENode[NodeT]](batchZeroAdds) + + val highestBatchIndex = _highestBatchIndex + val batches = new Array[(ArraySeq[EClassSymbol.Virtual], ArraySeq[ENodeSymbol[NodeT]])](highestBatchIndex) + for (batchIndex <- 1 to highestBatchIndex) { + val queue = otherBatchAdds.get(batchIndex) + if (queue != null) { + batches(batchIndex - 1) = linkedQueueToSplitArrays[ENodeSymbol[NodeT]](queue) + } else { + batches(batchIndex - 1) = (ArraySeq.empty, ArraySeq.empty) + } + } + + val unionArraySeq = UnsafeSeqFromArray(unions.toArray(new Array[(EClassSymbol, EClassSymbol)](0))) + CommandSchedule(batchZero, batches.toSeq, unionArraySeq) + } +} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala b/foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala deleted file mode 100644 index 58bc42ab..00000000 --- a/foresight/src/main/scala/foresight/eqsat/commands/UnionManyCommand.scala +++ /dev/null @@ -1,70 +0,0 @@ -package foresight.eqsat.commands - -import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.{EClassCall, EClassSymbol} -import foresight.eqsat.mutable -import foresight.eqsat.readonly - -import scala.collection.mutable.{Map => MutableMap} - -/** - * A [[Command]] that unions multiple pairs of e-classes in a single batch. - * - * Each entry in [[pairs]] may contain real or virtual [[EClassSymbol]]s. Virtual symbols are - * resolved via the reification map passed to [[apply]]. All unions are accumulated using - * [[EGraphWithPendingUnions]] and committed in one rebuild step if needed. - * - * This command defines no new virtual symbols; it only merges existing classes. - * - * @tparam NodeT Node type for expressions represented by the e-graph. - * @param pairs Pairs of symbols whose underlying classes should be unified. - */ -final case class UnionManyCommand[NodeT](pairs: Seq[(EClassSymbol, EClassSymbol)]) extends Command[NodeT] { - - /** All symbols referenced by this batch of unions. */ - override def uses: Seq[EClassSymbol] = - pairs.flatMap { case (l, r) => Seq(l, r) } - - /** No new e-classes are created by a union; this command defines nothing. */ - override def definitions: Seq[EClassSymbol.Virtual] = Seq.empty - - /** - * Resolves all symbols using the given map and enqueues each union into - * an [[EGraphWithPendingUnions]]. If any union changes the structure, - * a rebuild is performed. - * - * @param egraph Target e-graph to update. - * @param reification Mapping from virtual symbols to concrete calls, - * used to resolve the left/right sides before unioning. - * @param parallelize Strategy for distributing the rebuild, if needed. - * @return - * - `true` if at least one union required changes (triggering a rebuild), - * otherwise `false`. - * - An empty reification map (unions do not define outputs). - * - * @example - * {{{ - * val a: EClassSymbol = EClassSymbol.real(callA) - * val b: EClassSymbol = EClassSymbol.real(callB) - * val cmd = UnionManyCommand(Seq(a -> b)) - * val updated = cmd.apply(egraph, HashMap.empty, parallel) - * }}} - */ - override def apply( - egraph: mutable.EGraph[NodeT], - reification: MutableMap[EClassSymbol.Virtual, EClassCall], - parallelize: ParallelMap - ): Boolean = { - val reifiedPairs = pairs - .map { case (l, r) => (l.reify(reification), r.reify(reification)) } - .filter { case (l, r) => !egraph.areSame(l, r) } - - if (reifiedPairs.isEmpty) { - false - } - else { - egraph.unionMany(reifiedPairs, parallelize) - true - } - } -} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/package.scala b/foresight/src/main/scala/foresight/eqsat/commands/package.scala deleted file mode 100644 index 79d9637f..00000000 --- a/foresight/src/main/scala/foresight/eqsat/commands/package.scala +++ /dev/null @@ -1,83 +0,0 @@ -package foresight.eqsat - -/** - * Defines the command system for making batched, replayable edits to a [[mutable.EGraph]]. - * - * A [[commands.Command]] is an immutable description of a single edit, - * such as inserting one or more nodes ([[commands.AddManyCommand]]) or - * merging e-classes ([[commands.UnionManyCommand]]). - * - * Commands are: - * - **pure values** that do not mutate an e-graph until [[commands.Command.apply]] - * - **self-describing**: they declare their required inputs ([[commands.Command.uses]]) - * and any new virtual symbols they define ([[commands.Command.definitions]]) - * - **optimizable**: they can simplify themselves for a given e-graph - * ([[commands.Command.simplify]]) - * - * ## Symbols and Reification - * - * Commands refer to e-classes symbolically via [[EClassSymbol]]: - * - [[EClassCall]] for existing classes - * - [[EClassSymbol.Virtual]] for classes not yet in the graph - * - * When a command is applied, virtual symbols are mapped to real classes using - * a **reification map** (`Map[Virtual, EClassCall]`). This allows command sequences - * to be planned without knowing concrete IDs upfront. - * - * ## Batching and Composition - * - * Most real edits involve more than one command. The package provides: - * - * - [[commands.CommandQueue]] — an immutable batch of commands that itself - * implements [[commands.Command]], so queues can be nested or chained. - * - * - [[commands.CommandQueueBuilder]] — a mutable helper for incrementally building - * a queue without reassigning after each append. - * - * Queues can be: - * - **simplified**: removing redundant unions, skipping inserts for already-present nodes - * - **optimized**: merging independent commands for fewer graph rebuilds - * - * ## Common Command Types - * - * - [[commands.AddManyCommand]] — inserts one or more [[ENodeSymbol]]s, - * each paired with the [[EClassSymbol.Virtual]] that will represent - * its resulting class. - * - * - [[commands.UnionManyCommand]] — unifies multiple pairs of e-classes at once. - * - * - [[commands.CommandQueue]] — sequences arbitrary commands and can itself - * be nested inside other queues. - * - * ## Example - * - * {{{ - * import foresight.eqsat.commands._ - * - * val builder = new CommandQueueBuilder[MyNode] - * - * // Add a tree and get its root symbol - * val root = builder.add(myTree) - * - * // Add a dependent node - * val depNode = ENodeSymbol(op, Nil, Nil, Seq(root)) - * val depSym = builder.add(depNode) - * - * // Request a union of the two classes - * builder.union(root, depSym) - * - * // Get the queue, optimize it, and apply it - * val queue = builder.queue.optimized - * val (maybeNewGraph, reif) = queue.apply(egraph, Map.empty, parallel) - * }}} - * - * ## Lifecycle - * - * 1. **Build** — construct commands via `AddManyCommand`, `UnionManyCommand`, - * or convenience methods in [[commands.CommandQueue]] / [[commands.CommandQueueBuilder]]. - * 2. **Simplify** — run `simplify(egraph, partialMap)` to drop redundant work and - * pre-bind resolvable symbols. - * 3. **Optimize** — call `optimized` on queues to merge unions and batch adds. - * 4. **Apply** — run `apply(egraph, reification, parallelMap)` to produce an updated graph. - */ -package object commands diff --git a/foresight/src/main/scala/foresight/eqsat/hashCons/AbstractMutableHashConsEGraph.scala b/foresight/src/main/scala/foresight/eqsat/hashCons/AbstractMutableHashConsEGraph.scala index ab6ae19e..453c8b5c 100644 --- a/foresight/src/main/scala/foresight/eqsat/hashCons/AbstractMutableHashConsEGraph.scala +++ b/foresight/src/main/scala/foresight/eqsat/hashCons/AbstractMutableHashConsEGraph.scala @@ -5,7 +5,9 @@ import foresight.eqsat.collections.{SlotMap, SlotSet} import foresight.eqsat.mutable.EGraph import foresight.eqsat.parallel.ParallelMap import foresight.util.Debug +import foresight.util.collections.UnsafeSeqFromArray +import scala.collection.compat.immutable.ArraySeq import scala.collection.mutable.{HashMap, HashSet, LinkedHashSet} /** @@ -78,8 +80,8 @@ private[hashCons] abstract class AbstractMutableHashConsEGraph[NodeT] final override def canonicalizeOrNull(ref: EClassRef): EClassCall = unionFind.findAndCompressOrNull(ref) - final override def tryAddMany(nodes: Seq[ENode[NodeT]], - parallelize: ParallelMap): Seq[AddNodeResult] = { + final override def tryAddMany(nodes: ArraySeq[ENode[NodeT]], + parallelize: ParallelMap): ArraySeq[AddNodeResult] = { // Adding independent e-nodes is fundamentally a sequential operation, but the most expensive part of adding nodes // is canonicalizing them and looking them up in the e-graph. Canonicalization can be parallelized since adding a // node will never change the canonical form of other nodes - only union operations can do that. @@ -90,13 +92,14 @@ private[hashCons] abstract class AbstractMutableHashConsEGraph[NodeT] val p = parallelize.child("add nodes") + // FIXME: produce an array from the parallel map directly to avoid an extra copy. val canonicalized = p(nodes, canonicalize) val results = p.run { canonicalized.map { node => tryAddUnsafe(node) } } - results.toSeq + UnsafeSeqFromArray(results.toArray) } final override def unionMany(pairs: Seq[(EClassCall, EClassCall)], parallelize: ParallelMap): Set[Set[EClassCall]] = { diff --git a/foresight/src/main/scala/foresight/eqsat/hashCons/immutable/HashConsEGraph.scala b/foresight/src/main/scala/foresight/eqsat/hashCons/immutable/HashConsEGraph.scala index 99419d6b..a0a0b625 100644 --- a/foresight/src/main/scala/foresight/eqsat/hashCons/immutable/HashConsEGraph.scala +++ b/foresight/src/main/scala/foresight/eqsat/hashCons/immutable/HashConsEGraph.scala @@ -4,6 +4,9 @@ import foresight.eqsat.hashCons.{EClassData, ReadOnlyHashConsEGraph} import foresight.eqsat.immutable.{EGraph, EGraphLike} import foresight.eqsat.{AddNodeResult, EClassCall, EClassRef, ENode} import foresight.eqsat.parallel.ParallelMap +import foresight.util.collections.UnsafeSeqFromArray + +import scala.collection.compat.immutable.ArraySeq /** * An e-graph that uses hash-consing to map e-nodes to e-classes. @@ -51,8 +54,8 @@ private[eqsat] final case class HashConsEGraph[NodeT] private[hashCons](protecte classData(ref) } - override def tryAddMany(nodes: Seq[ENode[NodeT]], - parallelize: ParallelMap): (Seq[AddNodeResult], HashConsEGraph[NodeT]) = { + override def tryAddMany(nodes: ArraySeq[ENode[NodeT]], + parallelize: ParallelMap): (ArraySeq[AddNodeResult], HashConsEGraph[NodeT]) = { // Adding independent e-nodes is fundamentally a sequential operation, but the most expensive part of adding nodes // is canonicalizing them and looking them up in the e-graph. Canonicalization can be parallelized since adding a // node will never change the canonical form of other nodes - only union operations can do that. @@ -64,13 +67,15 @@ private[eqsat] final case class HashConsEGraph[NodeT] private[hashCons](protecte val p = parallelize.child("add nodes") val mutable = toBuilder + + // FIXME: produce an array from the parallel map directly to avoid an extra copy. val canonicalized = p(nodes, canonicalize) val results = p.run { canonicalized.map { node => mutable.tryAddUnsafe(node) } } - (results.toSeq, mutable.result()) + (UnsafeSeqFromArray(results.toArray), mutable.result()) } override def unionMany(pairs: Seq[(EClassCall, EClassCall)], diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala index d01fb1d1..74e597b3 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala @@ -4,6 +4,8 @@ import foresight.eqsat.{AddNodeResult, EClassCall, ENode, MixedTree, Tree} import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.readonly +import scala.collection.compat.immutable.ArraySeq + /** * An e-graph is a data structure for representing and maintaining equivalence classes of expressions. * E-graphs support equality saturation, a powerful technique for exploring all equivalent forms of a term @@ -79,7 +81,7 @@ trait EGraphLike[NodeT, +This <: EGraphLike[NodeT, This] with EGraph[NodeT]] ext * @param parallelize Strategy used for any parallel work within the addition. * @return (Per-node results in input order, new e-graph containing the additions). */ - def tryAddMany(nodes: Seq[ENode[NodeT]], parallelize: ParallelMap): (Seq[AddNodeResult], This) + def tryAddMany(nodes: ArraySeq[ENode[NodeT]], parallelize: ParallelMap): (ArraySeq[AddNodeResult], This) /** * Unions (merges) pairs of e-classes. @@ -116,7 +118,7 @@ trait EGraphLike[NodeT, +This <: EGraphLike[NodeT, This] with EGraph[NodeT]] ext * @return (E-class of `node`, new e-graph). */ final def add(node: ENode[NodeT]): (EClassCall, This) = { - tryAddMany(Seq(node), ParallelMap.sequential) match { + tryAddMany(ArraySeq(node), ParallelMap.sequential) match { case (Seq(AddNodeResult.Added(call)), egraph) => (call, egraph) case (Seq(AddNodeResult.AlreadyThere(call)), egraph) => (call, egraph) case _ => throw new IllegalStateException("Unexpected result from tryAddMany") diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala index fc542165..0e71e614 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala @@ -1,11 +1,13 @@ package foresight.eqsat.immutable -import foresight.eqsat._ +import foresight.eqsat.* import foresight.eqsat.readonly import foresight.eqsat.metadata.Analysis import foresight.eqsat.parallel.ParallelMap import foresight.util.collections.StrictMapOps.toStrictMapOps +import scala.collection.compat.immutable.ArraySeq + /** * Wrapper that couples an [[EGraph]] with a set of registered [[Metadata]] managers * and keeps them *in sync* on every change to the underlying e-graph. @@ -131,8 +133,8 @@ extends readonly.EGraphWithMetadata[NodeT, Repr] * - The per-node results from the underlying e-graph. * - A new wrapper with the updated e-graph and metadata. */ - override def tryAddMany(nodes: Seq[ENode[NodeT]], - parallelize: ParallelMap): (Seq[AddNodeResult], EGraphWithMetadata[NodeT, Repr]) = { + override def tryAddMany(nodes: ArraySeq[ENode[NodeT]], + parallelize: ParallelMap): (ArraySeq[AddNodeResult], EGraphWithMetadata[NodeT, Repr]) = { val (results, newEgraph) = egraph.tryAddMany(nodes, parallelize) val newNodes = nodes.zip(results).collect { case (node, AddNodeResult.Added(call)) => diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala index 1fdeb294..4d749896 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala @@ -1,9 +1,11 @@ package foresight.eqsat.immutable -import foresight.eqsat._ +import foresight.eqsat.* import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.rewriting.PortableMatch +import scala.collection.compat.immutable.ArraySeq + /** * An e-graph that records the set of matches that have been applied to it. * @@ -34,8 +36,8 @@ final case class EGraphWithRecordedApplications[ EGraphWithRecordedApplications(newEgraph, applied) } - override def tryAddMany(nodes: Seq[ENode[Node]], - parallelize: ParallelMap): (Seq[AddNodeResult], EGraphWithRecordedApplications[Node, Repr, Match]) = { + override def tryAddMany(nodes: ArraySeq[ENode[Node]], + parallelize: ParallelMap): (ArraySeq[AddNodeResult], EGraphWithRecordedApplications[Node, Repr, Match]) = { val (results, newEgraph) = egraph.tryAddMany(nodes, parallelize) // Construct a new EGraphWithRecordedApplications with the new e-graph and the same applied matches. The applied diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala index 8ed402a1..87d9dd07 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala @@ -1,8 +1,10 @@ package foresight.eqsat.immutable -import foresight.eqsat._ +import foresight.eqsat.* import foresight.eqsat.parallel.ParallelMap +import scala.collection.compat.immutable.ArraySeq + /** * An e-graph that has a root e-class. * @param egraph The underlying e-graph that contains the nodes and classes. @@ -42,8 +44,8 @@ final case class EGraphWithRoot[ EGraphWithRoot(newGraph, root) } - override def tryAddMany(nodes: Seq[ENode[Node]], - parallelize: ParallelMap): (Seq[AddNodeResult], EGraphWithRoot[Node, Repr]) = { + override def tryAddMany(nodes: ArraySeq[ENode[Node]], + parallelize: ParallelMap): (ArraySeq[AddNodeResult], EGraphWithRoot[Node, Repr]) = { egraph.tryAddMany(nodes, parallelize) match { case (results, newGraph) => (results, EGraphWithRoot(newGraph, root)) } diff --git a/foresight/src/main/scala/foresight/eqsat/mutable/EGraph.scala b/foresight/src/main/scala/foresight/eqsat/mutable/EGraph.scala index b2fd56d7..ab2fd58d 100644 --- a/foresight/src/main/scala/foresight/eqsat/mutable/EGraph.scala +++ b/foresight/src/main/scala/foresight/eqsat/mutable/EGraph.scala @@ -5,6 +5,8 @@ import foresight.eqsat.readonly import foresight.eqsat.{AddNodeResult, EClassCall, ENode, MixedTree, Tree} import foresight.eqsat.hashCons.mutable.HashConsEGraph +import scala.collection.compat.immutable.ArraySeq + /** * A mutable e-graph that supports adding e-nodes and merging e-classes. * @@ -26,7 +28,7 @@ trait EGraph[NodeT] extends readonly.EGraph[NodeT] { * @param parallelize Strategy used for any parallel work within the addition. * @return Per-node results in input order. */ - def tryAddMany(nodes: Seq[ENode[NodeT]], parallelize: ParallelMap): Seq[AddNodeResult] + def tryAddMany(nodes: ArraySeq[ENode[NodeT]], parallelize: ParallelMap): ArraySeq[AddNodeResult] /** * Unions (merges) pairs of e-classes. @@ -63,7 +65,7 @@ trait EGraph[NodeT] extends readonly.EGraph[NodeT] { * @return E-class of `node`. */ final def add(node: ENode[NodeT]): EClassCall = { - tryAddMany(Seq(node), ParallelMap.sequential) match { + tryAddMany(ArraySeq(node), ParallelMap.sequential) match { case Seq(AddNodeResult.Added(call)) => call case Seq(AddNodeResult.AlreadyThere(call)) => call case _ => throw new IllegalStateException("Unexpected result from tryAddMany") diff --git a/foresight/src/main/scala/foresight/eqsat/mutable/EGraphWithMetadata.scala b/foresight/src/main/scala/foresight/eqsat/mutable/EGraphWithMetadata.scala index a7ebc8a1..45b80ace 100644 --- a/foresight/src/main/scala/foresight/eqsat/mutable/EGraphWithMetadata.scala +++ b/foresight/src/main/scala/foresight/eqsat/mutable/EGraphWithMetadata.scala @@ -4,6 +4,7 @@ import foresight.eqsat.metadata.Analysis import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.{AddNodeResult, EClassCall, ENode, readonly} +import scala.collection.compat.immutable.ArraySeq import scala.collection.mutable /** @@ -78,7 +79,7 @@ final class EGraphWithMetadata[ metadata.remove(name).isDefined } - override def tryAddMany(nodes: Seq[ENode[NodeT]], parallelize: ParallelMap): Seq[AddNodeResult] = { + override def tryAddMany(nodes: ArraySeq[ENode[NodeT]], parallelize: ParallelMap): ArraySeq[AddNodeResult] = { val results = egraph.tryAddMany(nodes, parallelize) val newNodes = nodes.zip(results).collect { case (node, AddNodeResult.Added(call)) => diff --git a/foresight/src/main/scala/foresight/eqsat/mutable/UpdatingImmutableEGraph.scala b/foresight/src/main/scala/foresight/eqsat/mutable/UpdatingImmutableEGraph.scala index e997ccee..ce9f031e 100644 --- a/foresight/src/main/scala/foresight/eqsat/mutable/UpdatingImmutableEGraph.scala +++ b/foresight/src/main/scala/foresight/eqsat/mutable/UpdatingImmutableEGraph.scala @@ -3,6 +3,8 @@ package foresight.eqsat.mutable import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.{AddNodeResult, EClassCall, EClassRef, ENode, ShapeCall, immutable} +import scala.collection.compat.immutable.ArraySeq + private final class UpdatingImmutableEGraph[ NodeT, EGraphT <: immutable.EGraph[NodeT] with immutable.EGraphLike[NodeT, EGraphT] @@ -16,7 +18,7 @@ private final class UpdatingImmutableEGraph[ result } - override def tryAddMany(nodes: Seq[ENode[NodeT]], parallelize: ParallelMap): Seq[AddNodeResult] = update { + override def tryAddMany(nodes: ArraySeq[ENode[NodeT]], parallelize: ParallelMap): ArraySeq[AddNodeResult] = update { _egraph.tryAddMany(nodes, parallelize) } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala index 4e67eea1..d02c78cb 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/Applier.scala @@ -1,12 +1,12 @@ package foresight.eqsat.rewriting -import foresight.eqsat.commands.{Command, CommandQueue} +import foresight.eqsat.commands.{CommandSchedule, CommandScheduleBuilder} import foresight.eqsat.readonly.EGraph /** * Describes how to **turn a match into edits** on an e-graph, without mutating it directly. * - * An [[Applier]] does not update the graph directly. Instead, it builds a [[Command]] describing the edits (e.g., + * An [[Applier]] does not update the graph directly. Instead, it builds a [[CommandSchedule]] describing the edits (e.g., * insertions, unions) that, when executed by the command engine, yields a new e-graph. * * Typically, an [[Applier]] is paired with a [[Searcher]] inside a [[Rule]]: @@ -15,7 +15,7 @@ import foresight.eqsat.readonly.EGraph * - those per-match commands are aggregated/optimized by the rule into a single operation. * * ## Contract - * - **Purity**: [[apply]] does not mutate `egraph`. It only describes work via a [[Command]]. + * - **Purity**: [[apply]] does not mutate `egraph`. It only describes work via a [[CommandSchedule]]. * - **Thread-safety**: Appliers may be invoked in parallel across distinct matches of the same snapshot. * They must be safe for concurrent use and avoid shared mutable state. * - **Idempotence** (recommended): When feasible, produce commands that tolerate duplicates (e.g., union @@ -38,7 +38,7 @@ trait Applier[NodeT, -MatchT, -EGraphT <: EGraph[NodeT]] { * @param egraph Immutable e-graph snapshot the match was derived from. * @return A command representing the intended edits for this match. */ - def apply(m: MatchT, egraph: EGraphT): Command[NodeT] + def apply(m: MatchT, egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit } /** @@ -54,7 +54,9 @@ object Applier { */ def ignore[NodeT, MatchT, EGraphT <: EGraph[NodeT]]: Applier[NodeT, MatchT, EGraphT] = new ReversibleApplier[NodeT, MatchT, EGraphT] { - override def apply(m: MatchT, egraph: EGraphT): Command[NodeT] = CommandQueue.empty + override def apply(m: MatchT, egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { + + } override def tryReverse: Option[Searcher[NodeT, MatchT, EGraphT]] = Some(Searcher.empty) } @@ -79,8 +81,10 @@ object Applier { filter: (MatchT, EGraphT) => Boolean) extends ReversibleApplier[NodeT, MatchT, EGraphT] { - override def apply(m: MatchT, egraph: EGraphT): Command[NodeT] = - if (filter(m, egraph)) applier.apply(m, egraph) else CommandQueue.empty + override def apply(m: MatchT, egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { + if (filter(m, egraph)) + applier.apply(m, egraph, builder) + } override def tryReverse: Option[Searcher[NodeT, MatchT, EGraphT]] = applier match { case r: ReversibleApplier[NodeT, MatchT, EGraphT] => r.tryReverse.map(_.filter(filter)) @@ -110,8 +114,9 @@ object Applier { f: (MatchT1, EGraphT) => MatchT2) extends Applier[NodeT, MatchT1, EGraphT] { - override def apply(m: MatchT1, egraph: EGraphT): Command[NodeT] = - applier.apply(f(m, egraph), egraph) + override def apply(m: MatchT1, egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { + applier.apply(f(m, egraph), egraph, builder) + } } /** @@ -131,8 +136,8 @@ object Applier { f: (MatchT1, EGraphT) => Iterable[MatchT2]) extends Applier[NodeT, MatchT1, EGraphT] { - override def apply(m: MatchT1, egraph: EGraphT): Command[NodeT] = - CommandQueue(f(m, egraph).map(applier.apply(_, egraph)).toSeq) + override def apply(m: MatchT1, egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = + f(m, egraph).foreach(applier.apply(_, egraph, builder)) } // ---------------------- Syntax sugar for Applier combinators ---------------------- diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala index 84cb7bc6..50e15c18 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/EClassSearcher.scala @@ -3,7 +3,7 @@ package foresight.eqsat.rewriting import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.EClassCall import foresight.eqsat.readonly.EGraph -import foresight.eqsat.commands.{Command, CommandQueue} +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.util.collections.StrictMapOps.toStrictMapOps import java.util.concurrent.atomic.AtomicIntegerArray @@ -167,31 +167,15 @@ private[eqsat] object EClassSearcher { * are passed to the provided continuation. * * @param eclassesToSearch The EClassesToSearch instance to get rules for. - * @param continuation Continuation to handle commands produced by rule applications. + * @param builder Command queue builder to collect produced commands. * @return A sequence of EClassSearcher instances that produce commands. */ def commandSearchers(eclassesToSearch: EClassesToSearch[EGraphT], - continuation: Command[NodeT] => Unit): Seq[EClassSearcher[NodeT, MatchT, EGraphT]] = { + builder: CommandScheduleBuilder[NodeT]): Seq[EClassSearcher[NodeT, MatchT, EGraphT]] = { rulesPerSharedEClassToSearch(eclassesToSearch).map { case Rule(_, searcher: EClassSearcher[NodeT, MatchT, _], applier) => val castSearcher = searcher.asInstanceOf[EClassSearcher[NodeT, MatchT, EGraphT]] - - castSearcher - .andThen(new castSearcher.ContinuationBuilder { - def apply(downstream: castSearcher.Continuation): castSearcher.Continuation = (m: MatchT, egraph: EGraphT) => { - if (downstream(m, egraph)) { - applier(m, egraph) match { - case CommandQueue(Seq()) => // Ignore no-op commands. - case cmd => - // Collect nontrivial commands. - continuation(cmd) - } - true - } else { - false - } - } - }) + castSearcher.andApplyAndCollect(applier, builder) case _ => throw new IllegalStateException("Expected EClassSearcher rule.") } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/Rewrite.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/Rewrite.scala index 4003a4e7..4f9f4e14 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/Rewrite.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/Rewrite.scala @@ -1,13 +1,11 @@ package foresight.eqsat.rewriting -import foresight.eqsat.commands.Command +import foresight.eqsat.commands.{CommandSchedule, CommandScheduleBuilder} import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.immutable import foresight.eqsat.mutable import foresight.eqsat.readonly.EGraph -import scala.collection.mutable.HashMap - /** * A rewrite rule encapsulates a search-and-replace operation on an e-graph. * @@ -45,6 +43,18 @@ trait Rewrite[NodeT, MatchT, -EGraphT <: EGraph[NodeT]] { */ def search(egraph: EGraphT, parallelize: ParallelMap = ParallelMap.default): Seq[MatchT] + /** + * Build a staged command from a precomputed set of matches. + * + * @param matches Matches to apply. + * @param egraph Target e-graph from which matches were derived. + * @param parallelize Parallel strategy used when building per-match commands. + * @param builder Command schedule builder to accumulate per-match commands into. + * @throws Rule.ApplicationException + * if constructing the per-match commands fails. + */ + def delayed(matches: Seq[MatchT], egraph: EGraphT, parallelize: ParallelMap, builder: CommandScheduleBuilder[NodeT]): Unit + /** * Build a staged command from a precomputed set of matches. * @@ -54,25 +64,46 @@ trait Rewrite[NodeT, MatchT, -EGraphT <: EGraph[NodeT]] { * @param matches Matches to apply. * @param egraph Target e-graph from which matches were derived. * @param parallelize Parallel strategy used when building per-match commands. - * @return An optimized [[CommandQueue]] encapsulated as a [[Command]]. + * @return An optimized [[CommandQueue]] encapsulated as a [[CommandSchedule]]. * @throws Rule.ApplicationException * if constructing the per-match commands fails. */ - def delayed(matches: Seq[MatchT], egraph: EGraphT, parallelize: ParallelMap): Command[NodeT] + final def delayed(matches: Seq[MatchT], egraph: EGraphT, parallelize: ParallelMap): CommandSchedule[NodeT] = { + // FIXME: build sequential variant that avoids concurrency overheads when parallelism is disabled + val collector = CommandScheduleBuilder.newConcurrentBuilder[NodeT] + delayed(matches, egraph, parallelize, collector) + collector.result() + } /** * Build a staged command that, when executed, applies this rule's matches to `egraph`. * - * This does not mutate the e-graph now; instead it returns a [[Command]] that you can: - * - enqueue into a [[CommandQueue]] with other rules; and - * - execute later as part of a larger saturation step. + * This does not mutate the e-graph now; instead it populates the provided [[CommandScheduleBuilder]] * * @param egraph The e-graph to search for matches. The staged command is intended to be run * against the same (or equivalent) snapshot. * @param parallelize Parallel strategy for both search and later application. - * @return A single, optimized [[Command]] that applies all current matches of this rule. + * @param builder Command schedule builder to accumulate per-match commands into. */ - def delayed(egraph: EGraphT, parallelize: ParallelMap = ParallelMap.default): Command[NodeT] + def delayed(egraph: EGraphT, parallelize: ParallelMap, builder: CommandScheduleBuilder[NodeT]): Unit + + /** + * Build a staged command that, when executed, applies this rule's matches to `egraph`. + * + * This does not mutate the e-graph now; instead it returns a [[CommandSchedule]] that you can execute later + * as part of a larger saturation step. + * + * @param egraph The e-graph to search for matches. The staged command is intended to be run + * against the same (or equivalent) snapshot. + * @param parallelize Parallel strategy for both search and later application. + * @return A single, optimized [[CommandSchedule]] that applies all current matches of this rule. + */ + final def delayed(egraph: EGraphT, parallelize: ParallelMap = ParallelMap.default): CommandSchedule[NodeT] = { + // FIXME: build sequential variant that avoids concurrency overheads when parallelism is disabled + val collector = CommandScheduleBuilder.newConcurrentBuilder[NodeT] + delayed(egraph, parallelize, collector) + collector.result() + } /** * Search the e-graph for matches and apply them immediately, if any. @@ -93,7 +124,7 @@ trait Rewrite[NodeT, MatchT, -EGraphT <: EGraph[NodeT]] { parallelize: ParallelMap = ParallelMap.default ): Option[MutEGraphT] = { val mutGraph = mutable.FreezableEGraph[NodeT, MutEGraphT](egraph) - val anyChanges = delayed(egraph, parallelize)(mutGraph, HashMap.empty, parallelize) + val anyChanges = delayed(egraph, parallelize)(mutGraph, parallelize) if (anyChanges) { Some(mutGraph.freeze()) } else { @@ -136,6 +167,6 @@ trait Rewrite[NodeT, MatchT, -EGraphT <: EGraph[NodeT]] { egraph: MutEGraphT, parallelize: ParallelMap = ParallelMap.default ): Boolean = { - delayed(egraph, parallelize)(egraph, HashMap.empty, parallelize) + delayed(egraph, parallelize)(egraph, parallelize) } } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/Rule.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/Rule.scala index e0e52963..be2f0072 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/Rule.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/Rule.scala @@ -1,9 +1,7 @@ package foresight.eqsat.rewriting -import foresight.eqsat.commands.{Command, CommandQueue} -import foresight.eqsat.parallel.{OperationCanceledException, ParallelMap} -import foresight.eqsat.mutable -import foresight.eqsat.immutable +import foresight.eqsat.commands.CommandScheduleBuilder +import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.readonly.EGraph /** @@ -12,7 +10,7 @@ import foresight.eqsat.readonly.EGraph * A rule has two primary components: * * 1. **Search** – uses a [[Searcher]] to discover all matches of the rule in a given [[EGraph]]. - * 2. **Apply** – for each match, uses an [[Applier]] to produce a [[Command]] that mutates the e-graph + * 2. **Apply** – for each match, uses an [[Applier]] to produce a [[CommandSchedule]] that mutates the e-graph * (e.g., unions, insertions). Applications may be parallelized via [[ParallelMap]]. * * After search and apply, a rule performs a **composition step**: @@ -48,7 +46,7 @@ import foresight.eqsat.readonly.EGraph * @tparam EGraphT Concrete e-graph type this rule targets. Must be both [[EGraphLike]] and [[EGraph]]. * @param name Human-readable rule name (used in logs/diagnostics). * @param searcher Component responsible for finding matches (see [[Searcher.search]]). - * @param applier Component that turns a match into a [[Command]] acting on the e-graph. + * @param applier Component that turns a match into a [[CommandSchedule]] acting on the e-graph. * @example Defining and running a rule immediately * {{{ * val constantFold: Rule[MyNode, MyMatch, MyEGraph] = @@ -96,13 +94,11 @@ final case class Rule[NodeT, MatchT, EGraphT <: EGraph[NodeT]](override val name * @param egraph The e-graph to search for matches. The staged command is intended to be run * against the same (or equivalent) snapshot. * @param parallelize Parallel strategy for both search and later application. - * @return A single, optimized [[Command]] that applies all current matches of this rule. + * @param builder Command schedule builder to accumulate per-match commands into. */ - override def delayed(egraph: EGraphT, parallelize: ParallelMap = ParallelMap.default): Command[NodeT] = { - val pipeline = searcher.andApply(applier) - aggregateCommands( - pipeline.searchAndCollect(egraph, parallelize.child(s"match+apply $name")), - egraph) + override def delayed(egraph: EGraphT, parallelize: ParallelMap, builder: CommandScheduleBuilder[NodeT]): Unit = { + val pipeline = searcher.andApplyAndCollect(applier, builder) + pipeline.search(egraph, parallelize.child(s"match+apply $name")) } /** @@ -114,24 +110,12 @@ final case class Rule[NodeT, MatchT, EGraphT <: EGraph[NodeT]](override val name * @param matches Matches to apply. * @param egraph Target e-graph from which matches were derived. * @param parallelize Parallel strategy used when building per-match commands. - * @return An optimized [[CommandQueue]] encapsulated as a [[Command]]. + * @param builder Command schedule builder to accumulate per-match commands into. * @throws Rule.ApplicationException * if constructing the per-match commands fails. */ - override def delayed(matches: Seq[MatchT], egraph: EGraphT, parallelize: ParallelMap): Command[NodeT] = { - aggregateCommands( - parallelize.child(s"apply $name")[MatchT, Command[NodeT]](matches, applier.apply(_, egraph)).toSeq, - egraph) - } - - private def aggregateCommands(buildCommands: => Seq[Command[NodeT]], egraph: EGraphT): Command[NodeT] = { - try { - CommandQueue(buildCommands) - } catch { - case e: OperationCanceledException => throw e - case e: Exception => - throw Rule.ApplicationException(this, egraph, e) - } + override def delayed(matches: Seq[MatchT], egraph: EGraphT, parallelize: ParallelMap, builder: CommandScheduleBuilder[NodeT]): Unit = { + parallelize.child(s"apply $name")[MatchT, Unit](matches, applier.apply(_, egraph, builder)) } /** diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/Searcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/Searcher.scala index 026cf784..b255c868 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/Searcher.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/Searcher.scala @@ -1,6 +1,5 @@ package foresight.eqsat.rewriting -import foresight.eqsat.commands.{Command, CommandQueue} import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.patterns.PatternMatch @@ -138,23 +137,6 @@ trait Searcher[NodeT, MatchT, EGraphT <: EGraph[NodeT]] TransformSearcher(SearcherContinuation.identityBuilder) } - - /** - * Chain this searcher with an [[Applier]], producing a searcher that runs this searcher - * and then immediately applies each match using the given applier. - * - * This is useful for building pipelines of searchers and appliers without needing to - * construct intermediate sequences of matches. - * - * @param applier The applier to run on each match found by this searcher. - * @return A new searcher that applies the given applier to each match found. - */ - final def andApply(applier: Applier[NodeT, MatchT, EGraphT]): Searcher[NodeT, Command[NodeT], EGraphT] = { - transform(applier.apply).filter { - case (CommandQueue(Seq()), _) => false - case _ => true - } - } } /** diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/SearcherLike.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/SearcherLike.scala index 79272f49..65559b87 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/SearcherLike.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/SearcherLike.scala @@ -2,6 +2,7 @@ package foresight.eqsat.rewriting import foresight.eqsat.rewriting.patterns.{Pattern, PatternMatch} import foresight.eqsat.Slot +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.eqsat.readonly.EGraph /** @@ -105,6 +106,28 @@ trait SearcherLike[Node, Match, EGraphT <: EGraph[Node], +This <: SearcherLike[N final def flatMap(f: (Match, EGraphT) => Iterable[Match]): This = { andThen(SearcherContinuation.flatMapBuilder(f)) } + + /** + * Chains an applier to the searcher-like object, applying it to each match + * and collecting the resulting commands using the provided collector. + * + * @param applier The applier to apply to each match. + * @param collector The command schedule builder to collect the resulting commands. + * @return A new instance of the searcher-like object that applies the applier and collects commands. + */ + private[eqsat] final def andApplyAndCollect(applier: Applier[Node, Match, EGraphT], + collector: CommandScheduleBuilder[Node]): This = { + andThen(new ContinuationBuilder { + def apply(downstream: Continuation): Continuation = (m: Match, egraph: EGraphT) => { + if (downstream(m, egraph)) { + applier(m, egraph, collector) + true + } else { + false + } + } + }) + } } /** diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/package.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/package.scala index 71d31cce..aa083d01 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/package.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/package.scala @@ -9,7 +9,7 @@ package foresight.eqsat * ## Core ideas * * The system is built on **immutable e-graphs**, where every change is represented - * as a [[foresight.eqsat.commands.Command]] value. Instead of mutating the e-graph + * as a [[foresight.eqsat.commands.CommandSchedule]] value. Instead of mutating the e-graph * in place, commands are executed later to produce a new snapshot. * * Rules follow a **search → apply → compose** flow. A [[Searcher]] is responsible @@ -36,7 +36,7 @@ package foresight.eqsat * can be enriched with [[Searcher.product]], filter, map and flatMap combinators. * * Next, an applier is defined to convert each match into a - * [[foresight.eqsat.commands.Command]]. The searcher and applier are then paired + * [[foresight.eqsat.commands.CommandSchedule]]. The searcher and applier are then paired * to form a [[Rule]]. Rules can be executed immediately using [[Rule.apply]], or * staged for later execution with [[Rule.delayed]]—either individually or batched * with others. diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index 730b0195..b8b4ec3c 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -1,12 +1,13 @@ package foresight.eqsat.rewriting.patterns import foresight.eqsat.collections.SlotSeq -import foresight.eqsat.commands.{Command, CommandQueueBuilder} +import foresight.eqsat.commands.CommandScheduleBuilder import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.{ReversibleApplier, Searcher} import foresight.eqsat.{EClassSymbol, MixedTree, Slot} -import scala.collection.compat._ +import scala.collection.compat.immutable.ArraySeq +import scala.runtime.IntRef /** * An applier that applies a pattern match to an e-graph. @@ -18,11 +19,9 @@ import scala.collection.compat._ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedTree[NodeT, Pattern.Var]) extends ReversibleApplier[NodeT, PatternMatch[NodeT], EGraphT] { - override def apply(m: PatternMatch[NodeT], egraph: EGraphT): Command[NodeT] = { - val builder = new CommandQueueBuilder[NodeT]() + override def apply(m: PatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { val symbol = instantiateAsSimplifiedAddCommand(pattern, m, egraph, builder) builder.unionSimplified(EClassSymbol.real(m.root), symbol, egraph) - builder.result() } override def tryReverse: Option[Searcher[NodeT, PatternMatch[NodeT], EGraphT]] = { @@ -69,13 +68,13 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT private final class SimplifiedAddCommandInstantiator(m: PatternMatch[NodeT], egraph: EGraphT, - builder: CommandQueueBuilder[NodeT]) { - def instantiate(pattern: MixedTree[NodeT, Pattern.Var]): EClassSymbol = { + builder: CommandScheduleBuilder[NodeT]) { + def instantiate(pattern: MixedTree[NodeT, Pattern.Var], maxBatch: IntRef): EClassSymbol = { pattern match { case MixedTree.Atom(p) => builder.addSimplifiedReal(m(p), egraph) case MixedTree.Node(t, defs@Seq(), uses, args) => // No definitions, so we can reuse the PatternMatch and its original slot mapping - addSimplifiedNode(t, defs, uses, args) + addSimplifiedNode(t, defs, uses, args, maxBatch) case MixedTree.Node(t, defs, uses, args) => val defSlots = defs.map { (s: Slot) => @@ -85,25 +84,31 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) - new SimplifiedAddCommandInstantiator(newMatch, egraph, builder).addSimplifiedNode(t, defSlots, uses, args) + new SimplifiedAddCommandInstantiator(newMatch, egraph, builder).addSimplifiedNode(t, defSlots, uses, args, maxBatch) } } private def addSimplifiedNode(nodeType: NodeT, definitions: SlotSeq, uses: SlotSeq, - args: immutable.ArraySeq[MixedTree[NodeT, Pattern.Var]]): EClassSymbol = { - val argSymbols = CommandQueueBuilder.symbolArrayFrom(args, instantiate) + args: ArraySeq[MixedTree[NodeT, Pattern.Var]], + maxBatch: IntRef): EClassSymbol = { + val argMaxBatch = IntRef(0) + val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, instantiate) val useSymbols = uses.map(m.apply: Slot => Slot) - builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, egraph) + val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph) + if (argMaxBatch.elem > maxBatch.elem) { + maxBatch.elem = argMaxBatch.elem + } + result } } private def instantiateAsSimplifiedAddCommand(pattern: MixedTree[NodeT, Pattern.Var], m: PatternMatch[NodeT], egraph: EGraphT, - builder: CommandQueueBuilder[NodeT]): EClassSymbol = { + builder: CommandScheduleBuilder[NodeT]): EClassSymbol = { - new SimplifiedAddCommandInstantiator(m, egraph, builder).instantiate(pattern) + new SimplifiedAddCommandInstantiator(m, egraph, builder).instantiate(pattern, new IntRef(0)) } } diff --git a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala index 794a6906..80b3ee4a 100644 --- a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala +++ b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala @@ -1,17 +1,14 @@ package foresight.eqsat.saturation -import foresight.eqsat.commands.{Command, CommandQueue} +import foresight.eqsat.commands.{CommandSchedule, CommandScheduleBuilder} import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.rewriting.{EClassSearcher, EClassesToSearch, PortableMatch, Rewrite, Rule} +import foresight.eqsat.rewriting.{EClassSearcher, PortableMatch, Rewrite} import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithRecordedApplications} import foresight.eqsat.mutable.{FreezableEGraph, EGraph => MutableEGraph} import foresight.eqsat.readonly -import foresight.eqsat.rewriting.SearcherContinuation.Continuation import foresight.util.collections.StrictMapOps.toStrictMapOps import foresight.util.collections.UnsafeSeqFromArray -import scala.collection.mutable.HashMap - /** * A strategy that searches for matches of a set of rules in an e-graph and applies them. * @@ -34,18 +31,20 @@ trait SearchAndApply[NodeT, -RuleT <: Rewrite[NodeT, MatchT, _], EGraphT <: read parallelize: ParallelMap): Seq[MatchT] /** - * Produces a command that applies the given matches of the rule to the e-graph. + * Converts the matches found for the given rule into commands that can be applied to the e-graph. * * @param rule The rule whose matches are to be applied. * @param matches The matches to be applied. * @param egraph The e-graph to which the matches are applied. * @param parallelize A parallelization strategy for applying the matches. - * @return A command that applies the matches to the e-graph. + * @param collector A command schedule builder to collect the generated commands. + * */ protected def delayed(rule: RuleT, matches: Seq[MatchT], egraph: EGraphT, - parallelize: ParallelMap): Command[NodeT] + parallelize: ParallelMap, + collector: CommandScheduleBuilder[NodeT]): Unit /** * Applies the given command to the e-graph. @@ -58,7 +57,7 @@ trait SearchAndApply[NodeT, -RuleT <: Rewrite[NodeT, MatchT, _], EGraphT <: read * @param parallelize A parallelization strategy for applying the command. * @return An updated e-graph with the command applied, or None if the command made no changes. */ - protected def update(command: Command[NodeT], + protected def update(command: CommandSchedule[NodeT], matches: Map[String, Seq[MatchT]], egraph: EGraphT, parallelize: ParallelMap): Option[EGraphT] @@ -83,42 +82,42 @@ trait SearchAndApply[NodeT, -RuleT <: Rewrite[NodeT, MatchT, _], EGraphT <: read } /** - * Produces a sequence of commands that apply the given matches of the rules to the e-graph. + * Converts the matches found for the given rules into commands that can be applied to the e-graph. * @param rules The rules whose matches are to be applied. * @param matches A map from rule names to sequences of matches found for each rule. * @param egraph The e-graph to which the matches are applied. * @param parallelize A parallelization strategy for applying the matches. - * @return A sequence of commands that apply the matches to the e-graph. + * @param builder A command schedule builder to collect the generated commands. */ final def delayed(rules: Seq[RuleT], matches: Map[String, Seq[MatchT]], egraph: EGraphT, - parallelize: ParallelMap): Seq[Command[NodeT]] = { + parallelize: ParallelMap, + builder: CommandScheduleBuilder[NodeT]): Unit = { + val ruleApplicationParallelize = parallelize.child("rule application") - ruleApplicationParallelize[RuleT, Command[NodeT]](rules, (rule: RuleT) => { + ruleApplicationParallelize[RuleT, Unit](rules, (rule: RuleT) => { val newMatches = matches(rule.name) - delayed(rule, newMatches, egraph, ruleApplicationParallelize) - }).toSeq + delayed(rule, newMatches, egraph, ruleApplicationParallelize, builder) + }) } /** - * Applies the given sequence of commands to the e-graph. This method first optimizes the sequence of commands - * and then applies the optimized command to the e-graph. - * @param commands The sequence of commands to apply. - * @param matches A map from rule names to sequences of matches found for each rule. This is provided for - * commands that may need to know which matches were found for which rules (e.g., for caching). - * @param egraph The e-graph to which the commands are applied. Depending on the implementation, this may be a mutable - * or immutable e-graph. If it is mutable, the commands may modify it in place. If it is immutable, - * the commands will return a new e-graph if they make any changes. - * @param parallelize A parallelization strategy for applying the commands. - * @return An updated e-graph with the commands applied, or None if no commands made any changes. + * Converts the matches found for the given rules into commands that can be applied to the e-graph. + * @param rules The rules whose matches are to be applied. + * @param matches A map from rule names to sequences of matches found for each rule. + * @param egraph The e-graph to which the matches are applied. + * @param parallelize A parallelization strategy for applying the matches. + * @return A command schedule that applies the matches to the e-graph. */ - final def update(commands: Seq[Command[NodeT]], - matches: Map[String, Seq[MatchT]], - egraph: EGraphT, - parallelize: ParallelMap): Option[EGraphT] = { - val command = CommandQueue(commands).optimized - update(command, matches, egraph, parallelize) + final def delayed(rules: Seq[RuleT], + matches: Map[String, Seq[MatchT]], + egraph: EGraphT, + parallelize: ParallelMap): CommandSchedule[NodeT] = { + // FIXME: build sequential variant that avoids concurrency overheads when parallelism is disabled + val collector = CommandScheduleBuilder.newConcurrentBuilder[NodeT] + delayed(rules, matches, egraph, parallelize, collector) + collector.result() } /** @@ -173,11 +172,11 @@ object SearchAndApply { new NoMatchCaching[NodeT, EGraphT, MatchT] { override def searchLoopInterchange: Boolean = false - override def update(command: Command[NodeT], + override def update(command: CommandSchedule[NodeT], matches: Map[String, Seq[MatchT]], egraph: EGraphT, parallelize: ParallelMap): Option[EGraphT] = { - val anyChanges = command(egraph, HashMap.empty, parallelize) + val anyChanges = command(egraph, parallelize) if (anyChanges) Some(egraph) else None } } @@ -198,12 +197,11 @@ object SearchAndApply { new NoMatchCaching[NodeT, EGraphT, MatchT] { override def searchLoopInterchange: Boolean = false - override def update(command: Command[NodeT], + override def update(command: CommandSchedule[NodeT], matches: Map[String, Seq[MatchT]], egraph: EGraphT, parallelize: ParallelMap): Option[EGraphT] = { - val (newEGraph, _) = command.applyImmutable(egraph, Map.empty, parallelize) - newEGraph + command.applyImmutable(egraph, parallelize) } } } @@ -232,17 +230,18 @@ object SearchAndApply { override def delayed(rule: Rewrite[NodeT, MatchT, EGraphT], matches: Seq[MatchT], egraph: EGraphWithRecordedApplications[NodeT, EGraphT, MatchT], - parallelize: ParallelMap): Command[NodeT] = { - rule.delayed(matches, egraph.egraph, parallelize) + parallelize: ParallelMap, + builder: CommandScheduleBuilder[NodeT]): Unit = { + rule.delayed(matches, egraph.egraph, parallelize, builder) } - override def update(command: Command[NodeT], + override def update(command: CommandSchedule[NodeT], matches: Map[String, Seq[MatchT]], egraph: EGraphWithRecordedApplications[NodeT, EGraphT, MatchT], parallelize: ParallelMap): Option[EGraphWithRecordedApplications[NodeT, EGraphT, MatchT]] = { val recorded = matches.mapValuesStrict(_.toSet) val mutEGraph = FreezableEGraph[NodeT, EGraphWithRecordedApplications[NodeT, EGraphT, MatchT]](egraph.record(recorded)) - val anyChanges = command(mutEGraph, HashMap.empty, parallelize) + val anyChanges = command(mutEGraph, parallelize) if (anyChanges) Some(mutEGraph.freeze()) else None } } @@ -269,8 +268,9 @@ object SearchAndApply { final override def delayed(rule: Rewrite[NodeT, MatchT, EGraphT], matches: Seq[MatchT], egraph: EGraphT, - parallelize: ParallelMap): Command[NodeT] = { - rule.delayed(matches, egraph, parallelize) + parallelize: ParallelMap, + builder: CommandScheduleBuilder[NodeT]): Unit = { + rule.delayed(matches, egraph, parallelize, builder) } final override def apply(rules: Seq[Rewrite[NodeT, MatchT, EGraphT]], @@ -278,19 +278,19 @@ object SearchAndApply { parallelize: ParallelMap): Option[EGraphT] = { val ruleMatchingAndApplicationParallelize = parallelize.child("rule matching+application") + // FIXME: build sequential variant that avoids concurrency overheads when parallelism is disabled + val collector = CommandScheduleBuilder.newConcurrentBuilder[NodeT] + if (!searchLoopInterchange || egraph.classCount <= EClassSearcher.smallEGraphThreshold) { // Small e-graph optimization: for small e-graphs, the overhead of partitioning and // fusing rule applications outweighs the benefits. Just process each rule normally. - val updates = ruleMatchingAndApplicationParallelize( + ruleMatchingAndApplicationParallelize( rules, (rule: Rewrite[NodeT, MatchT, EGraphT]) => { - rule.delayed(egraph, ruleMatchingAndApplicationParallelize) + rule.delayed(egraph, ruleMatchingAndApplicationParallelize, collector) } - ).toSeq - update(updates, Map.empty[String, Seq[MatchT]], egraph, parallelize) + ) } else { - val updates = Seq.newBuilder[Command[NodeT]] - // Idea: EClassSearcher rules are the common case, and they apply in parallel over a subset of // e-classes in the e-graph. If multiple rules share the same subset of e-classes to search, // we can group them together to fuse iterations over those e-classes. Fusion both reduces @@ -300,24 +300,22 @@ object SearchAndApply { // Process regular rules normally. for (rule <- partitioned.regularRules) { - updates += rule.delayed(egraph, ruleMatchingAndApplicationParallelize) + rule.delayed(egraph, ruleMatchingAndApplicationParallelize, collector) } // Process shared EClassesToSearch rules together. for (eclassesToSearch <- partitioned.rulesPerSharedEClassToSearch.keys) { - updates ++= ruleMatchingAndApplicationParallelize.collectFrom[Command[NodeT]] { (add: Command[NodeT] => Unit) => - val commandSearchers = partitioned.commandSearchers(eclassesToSearch, add) - EClassSearcher.searchMultiple( - UnsafeSeqFromArray(commandSearchers.toArray), - eclassesToSearch(egraph), - egraph, - ruleMatchingAndApplicationParallelize - ) - } + val commandSearchers = partitioned.commandSearchers(eclassesToSearch, collector) + EClassSearcher.searchMultiple( + UnsafeSeqFromArray(commandSearchers.toArray), + eclassesToSearch(egraph), + egraph, + ruleMatchingAndApplicationParallelize + ) } - - update(updates.result(), Map.empty[String, Seq[MatchT]], egraph, parallelize) } + + update(collector.result(), Map.empty[String, Seq[MatchT]], egraph, parallelize) } } } diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala index 41e2d5a7..a49e836d 100644 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala @@ -1,207 +1,207 @@ -package foresight.eqsat.commands - -import foresight.eqsat.collections.SlotSeq -import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.{EClassSymbol, ENode, ENodeSymbol, MixedTree, Slot} -import foresight.eqsat.immutable.EGraph -import org.junit.Test - -import scala.collection.compat.immutable.ArraySeq - -class CommandQueueBuilderTest { - /** - * An empty queue does nothing. - */ - @Test - def emptyQueueDoesNothing(): Unit = { - val builder = new CommandQueueBuilder[Int] - - val egraph = EGraph.empty[Int] - - assert(builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential)._1.isEmpty) - } - - /** - * Adding a node to the queue adds it to the e-graph. - */ - @Test - def addNode(): Unit = { - val builder = new CommandQueueBuilder[Int] - val egraph = EGraph.empty[Int] - - val node = ENodeSymbol(0, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty) - builder.add(node) - - val queue = builder.result() - assert(queue.commands.size == 1) - assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) - assert(queue.commands.head.asInstanceOf[AddManyCommand[Int]].nodes.head._2 == node) - - val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) - assert(egraph2.classes.size == 1) - assert(egraph2.nodes(egraph2.canonicalize(egraph2.classes.head)).head == node.reify(Map.empty)) - } - - /** - * Adding a tree to the queue adds it to the e-graph. - */ - @Test - def addSingleNodeTree(): Unit = { - val builder = new CommandQueueBuilder[Int] - val egraph = EGraph.empty[Int] - - val tree = MixedTree.Node[Int, EClassSymbol](0, Seq.empty[Slot], Seq.empty[Slot], Seq.empty[MixedTree[Int, EClassSymbol]]) - builder.add(tree) - - val queue = builder.result() - assert(queue.commands.size == 1) - assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) - - val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) - assert(egraph2.classes.size == 1) - } - - /** - * Adding a tree with a child to the queue adds it to the e-graph. - */ - @Test - def addTreeWithChild(): Unit = { - val builder = new CommandQueueBuilder[Int] - val egraph = EGraph.empty[Int] - - val child = MixedTree.unslotted(1, Seq.empty[MixedTree[Int, EClassSymbol]]) - val tree = MixedTree.unslotted(0, Seq(child)) - builder.add(tree) - - val queue = builder.result() - assert(queue.commands.size == 2) - assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) - assert(queue.commands(1).isInstanceOf[AddManyCommand[Int]]) - - val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) - assert(egraph2.classes.size == 2) - } - - /** - * Adding a tree with a call to the queue adds it to the e-graph. - */ - @Test - def addTreeWithCall(): Unit = { - val builder = new CommandQueueBuilder[Int] - val egraph = EGraph.empty[Int] - val (call, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) - - val tree = MixedTree.Atom[Int, EClassSymbol](EClassSymbol.real(call)) - builder.add(tree) - - val queue = builder.result() - assert(queue.commands.isEmpty) - - val (None, _) = builder.result().applyImmutable(egraph2, Map.empty, ParallelMap.sequential) - } - - /** - * Unions two e-classes in the e-graph. - */ - @Test - def union(): Unit = { - val builder = new CommandQueueBuilder[Int] - val egraph = EGraph.empty[Int] - val (a, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) - val (b, egraph3) = egraph2.add(ENode(1, Seq.empty, Seq.empty, Seq.empty)) - - builder.union(EClassSymbol.real(a), EClassSymbol.real(b)) - - val queue = builder.result() - assert(queue.commands.size == 1) - assert(queue.commands.head.isInstanceOf[UnionManyCommand[Int]]) - - val (Some(egraph4), _) = builder.result().applyImmutable(egraph3, Map.empty, ParallelMap.sequential) - assert(egraph4.classes.size == 1) - assert(egraph4.areSame(a, b)) - } - - /** - * Unions are combined into a single command by CommandQueue.optimized. - */ - @Test - def optimizedUnions(): Unit = { - val builder = new CommandQueueBuilder[Int] - val egraph = EGraph.empty[Int] - val (a, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) - val (b, egraph3) = egraph2.add(ENode(1, Seq.empty, Seq.empty, Seq.empty)) - val (c, egraph4) = egraph3.add(ENode(2, Seq.empty, Seq.empty, Seq.empty)) - - builder.union(EClassSymbol.real(a), EClassSymbol.real(b)) - builder.union(EClassSymbol.real(b), EClassSymbol.real(c)) - - val naiveQueue = builder.result() - assert(naiveQueue.commands.size == 2) - assert(naiveQueue.commands.head.isInstanceOf[UnionManyCommand[Int]]) - assert(naiveQueue.commands(1).isInstanceOf[UnionManyCommand[Int]]) - - val optimizedQueue = builder.result().optimized - assert(optimizedQueue.commands.size == 1) - assert(optimizedQueue.commands.head.isInstanceOf[UnionManyCommand[Int]]) - - val (Some(egraph5), _) = optimizedQueue.applyImmutable(egraph4, Map.empty, ParallelMap.sequential) - assert(egraph5.classes.size == 1) - assert(egraph5.areSame(a, b)) - assert(egraph5.areSame(b, c)) - } - - @Test - def xyUnionYxProducesPermutation(): Unit = { - val builder = new CommandQueueBuilder[Int] - - val x = Slot.fresh() - val y = Slot.fresh() - - val node1 = ENodeSymbol(0, SlotSeq.empty, SlotSeq(x, y), ArraySeq.empty) - val node2 = ENodeSymbol(0, SlotSeq.empty, SlotSeq(y, x), ArraySeq.empty) - - val a = builder.add(node1) - val b = builder.add(node2) - - builder.union(a, b) - - val node3 = ENodeSymbol(1, SlotSeq(x, y), SlotSeq.empty, ArraySeq(a)) - val node4 = ENodeSymbol(1, SlotSeq(x, y), SlotSeq.empty, ArraySeq(b)) - - val c = builder.add(node3) - val d = builder.add(node4) - - for (queue <- Seq(builder.result(), builder.result().optimized)) { - val (Some(egraph), reification) = queue.applyImmutable(EGraph.empty[Int], Map.empty, ParallelMap.sequential) - - assert(egraph.classes.size == 2) - assert(egraph.areSame(a.reify(reification), b.reify(reification))) - assert(egraph.areSame(c.reify(reification), d.reify(reification))) - } - } - - @Test - def constructTreeWithSlots(): Unit = { - val builder = new CommandQueueBuilder[Int] - - val x = Slot.fresh() - val y = Slot.fresh() - - val tree1 = builder.add(MixedTree.Node(2, Seq.empty, Seq(y), Seq.empty)) - val tree2 = builder.add( - MixedTree.Node(0, Seq(x), Seq.empty, Seq(MixedTree.Node(1, Seq.empty, Seq(x), Seq(MixedTree.Atom(tree1)))))) - - for (queue <- Seq(builder.result(), builder.result().optimized)) { - val (Some(egraph), reification) = queue.applyImmutable(EGraph.empty[Int], Map.empty, ParallelMap.sequential) - - assert(egraph.classes.size == 3) - assert(tree1.reify(reification).args.valueSet == Set(y)) - assert(tree2.reify(reification).args.valueSet == Set(y)) - - assert(!egraph.contains( - MixedTree.Node(0, Seq(x), Seq.empty, Seq( - MixedTree.Node(1, Seq.empty, Seq(y), Seq(MixedTree.Atom(tree1.reify(reification)))))))) - } - } -} - +//package foresight.eqsat.commands +// +//import foresight.eqsat.collections.SlotSeq +//import foresight.eqsat.parallel.ParallelMap +//import foresight.eqsat.{EClassSymbol, ENode, ENodeSymbol, MixedTree, Slot} +//import foresight.eqsat.immutable.EGraph +//import org.junit.Test +// +//import scala.collection.compat.immutable.ArraySeq +// +//class CommandQueueBuilderTest { +// /** +// * An empty queue does nothing. +// */ +// @Test +// def emptyQueueDoesNothing(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// +// val egraph = EGraph.empty[Int] +// +// assert(builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential)._1.isEmpty) +// } +// +// /** +// * Adding a node to the queue adds it to the e-graph. +// */ +// @Test +// def addNode(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// val egraph = EGraph.empty[Int] +// +// val node = ENodeSymbol(0, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty) +// builder.add(node) +// +// val queue = builder.result() +// assert(queue.commands.size == 1) +// assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) +// assert(queue.commands.head.asInstanceOf[AddManyCommand[Int]].nodes.head._2 == node) +// +// val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) +// assert(egraph2.classes.size == 1) +// assert(egraph2.nodes(egraph2.canonicalize(egraph2.classes.head)).head == node.reify(Map.empty)) +// } +// +// /** +// * Adding a tree to the queue adds it to the e-graph. +// */ +// @Test +// def addSingleNodeTree(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// val egraph = EGraph.empty[Int] +// +// val tree = MixedTree.Node[Int, EClassSymbol](0, Seq.empty[Slot], Seq.empty[Slot], Seq.empty[MixedTree[Int, EClassSymbol]]) +// builder.add(tree) +// +// val queue = builder.result() +// assert(queue.commands.size == 1) +// assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) +// +// val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) +// assert(egraph2.classes.size == 1) +// } +// +// /** +// * Adding a tree with a child to the queue adds it to the e-graph. +// */ +// @Test +// def addTreeWithChild(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// val egraph = EGraph.empty[Int] +// +// val child = MixedTree.unslotted(1, Seq.empty[MixedTree[Int, EClassSymbol]]) +// val tree = MixedTree.unslotted(0, Seq(child)) +// builder.add(tree) +// +// val queue = builder.result() +// assert(queue.commands.size == 2) +// assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) +// assert(queue.commands(1).isInstanceOf[AddManyCommand[Int]]) +// +// val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) +// assert(egraph2.classes.size == 2) +// } +// +// /** +// * Adding a tree with a call to the queue adds it to the e-graph. +// */ +// @Test +// def addTreeWithCall(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// val egraph = EGraph.empty[Int] +// val (call, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) +// +// val tree = MixedTree.Atom[Int, EClassSymbol](EClassSymbol.real(call)) +// builder.add(tree) +// +// val queue = builder.result() +// assert(queue.commands.isEmpty) +// +// val (None, _) = builder.result().applyImmutable(egraph2, Map.empty, ParallelMap.sequential) +// } +// +// /** +// * Unions two e-classes in the e-graph. +// */ +// @Test +// def union(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// val egraph = EGraph.empty[Int] +// val (a, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) +// val (b, egraph3) = egraph2.add(ENode(1, Seq.empty, Seq.empty, Seq.empty)) +// +// builder.union(EClassSymbol.real(a), EClassSymbol.real(b)) +// +// val queue = builder.result() +// assert(queue.commands.size == 1) +// assert(queue.commands.head.isInstanceOf[UnionManyCommand[Int]]) +// +// val (Some(egraph4), _) = builder.result().applyImmutable(egraph3, Map.empty, ParallelMap.sequential) +// assert(egraph4.classes.size == 1) +// assert(egraph4.areSame(a, b)) +// } +// +// /** +// * Unions are combined into a single command by CommandQueue.optimized. +// */ +// @Test +// def optimizedUnions(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// val egraph = EGraph.empty[Int] +// val (a, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) +// val (b, egraph3) = egraph2.add(ENode(1, Seq.empty, Seq.empty, Seq.empty)) +// val (c, egraph4) = egraph3.add(ENode(2, Seq.empty, Seq.empty, Seq.empty)) +// +// builder.union(EClassSymbol.real(a), EClassSymbol.real(b)) +// builder.union(EClassSymbol.real(b), EClassSymbol.real(c)) +// +// val naiveQueue = builder.result() +// assert(naiveQueue.commands.size == 2) +// assert(naiveQueue.commands.head.isInstanceOf[UnionManyCommand[Int]]) +// assert(naiveQueue.commands(1).isInstanceOf[UnionManyCommand[Int]]) +// +// val optimizedQueue = builder.result().optimized +// assert(optimizedQueue.commands.size == 1) +// assert(optimizedQueue.commands.head.isInstanceOf[UnionManyCommand[Int]]) +// +// val (Some(egraph5), _) = optimizedQueue.applyImmutable(egraph4, Map.empty, ParallelMap.sequential) +// assert(egraph5.classes.size == 1) +// assert(egraph5.areSame(a, b)) +// assert(egraph5.areSame(b, c)) +// } +// +// @Test +// def xyUnionYxProducesPermutation(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// +// val x = Slot.fresh() +// val y = Slot.fresh() +// +// val node1 = ENodeSymbol(0, SlotSeq.empty, SlotSeq(x, y), ArraySeq.empty) +// val node2 = ENodeSymbol(0, SlotSeq.empty, SlotSeq(y, x), ArraySeq.empty) +// +// val a = builder.add(node1) +// val b = builder.add(node2) +// +// builder.union(a, b) +// +// val node3 = ENodeSymbol(1, SlotSeq(x, y), SlotSeq.empty, ArraySeq(a)) +// val node4 = ENodeSymbol(1, SlotSeq(x, y), SlotSeq.empty, ArraySeq(b)) +// +// val c = builder.add(node3) +// val d = builder.add(node4) +// +// for (queue <- Seq(builder.result(), builder.result().optimized)) { +// val (Some(egraph), reification) = queue.applyImmutable(EGraph.empty[Int], Map.empty, ParallelMap.sequential) +// +// assert(egraph.classes.size == 2) +// assert(egraph.areSame(a.reify(reification), b.reify(reification))) +// assert(egraph.areSame(c.reify(reification), d.reify(reification))) +// } +// } +// +// @Test +// def constructTreeWithSlots(): Unit = { +// val builder = new CommandQueueBuilder[Int] +// +// val x = Slot.fresh() +// val y = Slot.fresh() +// +// val tree1 = builder.add(MixedTree.Node(2, Seq.empty, Seq(y), Seq.empty)) +// val tree2 = builder.add( +// MixedTree.Node(0, Seq(x), Seq.empty, Seq(MixedTree.Node(1, Seq.empty, Seq(x), Seq(MixedTree.Atom(tree1)))))) +// +// for (queue <- Seq(builder.result(), builder.result().optimized)) { +// val (Some(egraph), reification) = queue.applyImmutable(EGraph.empty[Int], Map.empty, ParallelMap.sequential) +// +// assert(egraph.classes.size == 3) +// assert(tree1.reify(reification).args.valueSet == Set(y)) +// assert(tree2.reify(reification).args.valueSet == Set(y)) +// +// assert(!egraph.contains( +// MixedTree.Node(0, Seq(x), Seq.empty, Seq( +// MixedTree.Node(1, Seq.empty, Seq(y), Seq(MixedTree.Atom(tree1.reify(reification)))))))) +// } +// } +//} +// From 14e10951cca7d46061cd29cd40f265b523969856 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:14:23 -0500 Subject: [PATCH 12/33] Fix additions method to handle empty batchZero case --- .../scala/foresight/eqsat/commands/CommandSchedule.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala index c825cd54..0730646e 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala @@ -25,7 +25,11 @@ final case class CommandSchedule[NodeT](batchZero: (ArraySeq[EClassSymbol.Virtua * The additions scheduled in this command schedule, grouped by batch. */ def additions: Seq[(ArraySeq[EClassSymbol.Virtual], ArraySeq[ENodeSymbol[NodeT]])] = { - batchZero +: otherBatches + if (batchZero._1.isEmpty) { + otherBatches + } else { + batchZero +: otherBatches + } } private type ReificationMap = util.IdentityHashMap[EClassSymbol.Virtual, EClassCall] From e7043f2f279160ebddb7924fbb148dcc051b9949 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:14:29 -0500 Subject: [PATCH 13/33] Add unit tests for CommandScheduleBuilder functionality --- .../commands/CommandQueueBuilderTest.scala | 207 ------------------ .../commands/CommandScheduleBuilderTest.scala | 169 ++++++++++++++ 2 files changed, 169 insertions(+), 207 deletions(-) delete mode 100644 foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala create mode 100644 foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala deleted file mode 100644 index a49e836d..00000000 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandQueueBuilderTest.scala +++ /dev/null @@ -1,207 +0,0 @@ -//package foresight.eqsat.commands -// -//import foresight.eqsat.collections.SlotSeq -//import foresight.eqsat.parallel.ParallelMap -//import foresight.eqsat.{EClassSymbol, ENode, ENodeSymbol, MixedTree, Slot} -//import foresight.eqsat.immutable.EGraph -//import org.junit.Test -// -//import scala.collection.compat.immutable.ArraySeq -// -//class CommandQueueBuilderTest { -// /** -// * An empty queue does nothing. -// */ -// @Test -// def emptyQueueDoesNothing(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// -// val egraph = EGraph.empty[Int] -// -// assert(builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential)._1.isEmpty) -// } -// -// /** -// * Adding a node to the queue adds it to the e-graph. -// */ -// @Test -// def addNode(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// val egraph = EGraph.empty[Int] -// -// val node = ENodeSymbol(0, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty) -// builder.add(node) -// -// val queue = builder.result() -// assert(queue.commands.size == 1) -// assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) -// assert(queue.commands.head.asInstanceOf[AddManyCommand[Int]].nodes.head._2 == node) -// -// val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) -// assert(egraph2.classes.size == 1) -// assert(egraph2.nodes(egraph2.canonicalize(egraph2.classes.head)).head == node.reify(Map.empty)) -// } -// -// /** -// * Adding a tree to the queue adds it to the e-graph. -// */ -// @Test -// def addSingleNodeTree(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// val egraph = EGraph.empty[Int] -// -// val tree = MixedTree.Node[Int, EClassSymbol](0, Seq.empty[Slot], Seq.empty[Slot], Seq.empty[MixedTree[Int, EClassSymbol]]) -// builder.add(tree) -// -// val queue = builder.result() -// assert(queue.commands.size == 1) -// assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) -// -// val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) -// assert(egraph2.classes.size == 1) -// } -// -// /** -// * Adding a tree with a child to the queue adds it to the e-graph. -// */ -// @Test -// def addTreeWithChild(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// val egraph = EGraph.empty[Int] -// -// val child = MixedTree.unslotted(1, Seq.empty[MixedTree[Int, EClassSymbol]]) -// val tree = MixedTree.unslotted(0, Seq(child)) -// builder.add(tree) -// -// val queue = builder.result() -// assert(queue.commands.size == 2) -// assert(queue.commands.head.isInstanceOf[AddManyCommand[Int]]) -// assert(queue.commands(1).isInstanceOf[AddManyCommand[Int]]) -// -// val (Some(egraph2), _) = builder.result().applyImmutable(egraph, Map.empty, ParallelMap.sequential) -// assert(egraph2.classes.size == 2) -// } -// -// /** -// * Adding a tree with a call to the queue adds it to the e-graph. -// */ -// @Test -// def addTreeWithCall(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// val egraph = EGraph.empty[Int] -// val (call, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) -// -// val tree = MixedTree.Atom[Int, EClassSymbol](EClassSymbol.real(call)) -// builder.add(tree) -// -// val queue = builder.result() -// assert(queue.commands.isEmpty) -// -// val (None, _) = builder.result().applyImmutable(egraph2, Map.empty, ParallelMap.sequential) -// } -// -// /** -// * Unions two e-classes in the e-graph. -// */ -// @Test -// def union(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// val egraph = EGraph.empty[Int] -// val (a, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) -// val (b, egraph3) = egraph2.add(ENode(1, Seq.empty, Seq.empty, Seq.empty)) -// -// builder.union(EClassSymbol.real(a), EClassSymbol.real(b)) -// -// val queue = builder.result() -// assert(queue.commands.size == 1) -// assert(queue.commands.head.isInstanceOf[UnionManyCommand[Int]]) -// -// val (Some(egraph4), _) = builder.result().applyImmutable(egraph3, Map.empty, ParallelMap.sequential) -// assert(egraph4.classes.size == 1) -// assert(egraph4.areSame(a, b)) -// } -// -// /** -// * Unions are combined into a single command by CommandQueue.optimized. -// */ -// @Test -// def optimizedUnions(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// val egraph = EGraph.empty[Int] -// val (a, egraph2) = egraph.add(ENode(0, Seq.empty, Seq.empty, Seq.empty)) -// val (b, egraph3) = egraph2.add(ENode(1, Seq.empty, Seq.empty, Seq.empty)) -// val (c, egraph4) = egraph3.add(ENode(2, Seq.empty, Seq.empty, Seq.empty)) -// -// builder.union(EClassSymbol.real(a), EClassSymbol.real(b)) -// builder.union(EClassSymbol.real(b), EClassSymbol.real(c)) -// -// val naiveQueue = builder.result() -// assert(naiveQueue.commands.size == 2) -// assert(naiveQueue.commands.head.isInstanceOf[UnionManyCommand[Int]]) -// assert(naiveQueue.commands(1).isInstanceOf[UnionManyCommand[Int]]) -// -// val optimizedQueue = builder.result().optimized -// assert(optimizedQueue.commands.size == 1) -// assert(optimizedQueue.commands.head.isInstanceOf[UnionManyCommand[Int]]) -// -// val (Some(egraph5), _) = optimizedQueue.applyImmutable(egraph4, Map.empty, ParallelMap.sequential) -// assert(egraph5.classes.size == 1) -// assert(egraph5.areSame(a, b)) -// assert(egraph5.areSame(b, c)) -// } -// -// @Test -// def xyUnionYxProducesPermutation(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// -// val x = Slot.fresh() -// val y = Slot.fresh() -// -// val node1 = ENodeSymbol(0, SlotSeq.empty, SlotSeq(x, y), ArraySeq.empty) -// val node2 = ENodeSymbol(0, SlotSeq.empty, SlotSeq(y, x), ArraySeq.empty) -// -// val a = builder.add(node1) -// val b = builder.add(node2) -// -// builder.union(a, b) -// -// val node3 = ENodeSymbol(1, SlotSeq(x, y), SlotSeq.empty, ArraySeq(a)) -// val node4 = ENodeSymbol(1, SlotSeq(x, y), SlotSeq.empty, ArraySeq(b)) -// -// val c = builder.add(node3) -// val d = builder.add(node4) -// -// for (queue <- Seq(builder.result(), builder.result().optimized)) { -// val (Some(egraph), reification) = queue.applyImmutable(EGraph.empty[Int], Map.empty, ParallelMap.sequential) -// -// assert(egraph.classes.size == 2) -// assert(egraph.areSame(a.reify(reification), b.reify(reification))) -// assert(egraph.areSame(c.reify(reification), d.reify(reification))) -// } -// } -// -// @Test -// def constructTreeWithSlots(): Unit = { -// val builder = new CommandQueueBuilder[Int] -// -// val x = Slot.fresh() -// val y = Slot.fresh() -// -// val tree1 = builder.add(MixedTree.Node(2, Seq.empty, Seq(y), Seq.empty)) -// val tree2 = builder.add( -// MixedTree.Node(0, Seq(x), Seq.empty, Seq(MixedTree.Node(1, Seq.empty, Seq(x), Seq(MixedTree.Atom(tree1)))))) -// -// for (queue <- Seq(builder.result(), builder.result().optimized)) { -// val (Some(egraph), reification) = queue.applyImmutable(EGraph.empty[Int], Map.empty, ParallelMap.sequential) -// -// assert(egraph.classes.size == 3) -// assert(tree1.reify(reification).args.valueSet == Set(y)) -// assert(tree2.reify(reification).args.valueSet == Set(y)) -// -// assert(!egraph.contains( -// MixedTree.Node(0, Seq(x), Seq.empty, Seq( -// MixedTree.Node(1, Seq.empty, Seq(y), Seq(MixedTree.Atom(tree1.reify(reification)))))))) -// } -// } -//} -// diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala new file mode 100644 index 00000000..91cddcf8 --- /dev/null +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala @@ -0,0 +1,169 @@ +package foresight.eqsat.commands + +import foresight.eqsat.collections.SlotSeq +import foresight.eqsat.parallel.ParallelMap +import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol} +import foresight.eqsat.immutable.EGraph +import org.junit.Test + +import scala.collection.compat.immutable.ArraySeq + +class CommandScheduleBuilderTest { + /** + * An empty schedule is a no-op. + */ + @Test + def emptyScheduleDoesNothing(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val egraph = EGraph.empty[Int] + + val schedule = builder.result() + assert(schedule.additions.isEmpty) + assert(schedule.unions.isEmpty) + assert(schedule.applyImmutable(egraph, ParallelMap.sequential).isEmpty) + } + + /** + * Adding a concrete node in batch 0 inserts one class. + */ + @Test + def addConcreteNodeInBatchZero(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val egraph = EGraph.empty[Int] + + val enode = ENode(0, Seq.empty, Seq.empty, Seq.empty) + val sym: EClassSymbol.Virtual = builder.add(enode, 0) + + val schedule = builder.result() + // Batch structure + assert(schedule.additions.size == 1) + val (symbols0, nodes0) = schedule.additions.head + assert(nodes0.nonEmpty) + assert(schedule.batchZero._2.head == enode) + assert(schedule.batchZero._1.head == sym) + + val e2 = schedule.applyImmutable(egraph, ParallelMap.sequential) + assert(e2.nonEmpty) + val g2 = e2.get + assert(g2.classes.size == 1) + } + + /** + * Add symbolic node to a later batch that depends on a batch-0 symbol. + * Ensures reification is wired and batch order respected. + */ + @Test + def addDependentSymbolicNodeInBatchOne(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val egraph = EGraph.empty[Int] + + val a = builder.add(ENode(0, Seq.empty, Seq.empty, Seq.empty), 0) + val n1 = ENodeSymbol(1, SlotSeq.empty, SlotSeq.empty, ArraySeq(a)) + val b = builder.add(n1, 1) + + val schedule = builder.result() + // Two batches: 0 and 1 + assert(schedule.additions.size == 2) + assert(schedule.batchZero._1.contains(a)) + assert(schedule.otherBatches.head._1.contains(b)) + + val g2 = schedule.applyImmutable(egraph, ParallelMap.sequential).get + assert(g2.classes.size == 2) + } + + /** + * Unions execute after all additions; no-op when already same. + */ + @Test + def unionsExecuteAndNoOpWhenSame(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val e0 = EGraph.empty[Int] + + val a = builder.add(ENode(0, Seq.empty, Seq.empty, Seq.empty), 0) + val b = builder.add(ENode(1, Seq.empty, Seq.empty, Seq.empty), 0) + + builder.union(a, b) + + val schedule = builder.result() + val g1 = schedule.applyImmutable(e0, ParallelMap.sequential).get + assert(g1.classes.size == 1) + + // Apply the same schedule again: already united → no changes + assert(schedule.applyImmutable(g1, ParallelMap.sequential).isEmpty) + } + + /** + * Batch index correctness: all direct parents must have strictly higher batch index than their children. + */ + @Test + def batchIndexParentsGreaterThanChildren(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + + // Build a small diamond: a (batch 0) -> b,c (batch 1) -> d (batch 2) + val a = builder.add(ENode(10, Seq.empty, Seq.empty, Seq.empty), 0) + val b = builder.add(ENodeSymbol(20, SlotSeq.empty, SlotSeq.empty, ArraySeq(a)), 1) + val c = builder.add(ENodeSymbol(21, SlotSeq.empty, SlotSeq.empty, ArraySeq(a)), 1) + val d = builder.add(ENodeSymbol(30, SlotSeq.empty, SlotSeq.empty, ArraySeq(b, c)), 2) + + val schedule = builder.result() + + // Validate batches and membership + assert(schedule.additions.size == 3) + assert(schedule.batchZero._1 == ArraySeq(a)) + assert(schedule.otherBatches(0)._1 == ArraySeq(b, c)) + assert(schedule.otherBatches(1)._1 == ArraySeq(d)) + } + + /** + * Mixed batches: verify that adding the same ENode twice yields no change on the second application, + * and that batch indices are preserved. + */ + @Test + def idempotentApplyAndStableBatches(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val g0 = EGraph.empty[Int] + + val a = builder.add(ENode(1, Seq.empty, Seq.empty, Seq.empty), 0) + val b = builder.add(ENodeSymbol(2, SlotSeq.empty, SlotSeq.empty, ArraySeq(a)), 1) + + val schedule = builder.result() + + val g1 = schedule.applyImmutable(g0, ParallelMap.sequential) + assert(g1.nonEmpty) + + // Re-apply -> no changes + val g2 = schedule.applyImmutable(g1.get, ParallelMap.sequential) + assert(g2.isEmpty) + + // Batches unchanged + assert(schedule.batchZero._1 == ArraySeq(a)) + assert(schedule.otherBatches.head._1 == ArraySeq(b)) + } + + /** + * When using unions with a symbolic and a concrete reference, ensure the union is scheduled and applied. + */ + @Test + def unionWithVirtualAndReal(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val g0 = EGraph.empty[Int] + + val a = builder.add(ENode(7, Seq.empty, Seq.empty, Seq.empty), 0) + val scheduleBeforeUnion = builder.result() + val g1 = scheduleBeforeUnion.applyImmutable(g0, ParallelMap.sequential).get + + // Real call obtained from g1 + val realA: EClassCall = g1.canonicalize(g1.classes.head) + val realSym: EClassSymbol = EClassSymbol.real(realA) + + // Add a new virtual and union it with the real one + val b = CommandScheduleBuilder.newConcurrentBuilder[Int] + val v = b.add(ENode(8, Seq.empty, Seq.empty, Seq.empty), 0) + b.union(v, realSym) + + val sched2 = b.result() + val g2 = sched2.applyImmutable(g1, ParallelMap.sequential).get + + assert(g2.classes.size == 1) + } +} From cd2804722a6d1883792eee04dbf7562a3da86af5 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:19:39 -0500 Subject: [PATCH 14/33] Add tests for addSimplifiedReal method in CommandScheduleBuilder --- .../commands/CommandScheduleBuilderTest.scala | 107 +++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala index 91cddcf8..ef42317e 100644 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala @@ -2,7 +2,7 @@ package foresight.eqsat.commands import foresight.eqsat.collections.SlotSeq import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol} +import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} import foresight.eqsat.immutable.EGraph import org.junit.Test @@ -166,4 +166,109 @@ class CommandScheduleBuilderTest { assert(g2.classes.size == 1) } + + /** + * addSimplifiedReal: Atom path should return the same real symbol and schedule nothing. + */ + @Test + def addSimplifiedRealAtomNoop(): Unit = { + val baseBuilder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val g0 = EGraph.empty[Int] + + // Create one real class in the graph so we can obtain a real EClassCall. + val seed = CommandScheduleBuilder.newConcurrentBuilder[Int] + val v = seed.add(ENode(42, Seq.empty, Seq.empty, Seq.empty), 0) + val sched = seed.result() + val g1 = sched.applyImmutable(g0, ParallelMap.sequential).get + + // Get a real call referring to the inserted class. + val realCall: EClassCall = g1.canonicalize(g1.classes.head) + + // Use addSimplifiedReal on an Atom (real call). + val sym = baseBuilder.addSimplifiedReal( + MixedTree.Atom[Int, EClassCall](realCall), + g1 + ) + + // It should return the same real symbol and schedule nothing. + assert(sym == EClassSymbol.real(realCall)) + val out = baseBuilder.result() + assert(out.additions.isEmpty) + assert(out.unions.isEmpty) + } + + /** + * addSimplifiedReal: Single node with no children should schedule exactly one addition. + */ + @Test + def addSimplifiedRealSingleNodeAddsOne(): Unit = { + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + val g0 = EGraph.empty[Int] + + // Build a trivial 1-node tree: Node(7, [], [], []) + val tree = MixedTree.Node[Int, EClassCall](7, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty[MixedTree[Int, EClassCall]]) + val sym = builder.addSimplifiedReal(tree, g0) + + // We only care that exactly one node gets scheduled and applying the schedule grows the graph by 1. + val sched = builder.result() + assert(sched.additions.size == 1) + + val g1opt = sched.applyImmutable(g0, ParallelMap.sequential) + assert(g1opt.nonEmpty) + val g1 = g1opt.get + assert(g1.classes.nonEmpty) + + // The returned symbol should be virtual and present in the first (and only) batch. + val (syms0, nodes0) = sched.additions.head + assert(syms0.contains(sym)) + assert(nodes0.nonEmpty) + } + + /** + * addSimplifiedReal: Nested tree should create children first (earlier batch), then parent (later batch). + */ + @Test + def addSimplifiedRealNestedBatchesIncrease(): Unit = { + val g0 = EGraph.empty[Int] + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + + // Seed a real leaf in the graph to use as an Atom in the child. + val seed = CommandScheduleBuilder.newConcurrentBuilder[Int] + val leafV = seed.add(ENode(1, Seq.empty, Seq.empty, Seq.empty), 0) + val seedSched = seed.result() + val g1 = seedSched.applyImmutable(g0, ParallelMap.sequential).get + val realLeaf: EClassCall = g1.canonicalize(g1.classes.head) + + // child = Node(10, [], [], [ Atom(realLeaf) ]) + val child = MixedTree.Node[Int, EClassCall]( + 10, + SlotSeq.empty, + SlotSeq.empty, + ArraySeq(MixedTree.Atom[Int, EClassCall](realLeaf)) + ) + // parent = Node(20, [], [], [ child ]) + val parent = MixedTree.Node[Int, EClassCall]( + 20, + SlotSeq.empty, + SlotSeq.empty, + ArraySeq(child) + ) + + // Instantiate child first to get its symbol and ensure batch ordering is visible. + val childSym = builder.addSimplifiedReal(child, g1) + val parentSym = builder.addSimplifiedReal(parent, g1) + + val sched = builder.result() + // Expect two distinct non-empty batches for additions beyond batch 0. + assert(sched.additions.size == 2) + assert(sched.otherBatches.length == 1) + + // First batch should contain the child, second batch the parent. + assert(sched.batchZero._1.contains(childSym)) + assert(sched.otherBatches(0)._1.contains(parentSym)) + + // Applying the schedule should add exactly two new classes. + val g2 = sched.applyImmutable(g1, ParallelMap.sequential).get + assert(g2.classes.size == g1.classes.size + 2) + } } From 1fec937bb20483c7bcbb70a19ff6cd223e8e0236 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:21:00 -0500 Subject: [PATCH 15/33] Add test for addSimplifiedReal method to ensure batch layering stability in complex trees --- .../commands/CommandScheduleBuilderTest.scala | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala index ef42317e..956bde33 100644 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala @@ -271,4 +271,90 @@ class CommandScheduleBuilderTest { val g2 = sched.applyImmutable(g1, ParallelMap.sequential).get assert(g2.classes.size == g1.classes.size + 2) } + + + /** + * addSimplifiedReal: Complex tree should preserve strict batch layering across depths. + * Level 1 (children of real atoms) -> Level 2 (parents of level 1) -> Level 3 (root). + */ + @Test + def addSimplifiedRealComplexTreeBatchesStable(): Unit = { + val g0 = EGraph.empty[Int] + val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] + + // Seed two real leaves in the graph to use as Atoms. + val seed = CommandScheduleBuilder.newConcurrentBuilder[Int] + val leafA = seed.add(ENode(101, Seq.empty, Seq.empty, Seq.empty), 0) + val leafB = seed.add(ENode(102, Seq.empty, Seq.empty, Seq.empty), 0) + val seedSched = seed.result() + val g1 = seedSched.applyImmutable(g0, ParallelMap.sequential).get + val realA: EClassCall = g1.canonicalize(g1.classes.head) + val realB: EClassCall = { + // the second class; order is not guaranteed, so find the one that's not realA + val calls = g1.classes.map(g1.canonicalize) + calls.find(_ != realA).get + } + + // Level 1 nodes (directly depend on real atoms) + val child1 = MixedTree.Node[Int, EClassCall]( + 201, + SlotSeq.empty, + SlotSeq.empty, + ArraySeq(MixedTree.Atom[Int, EClassCall](realA)) + ) + val child2 = MixedTree.Node[Int, EClassCall]( + 202, + SlotSeq.empty, + SlotSeq.empty, + ArraySeq(MixedTree.Atom[Int, EClassCall](realA), MixedTree.Atom[Int, EClassCall](realB)) + ) + val childSym1 = builder.addSimplifiedReal(child1, g1) + val childSym2 = builder.addSimplifiedReal(child2, g1) + + // Level 2 nodes (depend on Level 1) + val parent1 = MixedTree.Node[Int, EClassCall]( + 301, + SlotSeq.empty, + SlotSeq.empty, + ArraySeq(child1, MixedTree.Atom[Int, EClassCall](realB)) + ) + val parent2 = MixedTree.Node[Int, EClassCall]( + 302, + SlotSeq.empty, + SlotSeq.empty, + ArraySeq(child2) + ) + val parentSym1 = builder.addSimplifiedReal(parent1, g1) + val parentSym2 = builder.addSimplifiedReal(parent2, g1) + + // Level 3 node (root depends on both Level 2 nodes) + val root = MixedTree.Node[Int, EClassCall]( + 401, + SlotSeq.empty, + SlotSeq.empty, + ArraySeq(parent1, parent2) + ) + val rootSym = builder.addSimplifiedReal(root, g1) + + val sched = builder.result() + + // Expect three distinct non-empty batches for additions: L1, L2, L3 + assert(sched.additions.size == 3) + assert(sched.otherBatches.length == 2) + + // Level 1 batch (batchZero) should contain both children + assert(sched.batchZero._1.contains(childSym1)) + assert(sched.batchZero._1.contains(childSym2)) + + // Level 2 batch should contain both parents + assert(sched.otherBatches(0)._1.contains(parentSym1)) + assert(sched.otherBatches(0)._1.contains(parentSym2)) + + // Level 3 batch should contain the root + assert(sched.otherBatches(1)._1.contains(rootSym)) + + // Applying the schedule should add exactly 5 new classes (2 L1 + 2 L2 + 1 L3) + val g2 = sched.applyImmutable(g1, ParallelMap.sequential).get + assert(g2.classes.size == g1.classes.size + 5) + } } From ff50e1a49029a21c20bac84b29c585813db91a4b Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:22:28 -0500 Subject: [PATCH 16/33] Refactor addSimplifiedReal method documentation for clarity and consistency --- .../commands/CommandScheduleBuilderTest.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala index 956bde33..63c4d141 100644 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala @@ -93,7 +93,7 @@ class CommandScheduleBuilderTest { } /** - * Batch index correctness: all direct parents must have strictly higher batch index than their children. + * Batch index correctness: all direct parents have strictly higher batch index than their children. */ @Test def batchIndexParentsGreaterThanChildren(): Unit = { @@ -168,7 +168,7 @@ class CommandScheduleBuilderTest { } /** - * addSimplifiedReal: Atom path should return the same real symbol and schedule nothing. + * addSimplifiedReal: Atom path returns the same real symbol and schedules nothing. */ @Test def addSimplifiedRealAtomNoop(): Unit = { @@ -190,7 +190,7 @@ class CommandScheduleBuilderTest { g1 ) - // It should return the same real symbol and schedule nothing. + // It returns the same real symbol and schedules nothing. assert(sym == EClassSymbol.real(realCall)) val out = baseBuilder.result() assert(out.additions.isEmpty) @@ -198,7 +198,7 @@ class CommandScheduleBuilderTest { } /** - * addSimplifiedReal: Single node with no children should schedule exactly one addition. + * addSimplifiedReal: Single node with no children schedules exactly one addition. */ @Test def addSimplifiedRealSingleNodeAddsOne(): Unit = { @@ -218,14 +218,14 @@ class CommandScheduleBuilderTest { val g1 = g1opt.get assert(g1.classes.nonEmpty) - // The returned symbol should be virtual and present in the first (and only) batch. + // The returned symbol is virtual and present in the first (and only) batch. val (syms0, nodes0) = sched.additions.head assert(syms0.contains(sym)) assert(nodes0.nonEmpty) } /** - * addSimplifiedReal: Nested tree should create children first (earlier batch), then parent (later batch). + * addSimplifiedReal: Nested tree creates children first (earlier batch), then parent (later batch). */ @Test def addSimplifiedRealNestedBatchesIncrease(): Unit = { @@ -263,18 +263,18 @@ class CommandScheduleBuilderTest { assert(sched.additions.size == 2) assert(sched.otherBatches.length == 1) - // First batch should contain the child, second batch the parent. + // First batch contains the child, second batch the parent. assert(sched.batchZero._1.contains(childSym)) assert(sched.otherBatches(0)._1.contains(parentSym)) - // Applying the schedule should add exactly two new classes. + // Applying the schedule adds exactly two new classes. val g2 = sched.applyImmutable(g1, ParallelMap.sequential).get assert(g2.classes.size == g1.classes.size + 2) } /** - * addSimplifiedReal: Complex tree should preserve strict batch layering across depths. + * addSimplifiedReal: Complex tree preserves strict batch layering across depths. * Level 1 (children of real atoms) -> Level 2 (parents of level 1) -> Level 3 (root). */ @Test @@ -342,18 +342,18 @@ class CommandScheduleBuilderTest { assert(sched.additions.size == 3) assert(sched.otherBatches.length == 2) - // Level 1 batch (batchZero) should contain both children + // Level 1 batch (batchZero) contains both children assert(sched.batchZero._1.contains(childSym1)) assert(sched.batchZero._1.contains(childSym2)) - // Level 2 batch should contain both parents + // Level 2 batch contains both parents assert(sched.otherBatches(0)._1.contains(parentSym1)) assert(sched.otherBatches(0)._1.contains(parentSym2)) - // Level 3 batch should contain the root + // Level 3 batch contains the root assert(sched.otherBatches(1)._1.contains(rootSym)) - // Applying the schedule should add exactly 5 new classes (2 L1 + 2 L2 + 1 L3) + // Applying the schedule adds exactly 5 new classes (2 L1 + 2 L2 + 1 L3) val g2 = sched.applyImmutable(g1, ParallelMap.sequential).get assert(g2.classes.size == g1.classes.size + 5) } From 91aaa340d18192972d78c100fb2c06c190097f9f Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:22:58 -0500 Subject: [PATCH 17/33] Rename test method for addSimplifiedReal to improve clarity --- .../foresight/eqsat/commands/CommandScheduleBuilderTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala index 63c4d141..5e18370e 100644 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala @@ -278,7 +278,7 @@ class CommandScheduleBuilderTest { * Level 1 (children of real atoms) -> Level 2 (parents of level 1) -> Level 3 (root). */ @Test - def addSimplifiedRealComplexTreeBatchesStable(): Unit = { + def addSimplifiedRealComplexTree(): Unit = { val g0 = EGraph.empty[Int] val builder = CommandScheduleBuilder.newConcurrentBuilder[Int] From fbe7d1dd39c1ccfd9ca05fe3cc59ed69acb2ac24 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:25:18 -0500 Subject: [PATCH 18/33] Define custom IntRef for cross-version compatibility --- .../foresight/eqsat/commands/CommandScheduleBuilder.scala | 1 - foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala | 3 +++ .../foresight/eqsat/rewriting/patterns/PatternApplier.scala | 3 +-- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index d62ab9c0..80aa2079 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -7,7 +7,6 @@ import foresight.util.Debug import foresight.util.collections.UnsafeSeqFromArray import scala.collection.compat.immutable.ArraySeq -import scala.runtime.IntRef /** * Constructs commands for later execution. Commands are scheduled in batches. diff --git a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala new file mode 100644 index 00000000..8aecbe98 --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala @@ -0,0 +1,3 @@ +package foresight.eqsat.commands + +private[eqsat] class IntRef(var elem: Int) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index b8b4ec3c..84845a9a 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -1,13 +1,12 @@ package foresight.eqsat.rewriting.patterns import foresight.eqsat.collections.SlotSeq -import foresight.eqsat.commands.CommandScheduleBuilder +import foresight.eqsat.commands.{CommandScheduleBuilder, IntRef} import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.{ReversibleApplier, Searcher} import foresight.eqsat.{EClassSymbol, MixedTree, Slot} import scala.collection.compat.immutable.ArraySeq -import scala.runtime.IntRef /** * An applier that applies a pattern match to an e-graph. From 2daf835ae3d8d3461745a236bb98ef4e816d41d9 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:26:05 -0500 Subject: [PATCH 19/33] Replace IntRef instantiation with new keyword for consistency --- .../foresight/eqsat/commands/CommandScheduleBuilder.scala | 4 ++-- .../foresight/eqsat/rewriting/patterns/PatternApplier.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index 80aa2079..a8c9fbf2 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -83,7 +83,7 @@ trait CommandScheduleBuilder[NodeT] { private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], egraph: EGraph[NodeT]): EClassSymbol = { - val maxBatch = IntRef(0) + val maxBatch = new IntRef(0) addSimplifiedReal(tree, egraph, maxBatch) } @@ -93,7 +93,7 @@ trait CommandScheduleBuilder[NodeT] { tree match { case MixedTree.Node(t, defs, uses, args) => // Local accumulator for children of this node. - val childMax = IntRef(0) + val childMax = new IntRef(0) val argSymbols = CommandScheduleBuilder.symbolArrayFrom( args, childMax, diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index 84845a9a..91aedae9 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -92,7 +92,7 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT uses: SlotSeq, args: ArraySeq[MixedTree[NodeT, Pattern.Var]], maxBatch: IntRef): EClassSymbol = { - val argMaxBatch = IntRef(0) + val argMaxBatch = new IntRef(0) val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, instantiate) val useSymbols = uses.map(m.apply: Slot => Slot) val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph) From 448ff74a1b86760c206acea26689d03d3d4e2fda Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:26:42 -0500 Subject: [PATCH 20/33] Use new keyword for IdentityHashMap instantiation for consistency --- .../main/scala/foresight/eqsat/commands/CommandSchedule.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala index 0730646e..6f2e25a7 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandSchedule.scala @@ -97,7 +97,7 @@ final case class CommandSchedule[NodeT](batchZero: (ArraySeq[EClassSymbol.Virtua def apply(egraph: mutable.EGraph[NodeT], parallelize: ParallelMap): Boolean = { - val reification = util.IdentityHashMap[EClassSymbol.Virtual, foresight.eqsat.EClassCall]() + val reification = new util.IdentityHashMap[EClassSymbol.Virtual, foresight.eqsat.EClassCall]() var anyChanges: Boolean = false anyChanges = anyChanges | applyBatchZero(egraph, parallelize, reification) From 66bf25a75b84a9d396e6da90a22ab937419898f5 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:29:27 -0500 Subject: [PATCH 21/33] Refactor import statements in EGraphWithMetadata for consistency --- .../scala/foresight/eqsat/immutable/EGraphWithMetadata.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala index 0e71e614..64098fe9 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithMetadata.scala @@ -1,7 +1,6 @@ package foresight.eqsat.immutable -import foresight.eqsat.* -import foresight.eqsat.readonly +import foresight.eqsat.{AddNodeResult, EClassCall, ENode, readonly} import foresight.eqsat.metadata.Analysis import foresight.eqsat.parallel.ParallelMap import foresight.util.collections.StrictMapOps.toStrictMapOps From 90eaf583f2501acf93ffc0a2a910fda5d671f23b Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:30:27 -0500 Subject: [PATCH 22/33] Use Function interface for computeIfAbsent in ConcurrentCommandScheduleBuilder --- .../eqsat/commands/ConcurrentCommandScheduleBuilder.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala index e2904f4c..0e8cf0de 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/ConcurrentCommandScheduleBuilder.scala @@ -3,6 +3,7 @@ package foresight.eqsat.commands import foresight.eqsat.{EClassSymbol, ENode, ENodeSymbol} import foresight.util.collections.UnsafeSeqFromArray +import java.util.function.Function import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import scala.collection.compat.immutable.ArraySeq import scala.reflect.ClassTag @@ -19,7 +20,11 @@ private[commands] class ConcurrentCommandScheduleBuilder[NodeT] extends CommandS case _ => throw new IllegalArgumentException("Only ENode instances are allowed in batch 0") } } else { - val queue = otherBatchAdds.computeIfAbsent(batch, _ => new ConcurrentLinkedQueue()) + val queue = otherBatchAdds.computeIfAbsent(batch, new Function[Int, ConcurrentLinkedQueue[(EClassSymbol.Virtual, ENodeSymbol[NodeT])]] { + override def apply(t: Int): ConcurrentLinkedQueue[(EClassSymbol.Virtual, ENodeSymbol[NodeT])] = { + new ConcurrentLinkedQueue() + } + }) queue.add((symbol, node)) } } From bfcadf6eb676531fcf7182415315eaff693d22f7 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:30:35 -0500 Subject: [PATCH 23/33] Refactor import statements in EGraphWithRoot --- .../main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala index 87d9dd07..22b86280 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala @@ -1,6 +1,7 @@ package foresight.eqsat.immutable -import foresight.eqsat.* +import foresight.eqsat.{EClassCall, ENode} +import foresight.eqsat.readonly import foresight.eqsat.parallel.ParallelMap import scala.collection.compat.immutable.ArraySeq From 2501ab0d2d8ae2b36e3ea0938d2c80f7e7870b65 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:32:04 -0500 Subject: [PATCH 24/33] Refactor import statements in EGraphWithRecordedApplications and EGraphWithRoot --- .../eqsat/immutable/EGraphWithRecordedApplications.scala | 3 ++- .../main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala index 4d749896..35026fb9 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRecordedApplications.scala @@ -1,6 +1,7 @@ package foresight.eqsat.immutable -import foresight.eqsat.* +import foresight.eqsat.{AddNodeResult, EClassCall, ENode} +import foresight.eqsat.readonly import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.rewriting.PortableMatch diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala index 22b86280..3df646a5 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphWithRoot.scala @@ -1,7 +1,6 @@ package foresight.eqsat.immutable -import foresight.eqsat.{EClassCall, ENode} -import foresight.eqsat.readonly +import foresight.eqsat.{AddNodeResult, EClassCall, ENode, MixedTree, readonly} import foresight.eqsat.parallel.ParallelMap import scala.collection.compat.immutable.ArraySeq From a2a94d8415e4dd0cda6297228ffb90b9d4de8b22 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:33:11 -0500 Subject: [PATCH 25/33] Refactor import statements in ApplierOps --- .../main/scala/foresight/eqsat/examples/arith/ApplierOps.scala | 2 +- .../main/scala/foresight/eqsat/examples/liar/ApplierOps.scala | 2 +- .../main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala index 375b7de1..afcaa6b5 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala @@ -5,7 +5,7 @@ import foresight.eqsat.extraction.ExtractionAnalysis import foresight.eqsat.readonly.{EGraph, EGraphWithMetadata} import foresight.eqsat.rewriting.Applier import foresight.eqsat.rewriting.patterns.{Pattern, PatternMatch} -import foresight.eqsat.* +import foresight.eqsat._ object ApplierOps { implicit class ApplierOfPatternMatchOps[EGraphT <: EGraph[ArithIR]](private val applier: Applier[ArithIR, PatternMatch[ArithIR], EGraphWithMetadata[ArithIR, EGraphT]]) extends AnyVal { diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala index e0a40ed4..353b4815 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala @@ -6,7 +6,7 @@ import foresight.eqsat.extraction.ExtractionAnalysis import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} import foresight.eqsat.rewriting.Applier import foresight.eqsat.rewriting.patterns.{Pattern, PatternApplier, PatternMatch} -import foresight.eqsat.* +import foresight.eqsat._ object ApplierOps { implicit class ApplierOfPatternMatchOps[EGraphT <: EGraphLike[ArrayIR, EGraphT] with EGraph[ArrayIR]](private val applier: Applier[ArrayIR, PatternMatch[ArrayIR], EGraphWithMetadata[ArrayIR, EGraphT]]) extends AnyVal { diff --git a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala index b0d8d235..e4af4fa1 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala @@ -5,7 +5,7 @@ import foresight.eqsat.extraction.ExtractionAnalysis import foresight.eqsat.rewriting.Applier import foresight.eqsat.rewriting.patterns.{Pattern, PatternMatch} import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} -import foresight.eqsat.* +import foresight.eqsat._ object ApplierOps { implicit class ApplierOfPatternMatchOps[EGraphT <: EGraphLike[SdqlIR, EGraphT] with EGraph[SdqlIR]](private val applier: Applier[SdqlIR, PatternMatch[SdqlIR], EGraphWithMetadata[SdqlIR, EGraphT]]) extends AnyVal { From 500718e6fef9703ea629c004b3b478898902a827 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:33:41 -0500 Subject: [PATCH 26/33] Make IntRef class final --- foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala index 8aecbe98..f0992cd6 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala @@ -1,3 +1,3 @@ package foresight.eqsat.commands -private[eqsat] class IntRef(var elem: Int) +private[eqsat] final class IntRef(var elem: Int) From a015f0eae0ce4bd56defe023b4ff22185b358705 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 21:57:22 -0500 Subject: [PATCH 27/33] Refactor processBlock to process stripes instead of contiguous blocks --- .../eqsat/parallel/ParallelMap.scala | 42 +++++++++---------- .../eqsat/parallel/ParallelMapTest.scala | 23 ---------- 2 files changed, 21 insertions(+), 44 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/parallel/ParallelMap.scala b/foresight/src/main/scala/foresight/eqsat/parallel/ParallelMap.scala index 6e700382..e5cbe194 100644 --- a/foresight/src/main/scala/foresight/eqsat/parallel/ParallelMap.scala +++ b/foresight/src/main/scala/foresight/eqsat/parallel/ParallelMap.scala @@ -150,32 +150,31 @@ trait ParallelMap { * Processes a single block of elements from an array sequence. * @param inputs The input array sequence. * @param blockIdx The index of the block to process. - * @param blockSize The size of each block. + * @param stride The stride between elements in the block. * @param f The function to apply to each element. * @tparam A The type of elements in the input array sequence. */ - protected def processBlock[A](inputs: ArraySeq[A], blockIdx: Int, blockSize: Int, f: A => Unit): Unit = { - val start = blockIdx * blockSize - val end = math.min(start + blockSize, inputs.length) - var i = start - while (i < end) { + protected def processBlock[A](inputs: ArraySeq[A], blockIdx: Int, stride: Int, f: A => Unit): Unit = { + val len = inputs.length + var i = blockIdx + while (i < len) { f(inputs(i)) - i += 1 + i += stride } } /** - * Applies a function to each element of an array sequence in blocks. Blocks control the - * granularity of parallelism: each block is processed sequentially, while different blocks - * may be processed in parallel. + * Applies a function to each element of an array sequence using **stripes**. + * A stripe processes every `numStripes`-th element starting at a given offset: + * - Stripe `s` processes indices `s, s + numStripes, s + 2*numStripes, ...`. * - * - Each block is processed in input order (sequentially within the block). - * - Blocks may execute concurrently, and their relative order is not guaranteed. + * - Each stripe is processed in input order relative to its own indices (sequential within the stripe). + * - Stripes may execute concurrently; their relative completion order is not guaranteed. * - If wrapped by a [[cancelable]] strategy, cancellation is checked before and during - * processing (between blocks and between elements within a block). + * processing (between stripes and between elements within a stripe). * * @param inputs The input array sequence. - * @param blockSize The size of each block to process. Must be positive. + * @param blockSize The number of stripes (acts as the stride). Must be positive. * @param f The function to apply to each element. * @tparam A The type of elements in the input array sequence. * @throws IllegalArgumentException if `blockSize <= 0` @@ -185,18 +184,19 @@ trait ParallelMap { throw new IllegalArgumentException(s"blockSize must be positive, got $blockSize") } - val numBlocks = ((inputs.length.toLong + blockSize - 1) / blockSize).toInt - if (numBlocks == 0) { - // No blocks: nothing to do + if (inputs.isEmpty) { return } + // Cap stripes to the input length to avoid launching empty work + val numBlocks = ((inputs.length.toLong + blockSize - 1) / blockSize).toInt + if (numBlocks == 1) { - // Single block: process sequentially to avoid overhead - processBlock(inputs, 0, blockSize, f) + // Single stripe: process sequentially to avoid overhead + processBlock(inputs, 0, 1, f) } else { - // Multiple blocks: process in parallel - apply[Int, Unit](0 until numBlocks, processBlock(inputs, _, blockSize, f)) + // Multiple stripes: process in parallel + apply[Int, Unit](0 until numBlocks, processBlock(inputs, _, numBlocks, f)) } } diff --git a/foresight/src/test/scala/foresight/eqsat/parallel/ParallelMapTest.scala b/foresight/src/test/scala/foresight/eqsat/parallel/ParallelMapTest.scala index 9b8b8bcd..accd12cd 100644 --- a/foresight/src/test/scala/foresight/eqsat/parallel/ParallelMapTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/parallel/ParallelMapTest.scala @@ -201,29 +201,6 @@ class ParallelMapTest { } } - /** - * processBlocks: within each block, element order is preserved even if blocks interleave. - */ - @Test - def processBlocksPerBlockOrderPreserved(): Unit = { - for (impl <- implementations) { - val n = 25 - val blockSize = 6 // creates 5 blocks: [0..5], [6..11], [12..17], [18..23], [24] - val inputs = ArraySeq.unsafeWrapArray((0 until n).toArray) - val seen = new ArrayBuffer[Int]() - impl.processBlocks[Int](inputs, blockSize, i => seen.synchronized { seen += i }) - - // Check: for every block, the subsequence of seen that belongs to the block is increasing. - val numBlocks = (n + blockSize - 1) / blockSize - for (b <- 0 until numBlocks) { - val start = b * blockSize - val end = math.min(start + blockSize, n) - val subseq = seen.filter(i => i >= start && i < end) - assert(subseq == subseq.sorted, s"Elements within block $b not in order: $subseq") - } - } - } - /** * processBlocks: blockSize == 1 is valid and processes all elements. */ From ba8a360a5f11851d996660b5de6be778ed46fbb0 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 22:11:40 -0500 Subject: [PATCH 28/33] Enable search loop interchange --- .../scala/foresight/eqsat/saturation/SearchAndApply.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala index 80b3ee4a..4cdf3e10 100644 --- a/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala +++ b/foresight/src/main/scala/foresight/eqsat/saturation/SearchAndApply.scala @@ -170,7 +170,7 @@ object SearchAndApply { MatchT ]: SearchAndApply[NodeT, Rewrite[NodeT, MatchT, EGraphT], EGraphT, MatchT] = { new NoMatchCaching[NodeT, EGraphT, MatchT] { - override def searchLoopInterchange: Boolean = false + override def searchLoopInterchange: Boolean = true override def update(command: CommandSchedule[NodeT], matches: Map[String, Seq[MatchT]], @@ -195,7 +195,7 @@ object SearchAndApply { MatchT ]: SearchAndApply[NodeT, Rewrite[NodeT, MatchT, EGraphT], EGraphT, MatchT] = { new NoMatchCaching[NodeT, EGraphT, MatchT] { - override def searchLoopInterchange: Boolean = false + override def searchLoopInterchange: Boolean = true override def update(command: CommandSchedule[NodeT], matches: Map[String, Seq[MatchT]], From 2d6239dff1c7f32d018562a01e0707ab7514c774 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 22:37:17 -0500 Subject: [PATCH 29/33] Enhance users method to handle nodes without slots --- .../eqsat/hashCons/ReadOnlyHashConsEGraph.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/hashCons/ReadOnlyHashConsEGraph.scala b/foresight/src/main/scala/foresight/eqsat/hashCons/ReadOnlyHashConsEGraph.scala index e8a98a04..83424399 100644 --- a/foresight/src/main/scala/foresight/eqsat/hashCons/ReadOnlyHashConsEGraph.scala +++ b/foresight/src/main/scala/foresight/eqsat/hashCons/ReadOnlyHashConsEGraph.scala @@ -196,9 +196,13 @@ private[hashCons] trait ReadOnlyHashConsEGraph[NodeT] extends EGraph[NodeT] { final override def users(ref: EClassRef): Iterable[ENode[NodeT]] = { val canonicalApp = canonicalize(ref) dataForClass(canonicalApp.ref).users.map(node => { - val c = nodeToRef(node) - val mapping = dataForClass(c).nodes(node) - ShapeCall(node, mapping).asNode + if (node.hasSlots) { + val c = nodeToRef(node) + val mapping = dataForClass(c).nodes(node) + ShapeCall(node, mapping).asNode + } else { + node + } }) } From a82ee38260d7353f815a02797ac7d8492456e04d Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 22:37:45 -0500 Subject: [PATCH 30/33] Refactor compute method to implement worklist algorithm for node processing --- .../eqsat/mutable/AnalysisMetadata.scala | 94 +++++++++++++++++-- 1 file changed, 86 insertions(+), 8 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/mutable/AnalysisMetadata.scala b/foresight/src/main/scala/foresight/eqsat/mutable/AnalysisMetadata.scala index b389c798..64930071 100644 --- a/foresight/src/main/scala/foresight/eqsat/mutable/AnalysisMetadata.scala +++ b/foresight/src/main/scala/foresight/eqsat/mutable/AnalysisMetadata.scala @@ -113,19 +113,97 @@ object AnalysisMetadata { def compute[NodeT, A](analysis: Analysis[NodeT, A], egraph: EGraph[NodeT]): AnalysisMetadata[NodeT, A] = { val result = new AnalysisMetadata[NodeT, A](analysis) - val updater = new AnalysisUpdater(analysis, egraph, result.results) - // Seed: nodes with no arguments. - for (c <- egraph.classes) { - for (node <- egraph.nodes(egraph.canonicalize(c))) { - if (node.args.isEmpty) { - updater.update(c, analysis.make(node, Seq.empty)) + import scala.collection.mutable + + // Worklist of nodes that are READY (all arg classes have results). + val nodeQueue = new mutable.Queue[ENode[NodeT]]() + // Tracks nodes currently pending in the queue; cleared on dequeue to allow re-enqueue after deps improve. + val enqueued = new mutable.HashSet[ENode[NodeT]]() + + // Fast contains on class results. + @inline def hasClass(ref: EClassRef): Boolean = + result.results.contains(ref) + + // Read a class result (must exist). + @inline def getClass(ref: EClassRef): A = + result.results(ref) + + // Compute A for an EClassCall using already-available class result. + @inline def evalCall(call: EClassCall): A = + analysis.rename(getClass(call.ref), call.args) + + // Merge a nodeResult into its class; return true if the class improved. + def mergeIntoClass(ref: EClassRef, nodeResult: A): Boolean = { + result.results.get(ref) match { + case Some(old) => + val joined = analysis.join(old, nodeResult) + if (joined.asInstanceOf[AnyRef] ne old.asInstanceOf[AnyRef]) { // fast-path ref compare + if (joined != old) { // fallback equals in case A is value-based + result.results(ref) = joined + true + } else false + } else false + case None => + result.results(ref) = nodeResult + true + } + } + + // When a class improves, try to enqueue its user nodes that are now fully ready. + def onClassImproved(ref: EClassRef): Unit = { + val users = egraph.users(ref) + for (n <- users) { + if (!enqueued.contains(n)) { + // A node is ready if all its argument classes already have results. + val ready = { + var ok = true + var ai = 0 + val as = n.args + while (ok && ai < as.length) { + ok = hasClass(as(ai).ref) + ai += 1 + } + ok + } + if (ready) { + enqueued += n + nodeQueue.enqueue(n) + } } } } - // Propagate to a fixed point; eventually touches all e-nodes. - updater.processPending(initialized = false) + // Seed step: evaluate all nullary nodes and propagate their classes. + // Also opportunistically enqueue any users that become ready. + { + val classes = egraph.classes + for (c <- classes) { + val canon = egraph.canonicalize(c) + val nodes = egraph.nodes(canon) + for (n <- nodes) { + if (n.args.isEmpty) { + val nodeResult = analysis.make(n, Seq.empty) + if (mergeIntoClass(c, nodeResult)) { + onClassImproved(c) + } + } + } + } + } + + // Main loop: process ready nodes once. + while (nodeQueue.nonEmpty) { + val n = nodeQueue.dequeue() + // Mark as no longer pending so that future upstream improvements can re-enqueue it. + enqueued -= n + // All arg classes are available by construction. + val nodeResult = analysis.make(n, n.args.map(evalCall)) + val cls = egraph.find(n).get.ref + if (mergeIntoClass(cls, nodeResult)) { + onClassImproved(cls) + } + } result } From b2516feeb274c1b396b6f7d0e751b290d143dc3d Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 22:45:46 -0500 Subject: [PATCH 31/33] Implement readiness-driven worklist for node processing in AnalysisUpdater --- .../eqsat/readonly/AnalysisUpdater.scala | 83 ++++++++++++------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala index 6e9f4ddc..55e98b2e 100644 --- a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala +++ b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala @@ -16,7 +16,10 @@ import scala.collection.mutable private[eqsat] abstract class AnalysisUpdater[NodeT, A](analysis: Analysis[NodeT, A], egraph: EGraph[NodeT]) { - private val worklist = mutable.Set.empty[ENode[NodeT]] + // Readiness-driven worklist: only nodes whose argument classes already have results are queued. + private val nodeQueue = new mutable.Queue[ENode[NodeT]]() + // Tracks nodes currently pending in the queue to avoid duplicate enqueues while in flight. + private val enqueued = new mutable.HashSet[ENode[NodeT]]() /** * Computes the analysis result for an e-class reference. @@ -47,6 +50,30 @@ private[eqsat] abstract class AnalysisUpdater[NodeT, A](analysis: Analysis[NodeT */ def add(ref: EClassRef, result: A): Unit + /** Enqueue a node if all of its argument classes currently have results. */ + private def enqueueIfReady(node: ENode[NodeT]): Unit = { + if (!enqueued.contains(node)) { + // Node is ready iff all arg classes are present. + var ready = true + val as = node.args + var i = 0 + while (ready && i < as.length) { + ready = contains(as(i).ref) + i += 1 + } + if (ready) { + enqueued += node + nodeQueue.enqueue(node) + } + } + } + + /** After a class improves, some of its user nodes may now be ready; enqueue those. */ + private def onClassImproved(ref: EClassRef): Unit = { + val it = egraph.users(ref).iterator + while (it.hasNext) enqueueIfReady(it.next()) + } + /** * Updates the analysis result for an e-class. * @@ -58,7 +85,7 @@ private[eqsat] abstract class AnalysisUpdater[NodeT, A](analysis: Analysis[NodeT case Some(oldResult) if oldResult == result => () case _ => add(ref, result) - worklist ++= egraph.users(ref) + onClassImproved(ref) } } @@ -70,39 +97,31 @@ private[eqsat] abstract class AnalysisUpdater[NodeT, A](analysis: Analysis[NodeT final def apply(call: EClassCall): A = analysis.rename(apply(call.ref), call.args) /** - * Processes the worklist by applying the analysis to potentially updated e-nodes. - * @param initialized Whether the analysis has been initialized for the e-graph. Initialization means that each - * e-class has an analysis result. + * Processes the readiness-driven worklist. Nodes are enqueued exactly when their last dependency becomes available. */ final def processPending(initialized: Boolean = true): Unit = { - while (worklist.nonEmpty) { - // Group the worklist by the e-class of the e-node. - val worklistPerClass = worklist.groupBy(n => egraph.find(n).get.ref) - - // Clear the worklist so it can accept further updates. - worklist.clear() - - // Apply the analysis to the updates e-nodes in each e-class. - for ((ref, nodes) <- worklistPerClass) { - val init = get(ref) - if (initialized) { - assert(init.isDefined, s"Analysis not initialized for $ref") - } - - val result = nodes.foldLeft(init)((acc, node) => { - if (node.args.forall(arg => contains(arg.ref))) { - val args = node.args.map(apply) - val nodeResult = analysis.make(node, args) - acc match { - case None => Some(nodeResult) - case Some(result) => Some(analysis.join(result, nodeResult)) - } - } else { - acc + // Process nodes whose arguments are already available; new class improvements will enqueue more. + while (nodeQueue.nonEmpty) { + val node = nodeQueue.dequeue() + // Allow future re-enqueues if upstream classes improve again. + enqueued -= node + + // By construction, all args have results. + val args = node.args.map(apply) + + val nodeResult = analysis.make(node, args) + val ref = egraph.find(node).get.ref + + get(ref) match { + case Some(old) => + val joined = analysis.join(old, nodeResult) + if ((joined.asInstanceOf[AnyRef] ne old.asInstanceOf[AnyRef]) && joined != old) { + add(ref, joined) + onClassImproved(ref) } - }) - - result.foreach(update(ref, _)) + case None => + add(ref, nodeResult) + onClassImproved(ref) } } } From 800a4cd1746b2ecfb8319837b1fd6c1c26ad0f80 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 22:47:07 -0500 Subject: [PATCH 32/33] Remove unnecessary parameter from processPending method in AnalysisUpdater --- .../main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala index 55e98b2e..0411bbfd 100644 --- a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala +++ b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisUpdater.scala @@ -99,7 +99,7 @@ private[eqsat] abstract class AnalysisUpdater[NodeT, A](analysis: Analysis[NodeT /** * Processes the readiness-driven worklist. Nodes are enqueued exactly when their last dependency becomes available. */ - final def processPending(initialized: Boolean = true): Unit = { + final def processPending(): Unit = { // Process nodes whose arguments are already available; new class improvements will enqueue more. while (nodeQueue.nonEmpty) { val node = nodeQueue.dequeue() From c42c29c3a544aff27c7d1b0dfab50b0942a72177 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sat, 8 Nov 2025 22:50:31 -0500 Subject: [PATCH 33/33] Remove unnecessary argument from processPending call in AnalysisMetadata --- .../main/scala/foresight/eqsat/immutable/AnalysisMetadata.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/AnalysisMetadata.scala b/foresight/src/main/scala/foresight/eqsat/immutable/AnalysisMetadata.scala index 4a43d3ba..e517c7eb 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/AnalysisMetadata.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/AnalysisMetadata.scala @@ -149,7 +149,7 @@ object AnalysisMetadata { } // Propagate to a fixed point; eventually touches all e-nodes. - updater.processPending(initialized = false) + updater.processPending() AnalysisMetadata(analysis, updater.results) }