diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala index bc1d6b80..cffb09f1 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala @@ -32,7 +32,7 @@ object ApplierOps { def subst(tree: ArithExpr): ArithExpr = { tree match { - case Var(slot) if slot == m(from) => L.fromTree[EClassCall](m(to.variable)) + case Var(slot) if slot == m(from) => L.fromTree(m(to.variable)) case Var(slot) => Var(slot) case Lam(param, body) => Lam(param, subst(body)) case App(fun, arg) => App(subst(fun), subst(arg)) @@ -46,7 +46,7 @@ object ApplierOps { } val substituted = subst(extractedExpr) - val newMatch = m.copy(varMapping = m.varMapping + (destination.variable -> L.toTree[EClassCall](substituted))) + val newMatch = m.bind(destination.variable, L.toCallTree(substituted)) applier.apply(newMatch, egraph, builder) } } diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala index 4e4ac43a..871c2859 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala @@ -64,7 +64,7 @@ final case class Rules()(using L: Language[ArithExpr]) { result.toSeq.map { value => // If a constant value is found, create a new Number node and bind it to the variable, overwriting the // original binding. - subst.bind(x.variable, L.toTree(Number(value))) + subst.bind(x.variable, L.toCallTree(Number(value))) } }), L.toApplier(x) diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala b/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala index 50d0da50..66ecdae8 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala @@ -3,7 +3,7 @@ package foresight.eqsat.examples.mm import foresight.eqsat.lang.{Language, LanguageOp} import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.Rule -import foresight.eqsat.rewriting.patterns.PatternMatch +import foresight.eqsat.rewriting.patterns.AbstractPatternMatch import scala.language.implicitConversions @@ -13,17 +13,17 @@ import scala.language.implicitConversions final case class Rules()(using L: Language[LinalgExpr]) { type Op = LanguageOp[LinalgExpr] type LinalgEGraph = EGraph[LinalgIR] - type LinalgRule = Rule[LinalgIR, PatternMatch[LinalgIR], LinalgEGraph] + type LinalgRule = Rule[LinalgIR, AbstractPatternMatch[LinalgIR], LinalgEGraph] - import L.rule + import L.borrowingRule val matMulAssociativity1: LinalgRule = - rule("mul-associativity1") { (x, y, z) => + borrowingRule("mul-associativity1") { (x, y, z) => ((x * y) * z) -> (x * (y * z)) } val matMulAssociativity2: LinalgRule = - rule("mul-associativity2") { (x, y, z) => + borrowingRule("mul-associativity2") { (x, y, z) => (x * (y * z)) -> ((x * y) * z) } diff --git a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala index afcaa6b5..61f9e7f9 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala @@ -27,16 +27,16 @@ object ApplierOps { override def apply(m: PatternMatch[ArithIR], egraph: EGraphWithMetadata[ArithIR, EGraphT], builder: CommandScheduleBuilder[ArithIR]): Unit = { val extracted = ExtractionAnalysis.smallest[ArithIR].extractor[EGraphT](m(source), egraph) - def subst(tree: Tree[ArithIR]): MixedTree[ArithIR, EClassCall] = { + def subst(tree: Tree[ArithIR]): CallTree[ArithIR] = { tree match { case Tree(Var, Seq(), Seq(use), Seq()) if use == m(from) => m(to) case Tree(nodeType, defs, uses, args) => - MixedTree.Node(nodeType, defs, uses, args.map(subst)) + CallTree.Node(nodeType, defs, uses, args.map(subst)) } } val substituted = subst(extracted) - val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) + val newMatch = m.bind(destination, substituted) applier.apply(newMatch, egraph, builder) } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala b/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala index 795d60ea..1040603d 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala @@ -4,7 +4,7 @@ import foresight.eqsat.examples.arith.ApplierOps._ import foresight.eqsat.readonly.{EGraph, EGraphWithMetadata} import foresight.eqsat.rewriting.Rule import foresight.eqsat.rewriting.patterns.{Pattern, PatternMatch} -import foresight.eqsat.{MixedTree, Slot} +import foresight.eqsat.{CallTree, MixedTree, Slot} /** * This object contains a collection of rules for rewriting arithmetic expressions. @@ -63,7 +63,7 @@ object Rules { result.toSeq.map { value => // If a constant value is found, create a new Number node and bind it to the variable, overwriting the // original binding. - subst.bind(x, Number(value)) + subst.bind(x, CallTree.from(Number(value))) } }), MixedTree.Atom[ArithIR, Pattern.Var](x).toApplier diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala index 353b4815..aebc522e 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala @@ -28,23 +28,23 @@ object ApplierOps { override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT], builder: CommandScheduleBuilder[ArrayIR]): Unit = { val extracted = ExtractionAnalysis.smallest[ArrayIR].extractor[EGraphT](m(source), egraph) - def typeOf(tree: MixedTree[ArrayIR, EClassCall]): MixedTree[Type, EClassCall] = { - TypeInferenceAnalysis.get(egraph)(tree, egraph) + def typeOf(tree: CallTree[ArrayIR]): CallTree[Type] = { + CallTree.from(TypeInferenceAnalysis.get(egraph)(tree, egraph)) } - def subst(tree: Tree[ArrayIR]): MixedTree[ArrayIR, EClassCall] = { + def subst(tree: Tree[ArrayIR]): CallTree[ArrayIR] = { tree match { case Tree(Var, Seq(), Seq(use), Seq(fromType)) - if use == m(from) && typeOf(m(to)) == MixedTree.fromTree(fromType) => + if use == m(from) && typeOf(m(to)) == CallTree.from(fromType) => m(to) case Tree(nodeType, defs, uses, args) => - MixedTree.Node(nodeType, defs, uses, args.map(subst)) + CallTree.Node(nodeType, defs, uses, args.map(subst)) } } val substituted = subst(extracted) - val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) + val newMatch = m.bind(destination, substituted) applier.apply(newMatch, egraph, builder) } } @@ -66,7 +66,7 @@ object ApplierOps { val tree = applier.instantiate(m) val realTree = tree.mapAtoms(_.asInstanceOf[EClassCall]) inferType(realTree, egraph) - val c = builder.addSimplifiedReal(realTree, egraph) + val c = builder.addSimplifiedReal(CallTree.from(realTree), egraph) builder.unionSimplified(EClassSymbol.real(m.root), c, egraph) } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala index 8b2ddf20..fca010d6 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala @@ -4,7 +4,7 @@ import foresight.eqsat.examples.liar.TypeRequirements.RequirementsSearcherContin import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.rewriting.patterns.{CompiledPattern, Pattern, PatternMatch} import foresight.eqsat.rewriting.{Applier, ReversibleSearcher, Searcher} -import foresight.eqsat.MixedTree +import foresight.eqsat.{CallTree, EClassCall, MixedTree} import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} object SearcherOps { @@ -23,8 +23,8 @@ object SearcherOps { searcher.map((m, egraph) => { val newVarMapping = m.varMapping ++ types.map { case (value, t) => - val (call, newEGraph) = egraph.add(m(value)) - t -> TypeInferenceAnalysis.get(newEGraph)(call, newEGraph) + val (call, newEGraph) = egraph.add(m(value).toMixedTree) + t -> CallTree.from(TypeInferenceAnalysis.get(newEGraph)(call, newEGraph)) } PatternMatch(m.root, newVarMapping, m.slotMapping) }) @@ -53,8 +53,8 @@ object SearcherOps { searcher.filter((m, egraph) => { values.forall(v => { m(v) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType.isInstanceOf[Value] - case MixedTree.Node(nodeType, _, _, _) => nodeType.isInstanceOf[Value] + case c: EClassCall => egraph.nodes(c).head.nodeType.isInstanceOf[Value] + case CallTree.Node(nodeType, _, _, _) => nodeType.isInstanceOf[Value] } }) }) @@ -69,8 +69,8 @@ object SearcherOps { def requireNonFunctionType(t: Pattern.Var): Searcher[ArrayIR, PatternMatch[ArrayIR], EGraphT] = { searcher.filter((m, egraph) => { m(t) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType != FunctionType - case MixedTree.Node(nodeType, _, _, _) => nodeType != FunctionType + case c: EClassCall => egraph.nodes(c).head.nodeType != FunctionType + case CallTree.Node(nodeType, _, _, _) => nodeType != FunctionType } }) } @@ -83,8 +83,8 @@ object SearcherOps { def requireInt32Type(t: Pattern.Var): Searcher[ArrayIR, PatternMatch[ArrayIR], EGraphT] = { searcher.filter((m, egraph) => { m(t) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType == Int32Type - case MixedTree.Node(nodeType, _, _, _) => nodeType == Int32Type + case c: EClassCall => egraph.nodes(c).head.nodeType == Int32Type + case CallTree.Node(nodeType, _, _, _) => nodeType == Int32Type } }) } @@ -97,8 +97,8 @@ object SearcherOps { def requireDoubleType(t: Pattern.Var): Searcher[ArrayIR, PatternMatch[ArrayIR], EGraphT] = { searcher.filter((m, egraph) => { m(t) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType == DoubleType - case MixedTree.Node(nodeType, _, _, _) => nodeType == DoubleType + case c: EClassCall => egraph.nodes(c).head.nodeType == DoubleType + case CallTree.Node(nodeType, _, _, _) => nodeType == DoubleType } }) } diff --git a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala index e4af4fa1..bb89d51e 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala @@ -27,16 +27,16 @@ object ApplierOps { override def apply(m: PatternMatch[SdqlIR], egraph: EGraphWithMetadata[SdqlIR, EGraphT], builder: CommandScheduleBuilder[SdqlIR]): Unit = { val extracted = ExtractionAnalysis.smallest[SdqlIR].extractor[EGraphT](m(source), egraph) - def subst(tree: Tree[SdqlIR]): MixedTree[SdqlIR, EClassCall] = { + def subst(tree: Tree[SdqlIR]): CallTree[SdqlIR] = { tree match { case Tree(Var, Seq(), Seq(use), Seq()) if use == m(from) => m(to) case Tree(nodeType, defs, uses, args) => - MixedTree.Node(nodeType, defs, uses, args.map(subst)) + CallTree.Node(nodeType, defs, uses, args.map(subst)) } } val substituted = subst(extracted) - val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) + val newMatch = m.bind(destination, substituted) applier.apply(newMatch, egraph, builder) } } diff --git a/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala b/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala index d589bdbf..c3371166 100644 --- a/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala +++ b/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala @@ -1,8 +1,7 @@ package foresight.eqsat.examples.mm -import foresight.eqsat.examples.mm.{Fact, LinalgExpr, LinalgIR, Mat, Mul, Rules, *} import foresight.eqsat.lang._ -import foresight.eqsat.saturation.{MaximalRuleApplication, MaximalRuleApplicationWithCaching, Strategy} +import foresight.eqsat.saturation.{MaximalRuleApplication, Strategy} import foresight.eqsat.EClassCall import foresight.eqsat.immutable.EGraph import foresight.eqsat.mutable diff --git a/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala b/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala index cbc68550..c2ae8a53 100644 --- a/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala +++ b/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala @@ -1,8 +1,8 @@ package foresight.eqsat.lang import foresight.eqsat.extraction.ExtractionAnalysis -import foresight.eqsat.rewriting.patterns.{Pattern, PatternApplier, PatternMatch} -import foresight.eqsat.rewriting.{ReversibleSearcher, Rule} +import foresight.eqsat.rewriting.patterns.{AbstractPatternMatch, Pattern, PatternApplier, PatternMatch, StateBorrowingMachineEClassSearcher} +import foresight.eqsat.rewriting.{ReversibleSearcher, Rule, Searcher} import foresight.eqsat.* import foresight.eqsat.immutable import foresight.eqsat.mutable @@ -90,6 +90,35 @@ trait Language[E]: */ def fromTree[A](n: MixedTree[Op, A])(using dec: AtomDecoder[E, A]): E + /** + * Decode a call tree back into a surface AST `E`. + * + * @param n A call tree with [[Op]] internal nodes and [[EClassCall]] leaves. + * @param dec A given [[AtomDecoder]] that knows how to turn `EClassCall` atoms back into `E` fragments. + * @return The reconstructed surface expression. + * + * Note: decoding is typically partial in the presence of analysis-only atoms; see [[fromAnalysisNode]]. + */ + def fromTree(n: CallTree[Op])(using dec: AtomDecoder[E, EClassCall]): E = { + fromTree[EClassCall](n.toMixedTree) + } + + /** + * Encode a surface AST `e: E` into a call tree with [[EClassCall]] leaves. + * + * @param e A surface AST node. + * @param enc A given [[AtomEncoder]] that knows how to encode `E` atoms into `EClassCall`. + * @return A call tree equivalent to `e`. + * + * @example + * {{{ + * val callTree: CallTree[Lang.Op] = Lang.toCallTree(surfaceExpr) + * }}} + */ + def toCallTree(e: E)(using enc: AtomEncoder[E, EClassCall]): CallTree[Op] = { + CallTree.from(toTree[EClassCall](e)) + } + /** * Encode a surface AST `e: E` into an immutable e-graph, returning the root e-class call * and the new e-graph containing `e`. @@ -347,6 +376,19 @@ trait Language[E]: (using enc: AtomEncoder[E, Pattern.Var]): ReversibleSearcher[Op, PatternMatch[Op], EGraphT] = toTree(e).toSearcher + /** + * Build a borrowing searcher from a surface expression, by first encoding to a pattern tree. + * + * Requires an encoder for [[foresight.eqsat.rewriting.patterns.Pattern.Var]] leaves, i.e., a way to turn surface + * leaves into pattern variables. + * + * @param e Surface-side pattern. + * @tparam EGraphT An e-graph type that supports this language. + */ + def toBorrowingSearcher[EGraphT <: EGraph[Op]](e: E) + (using enc: AtomEncoder[E, Pattern.Var]): Searcher[Op, AbstractPatternMatch[Op], EGraphT] = + StateBorrowingMachineEClassSearcher(toTree(e).compiled) + /** * Build a pattern applier from a surface expression, by first encoding to a pattern tree. * @@ -380,10 +422,39 @@ trait Language[E]: : Rule[Op, PatternMatch[Op], EGraphT] = Rule(name, toSearcher[EGraphT](lhs), toApplier[EGraphT](rhs)) + /** + * Construct a rewrite rule that uses a borrowing searcher under the hood. + * + * @param name Human-readable rule name (used in logs/diagnostics). + * @param lhs Surface pattern to match. + * @param rhs Surface template to build. + * @param enc Encoder for [[foresight.eqsat.rewriting.patterns.Pattern.Var]] leaves. + * @tparam EGraphT An e-graph type that supports this language. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String, lhs: E, rhs: E) + (using enc: AtomEncoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + Rule(name, toBorrowingSearcher[EGraphT](lhs), toApplier[EGraphT](rhs)) + /** Generate a fresh [[foresight.eqsat.rewriting.patterns.Pattern.Var]] and decode it back to a surface `E` value. */ private def freshPatternVar(using dec: AtomDecoder[E, Pattern.Var]): E = fromTree(MixedTree.Atom[Op, Pattern.Var](Pattern.Var.fresh())) + /** + * Create a borrowing rewrite rule by supplying a lambda that receives one fresh pattern variable. + * + * Works like [[rule(name)(f: E => (E,E))]] but compiles the LHS with a borrowing searcher. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: E => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives one fresh pattern variable. * @@ -452,6 +523,18 @@ trait Language[E]: val (lhs, rhs) = f(freshPatternVar, freshPatternVar) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with two fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar, freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives three fresh pattern variables. * @@ -480,6 +563,18 @@ trait Language[E]: val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with three fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives four fresh pattern variables. * @@ -509,6 +604,18 @@ trait Language[E]: val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with four fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives five fresh pattern variables. * @@ -540,6 +647,20 @@ trait Language[E]: ) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with five fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f( + freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar + ) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives six fresh pattern variables. * @@ -571,6 +692,20 @@ trait Language[E]: ) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with six fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E, E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f( + freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar + ) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Reconstruct a surface expression `E` from an *analysis-view* of a single node. * diff --git a/foresight/src/main/scala/foresight/eqsat/EClassCall.scala b/foresight/src/main/scala/foresight/eqsat/EClassCall.scala index d6ee2399..6a1947dc 100644 --- a/foresight/src/main/scala/foresight/eqsat/EClassCall.scala +++ b/foresight/src/main/scala/foresight/eqsat/EClassCall.scala @@ -1,8 +1,10 @@ package foresight.eqsat -import foresight.eqsat.collections.{SlotMap, SlotSet} +import foresight.eqsat.collections.{SlotMap, SlotSeq, SlotSet} import foresight.eqsat.readonly.EGraph +import scala.collection.compat.immutable.ArraySeq + /** * Represents the application of an [[EClassRef]] to a set of argument slots. * @@ -29,7 +31,7 @@ import foresight.eqsat.readonly.EGraph * val call2 = EClassCall(subXY, SlotMap(x -> b, y -> a)) // represents "b - a" * }}} */ -final case class EClassCall(ref: EClassRef, args: SlotMap) extends EClassSymbol { +final case class EClassCall(ref: EClassRef, args: SlotMap) extends EClassSymbol with CallTree[Nothing] { /** * The set of slots used as arguments in this application, in sequence order. * These are the slots referenced by the argument values, not the parameter slots. @@ -39,7 +41,7 @@ final case class EClassCall(ref: EClassRef, args: SlotMap) extends EClassSymbol /** * The set of distinct slots used as arguments in this application. */ - def slotSet: SlotSet = args.valueSet + override def slotSet: SlotSet = args.valueSet /** * Renames all argument slots in this application according to a given mapping. @@ -208,3 +210,78 @@ object EClassSymbol { */ def real(call: EClassCall): Real = call } + +/** + * A tree of nodes and e-class calls, representing expressions in an e-graph. + * @tparam NodeT The type of the nodes in the e-graph. + */ +sealed trait CallTree[+NodeT] { + /** + * The set of distinct slots used in this call tree. + */ + def slotSet: Set[Slot] = this match { + case call: EClassCall => call.slotSet + case CallTree.Node(_, defs, uses, children) => + defs.toSet ++ uses ++ children.flatMap(_.slotSet) + } + + /** + * Converts this call tree into a mixed tree of nodes and e-class calls. + * @return The resulting mixed tree. + */ + final def toMixedTree: MixedTree[NodeT, EClassCall] = this match { + case call: EClassCall => MixedTree.Atom(call) + case CallTree.Node(n, defs, uses, children) => + MixedTree.Node(n, defs, uses, children.map(_.toMixedTree)) + } + + /** + * Maps all [[EClassCall]] nodes in this call tree using a given function. + * @param f The mapping function. + * @return The resulting call tree with mapped calls. + */ + final def mapCalls(f: EClassCall => EClassCall): CallTree[NodeT] = this match { + case call: EClassCall => f(call) + case CallTree.Node(n, defs, uses, children) => + CallTree.Node(n, defs, uses, children.map(_.mapCalls(f))) + } +} + +/** + * Constructors for [[CallTree]] nodes. + */ +object CallTree { + /** + * Represents a node in the call tree. + * @param node The node. + * @param definitions The slots defined by this node. + * @param uses The slots used by this node. + * @param children The child call trees. + * @tparam NodeT The type of the nodes in the e-graph. + */ + final case class Node[+NodeT](node: NodeT, definitions: SlotSeq, uses: SlotSeq, children: ArraySeq[CallTree[NodeT]]) + extends CallTree[NodeT] + + /** + * Converts a [[MixedTree]] of [[EClassCall]]s into a [[CallTree]]. + * @param tree The mixed tree to convert. + * @tparam NodeT The type of the nodes in the e-graph. + * @return The resulting call tree. + */ + def from[NodeT](tree: MixedTree[NodeT, EClassCall]): CallTree[NodeT] = tree match { + case MixedTree.Atom(call) => call + case MixedTree.Node(n, defs, uses, args) => + Node(n, defs, uses, args.map(arg => from(arg))) + } + + /** + * Converts a [[Tree]] into a [[CallTree]]. + * @param tree The tree to convert. + * @tparam NodeT The type of the nodes in the e-graph. + * @return The resulting call tree. + */ + def from[NodeT](tree: Tree[NodeT]): CallTree[NodeT] = tree match { + case Tree(n, defs, uses, children) => + Node(n, defs, uses, children.map(child => from(child))) + } +} diff --git a/foresight/src/main/scala/foresight/eqsat/ENode.scala b/foresight/src/main/scala/foresight/eqsat/ENode.scala index fec0fdfe..6d2f768c 100644 --- a/foresight/src/main/scala/foresight/eqsat/ENode.scala +++ b/foresight/src/main/scala/foresight/eqsat/ENode.scala @@ -18,18 +18,24 @@ import scala.collection.compat.immutable.ArraySeq * * Node types (`NodeT`) supply the operator and any non-structural payload. Slots and arguments are provided here. * - * @param nodeType Operator or symbol for this node. * @tparam NodeT The domain-specific node type. It defines operator identity and payload but not slots/children. */ final class ENode[+NodeT] private ( - val nodeType: NodeT, - private val _definitions: Array[Slot], - private val _uses: Array[Slot], - private val _args: Array[EClassCall] + private var _nodeType: Any, + private var _definitions: Array[Slot], + private var _uses: Array[Slot], + private var _args: Array[EClassCall] ) extends Node[NodeT, EClassCall] with ENodeSymbol[NodeT] { // Cached hash code to make hashing and equality fast (benign data race; like String.hash) private var _hash: Int = 0 + /** + * The operator or symbol for this node. + * + * @return The node type. + */ + def nodeType: NodeT = _nodeType.asInstanceOf[NodeT] + /** * Slots introduced by this node that are scoped locally and invisible to parents. These are * redundant by construction at the boundary of this node and exist to model binders such as @@ -60,6 +66,16 @@ final class ENode[+NodeT] private ( */ private[eqsat] def unsafeArgsArray: Array[EClassCall] = _args + /** + * Unsafe access to the internal array of definition slots. Do not modify. + */ + private[eqsat] def unsafeDefinitionsArray: Array[Slot] = _definitions + + /** + * Unsafe access to the internal array of use slots. Do not modify. + */ + private[eqsat] def unsafeUsesArray: Array[Slot] = _uses + /** * The total number of slots occurring in this node: definitions, uses, and children’s argument slots. * Includes slots that may be duplicated across these categories. @@ -407,17 +423,254 @@ final class ENode[+NodeT] private ( * Constructors and helpers for [[ENode]]. */ object ENode { - private[eqsat] def arraysEqual[T](a: Array[T], b: Array[T]): Boolean = { - if (a eq b) return true - if (a.length != b.length) return false - var i = 0 - while (i < a.length) { - if (a(i) != b(i)) return false - i += 1 + private val newSlotDequeDelegate = new java.util.function.Function[Int, java.util.ArrayDeque[Array[Slot]]] { + override def apply(t: Int): java.util.ArrayDeque[Array[Slot]] = new java.util.ArrayDeque[Array[Slot]] + } + + private val newCallDequeDelegate = new java.util.function.Function[Int, java.util.ArrayDeque[Array[EClassCall]]] { + override def apply(t: Int): java.util.ArrayDeque[Array[EClassCall]] = new java.util.ArrayDeque[Array[EClassCall]] + } + + /** + * Pool for building and recycling `ENode`s with re-usable backing arrays. + * + * This pool both reuses `ENode` objects themselves and their backing arrays for `definitions`, + * `uses`, and `args` to cut allocations on hot paths. + */ + final class Pool private[ENode] ( + private val perBucketCap: Int, + private val maxBucketLen: Int = 16 + ) { + // Buckets indexed by length -> stack of arrays + private val slotBuckets: Array[java.util.ArrayDeque[Array[Slot]]] = + new Array[java.util.ArrayDeque[Array[Slot]]](maxBucketLen + 1) + private val callBuckets: Array[java.util.ArrayDeque[Array[EClassCall]]] = + new Array[java.util.ArrayDeque[Array[EClassCall]]](maxBucketLen + 1) + + // Eagerly initialize deques for all bucket lengths + { + var i = 0 + while (i <= maxBucketLen) { + slotBuckets(i) = new java.util.ArrayDeque[Array[Slot]]() + callBuckets(i) = new java.util.ArrayDeque[Array[EClassCall]]() + i += 1 + } + } + + @inline private def slotDeque(len: Int): java.util.ArrayDeque[Array[Slot]] = { + if (len < 0 || len > maxBucketLen) null + else slotBuckets(len) + } + + @inline private def callDeque(len: Int): java.util.ArrayDeque[Array[EClassCall]] = { + if (len < 0 || len > maxBucketLen) null + else callBuckets(len) + } + + // Free-list of reusable ENode objects + private val nodeFree = new java.util.ArrayDeque[ENode[Any]]() + private def borrowNode(): ENode[Any] = { + if (nodeFree.isEmpty) new ENode[Any](null.asInstanceOf[Any], emptySlotArray, emptySlotArray, emptyCallArray) + else nodeFree.removeFirst() + } + + private def prepareNode[NodeT](node: ENode[Any], nodeType: NodeT, defs: Array[Slot], uses: Array[Slot], args: Array[EClassCall]): ENode[NodeT] = { + node._nodeType = nodeType + node._definitions = defs + node._uses = uses + node._args = args + node._hash = 0 // reset cached hash + node.asInstanceOf[ENode[NodeT]] + } + + /** + * Releases an `ENode` back to this pool, along with its backing arrays. + * @param len Length of the array. + * @return An array of the given length. + */ + def acquireSlotArray(len: Int): Array[Slot] = { + if (len == 0) return emptySlotArray + val q = slotDeque(len) + if (q eq null) return new Array[Slot](len) + if (q.isEmpty) new Array[Slot](len) else q.removeFirst() + } + + /** + * Acquires a slot array filled with the given slots from this pool. + * @param slots Slots to fill into the array. + * @return An array containing the given slots. + */ + def acquireAndFillSlotArray(slots: Seq[Slot]): Array[Slot] = { + copySlotsIntoPooledArray(slots) + } + + /** + * Acquires a call array of the given length from this pool. + * @param len Length of the array. + * @return An array of the given length. + */ + def acquireCallArray(len: Int): Array[EClassCall] = { + if (len == 0) return emptyCallArray + val q = callDeque(len) + if (q eq null) return new Array[EClassCall](len) + if (q.isEmpty) new Array[EClassCall](len) else q.removeFirst() + } + + /** + * Acquires a call array filled with the given calls from this pool. + * @param calls Calls to fill into the array. + * @return An array containing the given calls. + */ + def acquireAndFillCallArray(calls: Seq[EClassCall]): Array[EClassCall] = { + copyCallsIntoPooledArray(calls) + } + + /** + * Releases a slot array back to this pool. + * @param arr Array to release. + */ + def releaseSlotArray(arr: Array[Slot]): Unit = { + if (arr eq null) return + val len = arr.length + if (len == 0) return + // java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) + val q = slotDeque(len) + if ((q ne null) && q.size() < perBucketCap) q.addFirst(arr) + } + + /** + * Releases a call array back to this pool. + * @param arr Array to release. + */ + def releaseCallArray(arr: Array[EClassCall]): Unit = { + if (arr eq null) return + val len = arr.length + if (len == 0) return + // java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) + val q = callDeque(len) + if ((q ne null) && q.size() < perBucketCap) q.addFirst(arr) + } + + private def copySlotsIntoPooledArray(slots: Seq[Slot]): Array[Slot] = slots match { + case slotSeq: SlotSeq => + val len = slotSeq.size + if (len == 0) emptySlotArray + else { + val arr = acquireSlotArray(len) + var i = 0 + while (i < len) { arr(i) = slotSeq.unsafeArray(i); i += 1 } + arr + } + case as: scala.collection.compat.immutable.ArraySeq[Slot] if as.unsafeArray.isInstanceOf[Array[Slot]] => + val src = as.unsafeArray.asInstanceOf[Array[Slot]] + val len = src.length + if (len == 0) emptySlotArray + else { + val arr = acquireSlotArray(len) + System.arraycopy(src, 0, arr, 0, len) + arr + } + case _ => + val len = slots.size + if (len == 0) emptySlotArray + else { + val arr = acquireSlotArray(len) + var i = 0 + val it = slots.iterator + while (i < len && it.hasNext) { arr(i) = it.next(); i += 1 } + arr + } + } + + private def copyCallsIntoPooledArray(calls: Seq[EClassCall]): Array[EClassCall] = calls match { + case as: scala.collection.compat.immutable.ArraySeq[EClassCall] if as.unsafeArray.isInstanceOf[Array[EClassCall]] => + val src = as.unsafeArray.asInstanceOf[Array[EClassCall]] + val len = src.length + if (len == 0) emptyCallArray + else { + val arr = acquireCallArray(len) + System.arraycopy(src, 0, arr, 0, len) + arr + } + case _ => + val len = calls.size + if (len == 0) emptyCallArray + else { + val arr = acquireCallArray(len) + var i = 0 + val it = calls.iterator + while (i < len && it.hasNext) { arr(i) = it.next(); i += 1 } + arr + } + } + + /** + * Builds an `ENode` whose backing arrays are owned by this pool and therefore recyclable via [[release]]. + */ + def acquire[NodeT](nodeType: NodeT, definitions: Seq[Slot], uses: Seq[Slot], args: Seq[EClassCall]): ENode[NodeT] = { + val defsArr = copySlotsIntoPooledArray(definitions) + val usesArr = copySlotsIntoPooledArray(uses) + val argsArr = copyCallsIntoPooledArray(args) + acquireUnsafe(nodeType, defsArr, usesArr, argsArr) + } + + /** + * Builds an `ENode` whose backing arrays are owned by this pool and therefore recyclable via [[release]]. + * + * The caller must ensure that the provided arrays are not used or retained after calling this method. + */ + private[eqsat] def acquireUnsafe[NodeT](nodeType: NodeT, definitions: Array[Slot], uses: Array[Slot], args: Array[EClassCall]): ENode[NodeT] = { + val nAny = borrowNode() + val n = prepareNode(nAny, nodeType, definitions, uses, args) + n + } + + /** + * Returns `node` to this pool. After calling this, the caller must not keep or use `node`. + */ + def release(node: ENode[_]): Unit = { + // Return backing arrays if they belong to this pool + releaseSlotArray(node.unsafeDefinitionsArray) + releaseSlotArray(node.unsafeUsesArray) + releaseCallArray(node.unsafeArgsArray) + + val nAny = node.asInstanceOf[ENode[Any]] + // Clear fields to avoid retaining references and to mark as blank + nAny._nodeType = null + nAny._definitions = emptySlotArray + nAny._uses = emptySlotArray + nAny._args = emptyCallArray + nAny._hash = 0 + if (nodeFree.size() < perBucketCap) nodeFree.addFirst(nAny) + } + + /** Clears all buckets. */ + def clear(): Unit = { + var i = 0 + while (i < slotBuckets.length) { + val q = slotBuckets(i) + if (q != null) q.clear() + i += 1 + } + i = 0 + while (i < callBuckets.length) { + val q = callBuckets(i) + if (q != null) q.clear() + i += 1 + } + nodeFree.clear() } - true } + /** Thread-local default pool for lightweight reuse without wiring one through call sites. */ + private val threadLocalDefaultPool = new ThreadLocal[Pool] { override def initialValue(): Pool = new Pool(64) } + + /** Creates a new pool with the given per-bucket capacity. */ + def newPool(perBucketCap: Int = 64): Pool = new Pool(perBucketCap) + + /** Gets the thread-local default pool. */ + def defaultPool: Pool = threadLocalDefaultPool.get() + private val emptySlotArray: Array[Slot] = Array.empty private val emptyCallArray: Array[EClassCall] = Array.empty diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index a8c9fbf2..bea9a877 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -2,7 +2,7 @@ package foresight.eqsat.commands import foresight.eqsat.collections.SlotSeq import foresight.eqsat.readonly.EGraph -import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} +import foresight.eqsat.{CallTree, EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} import foresight.util.Debug import foresight.util.collections.UnsafeSeqFromArray @@ -81,30 +81,37 @@ trait CommandScheduleBuilder[NodeT] { } } - private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], + private[eqsat] def addSimplifiedReal(tree: CallTree[NodeT], egraph: EGraph[NodeT]): EClassSymbol = { - val maxBatch = new IntRef(0) - addSimplifiedReal(tree, egraph, maxBatch) + val refPool = IntRef.defaultPool + val maxBatch = refPool.acquire(0) + val result = addSimplifiedReal(tree, egraph, maxBatch, ENode.defaultPool, refPool) + refPool.release(maxBatch) + result } - private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], + private[eqsat] def addSimplifiedReal(tree: CallTree[NodeT], egraph: EGraph[NodeT], - maxBatch: IntRef): EClassSymbol = { + maxBatch: IntRef, + nodePool: ENode.Pool, + refPool: IntRef.Pool): EClassSymbol = { tree match { - case MixedTree.Node(t, defs, uses, args) => + case CallTree.Node(t, defs, uses, args) => // Local accumulator for children of this node. - val childMax = new IntRef(0) + val childMax = refPool.acquire(0) val argSymbols = CommandScheduleBuilder.symbolArrayFrom( args, childMax, - (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb) + nodePool, + (child: CallTree[NodeT], mb: IntRef) => addSimplifiedReal(child, egraph, mb, nodePool, refPool) ) - val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph) + val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph, nodePool) // Propagate maximum required batch up to the caller's accumulator. if (childMax.elem > maxBatch.elem) maxBatch.elem = childMax.elem + refPool.release(childMax) sym - case MixedTree.Atom(call) => + case call: EClassCall => // No insertion required; keep caller's accumulator unchanged. EClassSymbol.real(call) } @@ -115,7 +122,8 @@ trait CommandScheduleBuilder[NodeT] { uses: SlotSeq, args: Array[EClassSymbol], maxBatch: IntRef, - egraph: EGraph[NodeT]): EClassSymbol = { + egraph: EGraph[NodeT], + pool: ENode.Pool): EClassSymbol = { // Check if all children are already in the graph. val argCalls = CommandScheduleBuilder.resolveAllOrNull(args) @@ -126,7 +134,12 @@ trait CommandScheduleBuilder[NodeT] { assert(maxBatch.elem == 0) } - val candidateNode = ENode.unsafeWrapArrays(nodeType, definitions, uses, argCalls) + val candidateNode = pool.acquireUnsafe( + nodeType, + pool.acquireAndFillSlotArray(definitions), + pool.acquireAndFillSlotArray(uses), + argCalls) + egraph.findOrNull(candidateNode) match { case null => // Node does not exist in the graph but its children do exist in the graph. @@ -135,6 +148,7 @@ trait CommandScheduleBuilder[NodeT] { case existingCall => // Node already exists in the graph; reuse its class. + pool.release(candidateNode) EClassSymbol.real(existingCall) } } else { @@ -157,14 +171,14 @@ object CommandScheduleBuilder { */ def newConcurrentBuilder[NodeT]: CommandScheduleBuilder[NodeT] = new ConcurrentCommandScheduleBuilder[NodeT]() - private[eqsat] def symbolArrayFrom[A](values: ArraySeq[A], maxBatch: IntRef, valueToSymbol: (A, IntRef) => EClassSymbol): Array[EClassSymbol] = { + private[eqsat] def symbolArrayFrom[A](values: ArraySeq[A], maxBatch: IntRef, pool: ENode.Pool, valueToSymbol: (A, IntRef) => EClassSymbol): Array[EClassSymbol] = { // Try to avoid allocating an array of EClassSymbol if all entries are EClassCall. // The common case is that all children are already in the e-graph, and we will // want to construct an ENode with an Array[EClassCall]. // If we find any entry that is not an EClassCall, we fall back to allocating // an Array[EClassSymbol] and copying the prefix of calls. val n = values.length - val calls = new Array[EClassCall](n) + val calls = pool.acquireCallArray(n) var i = 0 while (i < n) { valueToSymbol(values(i), maxBatch) match { @@ -175,8 +189,10 @@ object CommandScheduleBuilder { val syms = new Array[EClassSymbol](n) var j = 0 while (j < i) { - syms(j) = calls(j); j += 1 + syms(j) = calls(j) + j += 1 } + pool.releaseCallArray(calls) syms(i) = other j = i + 1 while (j < n) { diff --git a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala index f0992cd6..8f5dda17 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala @@ -1,3 +1,47 @@ package foresight.eqsat.commands private[eqsat] final class IntRef(var elem: Int) + +object IntRef { + /** + * Pool of reusable [[IntRef]] instances. + * + * Usage: + * val r = IntRef.acquire() // from the default thread-local pool + * r.elem = 42 + * IntRef.release(r) // return to pool + */ + final class Pool { + private final val maxSize = 64 + // LIFO stack to maximize cache locality + private val free = new java.util.ArrayDeque[IntRef](maxSize) + + /** Acquire an IntRef, initializing its value. */ + @inline def acquire(initial: Int): IntRef = { + val ref = free.pollFirst() + if (ref eq null) new IntRef(initial) + else { ref.elem = initial; ref } + } + + /** Return an IntRef to the pool for reuse. */ + @inline def release(ref: IntRef): Unit = { + if (free.size() >= maxSize) return + // no double-free tracking for performance; callers ensure discipline + free.addFirst(ref) + } + + /** Number of currently stored reusable objects. */ + @inline def size: Int = free.size() + + /** Drop all cached objects. */ + def clear(): Unit = free.clear() + } + + // Default thread-local pool. + private val threadLocal: ThreadLocal[Pool] = new ThreadLocal[Pool] { + override def initialValue(): Pool = new Pool + } + + /** Access the default thread-local pool for the current thread. */ + @inline def defaultPool: Pool = threadLocal.get() +} \ No newline at end of file diff --git a/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala b/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala index 634011a8..266dc0c4 100644 --- a/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala +++ b/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala @@ -1,8 +1,7 @@ package foresight.eqsat.extraction -import foresight.eqsat.{EClassCall, MixedTree, Tree} +import foresight.eqsat.{CallTree, EClassCall, MixedTree, Tree, readonly} import foresight.eqsat.immutable.{EGraph, EGraphLike} -import foresight.eqsat.readonly /** * An extractor that converts e-graph references (e-class calls) into concrete expression trees. @@ -54,4 +53,26 @@ trait Extractor[NodeT, -Repr <: readonly.EGraph[NodeT]] { apply(call, egraph) } } + + /** + * Extracts a concrete expression tree from a [[CallTree]] by extracting an expression for each + * e-class call within it. + * + * @param tree The call tree to materialize prior to extraction. + * @param egraph The original e-graph; remains unchanged. + * @return A concrete [[Tree]] extracted from the materialized call. + * + * @example + * {{{ + * val result: Tree[NodeT] = extractor(callTree, egraph) // `egraph` is unchanged + * }}} + */ + final def apply(tree: CallTree[NodeT], egraph: Repr): Tree[NodeT] = { + tree match { + case call: EClassCall => apply(call, egraph) + case CallTree.Node(n, defs, uses, children) => + val extractedChildren = children.map(child => apply(child, egraph)) + Tree(n, defs, uses, extractedChildren) + } + } } diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala index 74e597b3..15c50486 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala @@ -1,8 +1,7 @@ package foresight.eqsat.immutable -import foresight.eqsat.{AddNodeResult, EClassCall, ENode, MixedTree, Tree} +import foresight.eqsat.{AddNodeResult, CallTree, EClassCall, ENode, MixedTree, Tree, readonly} import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.readonly import scala.collection.compat.immutable.ArraySeq @@ -148,6 +147,27 @@ trait EGraphLike[NodeT, +This <: EGraphLike[NodeT, This] with EGraph[NodeT]] ext } } + /** + * Adds a call tree to the e-graph. + * + * Child subtrees are added/resolved first; then the root e-node is added or found. + * Returns a new e-graph containing the result. + * + * @param tree The call tree to add. + * @return (E-class of the root, new e-graph). + */ + final def add(tree: CallTree[NodeT]): (EClassCall, This) = { + tree match { + case call: EClassCall => (call, this.asInstanceOf[This]) + case CallTree.Node(t, defs, uses, args) => + val (newArgs, graphWithArgs) = args.foldLeft((Seq.empty[EClassCall], this.asInstanceOf[This]))((acc, arg) => { + val (node, egraph) = acc._2.add(arg) + (acc._1 :+ node, egraph) + }) + graphWithArgs.add(ENode(t, defs, uses, newArgs)) + } + } + /** * Adds a pure tree to the e-graph. * diff --git a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala index a7b7e03e..66790b8d 100644 --- a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala +++ b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala @@ -1,6 +1,6 @@ package foresight.eqsat.readonly -import foresight.eqsat.{EClassCall, EClassRef, MixedTree} +import foresight.eqsat.{CallTree, EClassCall, EClassRef, MixedTree} import foresight.eqsat.metadata.Analysis /** @@ -55,6 +55,28 @@ trait AnalysisMetadata[NodeT, A] { } } + /** + * Evaluate a call tree whose leaves are either e-class applications or concrete nodes. + * + * For a call leaf, this delegates to [[apply(EClassCall,EGraph)]]. For a node leaf, it first + * computes the results of all argument subtrees and then invokes the analysis transfer function + * [[Analysis.make]] using the node’s definitions and uses provided by the tree. + * + * @param tree The call tree to evaluate. + * @param egraph The e-graph to resolve calls and canonicalization. + * @return The analysis result for the whole tree. + */ + final def apply(tree: CallTree[NodeT], + egraph: EGraph[NodeT]): A = { + tree match { + case CallTree.Node(node, defs, uses, args) => + val argsResults = args.map(apply(_, egraph)) + analysis.make(node, defs, uses, argsResults) + + case call: EClassCall => apply(call, egraph) + } + } + /** * Compute the analysis result for an already-canonicalized e-class application. * diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala new file mode 100644 index 00000000..cf581cd6 --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala @@ -0,0 +1,39 @@ +package foresight.eqsat.rewriting.patterns + +import foresight.eqsat.{CallTree, EClassCall, Slot} + +/** + * An abstract pattern match that maps pattern variables to trees and slots. + * + * @tparam NodeT The type of the nodes in the e-graph. + */ +trait AbstractPatternMatch[NodeT] { + /** + * The e-class in which the pattern was found. + */ + def root: EClassCall + + /** + * Gets the tree that corresponds to a variable. + * + * @param variable The variable. + * @return The tree. + */ + def apply(variable: Pattern.Var): CallTree[NodeT] + + /** + * Gets the slot that corresponds to a slot variable. + * + * @param slot The slot variable. + * @return The slot. + */ + def apply(slot: Slot): Slot + + /** + * Gets the slot that corresponds to a slot variable, returning an option. + * + * @param slot The slot variable. + * @return The slot if it exists; None otherwise. + */ + def get(slot: Slot): Option[Slot] +} diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala index 93c3a41d..a267c355 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala @@ -47,6 +47,27 @@ final case class CompiledPattern[NodeT, EGraphT <: EGraph[NodeT]](pattern: Mixed }) } + /** + * Searches for matches of the pattern in an e-graph, calling a continuation for each match. + * If the continuation returns false, the search is stopped. + * + * Borrows the machine state and passes it to the continuation. The continuation must not + * retain the state after returning. + * + * @param call The e-class application to search for. + * @param egraph The e-graph to search in. + * @param continuation A continuation that is called for each match of the pattern. If the continuation returns false, + * the search is stopped. + */ + def searchBorrowed(call: EClassCall, egraph: EGraphT, continuation: (AbstractPatternMatch[NodeT], EGraphT) => Boolean): Unit = { + val state = machinePool.borrow(call) + Machine.run(egraph, state, instructions, (state: MutableMachineState[NodeT]) => { + val result = continuation(state, egraph) + state.release() + result + }) + } + /** * Searches for matches of the pattern in an e-graph. * @param call The e-class application to search for. diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala index 3e2660c7..f8ab9f0b 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala @@ -400,7 +400,7 @@ object Instruction { ) override def execute(ctx: Instruction.Execution[NodeT, EGraphT]): Boolean = { - val value = MixedTree.Atom[NodeT, EClassCall](ctx.machine.registerAt(register)) + val value = ctx.machine.registerAt(register) ctx.machine.bindVar(value) ctx.continue() } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala index 9f14e937..4e0fe32a 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala @@ -1,9 +1,9 @@ package foresight.eqsat.rewriting.patterns -import foresight.eqsat.{EClassCall, ENode, MixedTree, Slot} +import foresight.eqsat.{CallTree, EClassCall, ENode, Slot} import foresight.util.collections.ArrayMap -import scala.collection.compat._ +import scala.collection.compat.immutable.ArraySeq /** * The state of a pattern machine. @@ -14,7 +14,7 @@ import scala.collection.compat._ * @param boundNodes The nodes that are bound in the machine. * @tparam NodeT The type of the nodes in the e-graph. */ -final case class MachineState[NodeT](registers: immutable.ArraySeq[EClassCall], - boundVars: ArrayMap[Pattern.Var, MixedTree[NodeT, EClassCall]], +final case class MachineState[NodeT](registers: ArraySeq[EClassCall], + boundVars: ArrayMap[Pattern.Var, CallTree[NodeT]], boundSlots: ArrayMap[Slot, Slot], - boundNodes: immutable.ArraySeq[ENode[NodeT]]) + boundNodes: ArraySeq[ENode[NodeT]]) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala index 59e27379..378a19ca 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala @@ -1,9 +1,9 @@ package foresight.eqsat.rewriting.patterns -import foresight.eqsat.{EClassCall, ENode, MixedTree, Slot} +import foresight.eqsat.{CallTree, EClassCall, ENode, Slot} import foresight.util.collections.ArrayMap -import scala.collection.compat._ +import scala.collection.compat.immutable.ArraySeq /** * A mutable machine state that preallocates fixed-size arrays @@ -11,10 +11,11 @@ import scala.collection.compat._ */ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, private val registersArr: Array[EClassCall], - private val boundVarsArr: Array[MixedTree[NodeT, EClassCall]], + private val boundVarsArr: Array[EClassCall], private val boundSlotsArr: Array[Slot], private val boundNodesArr: Array[ENode[NodeT]], - private val homePool: MutableMachineState.Pool[NodeT]) { + private val homePool: MutableMachineState.Pool[NodeT]) + extends AbstractPatternMatch[NodeT] { private var regIdx: Int = 0 private var varIdx: Int = 0 @@ -165,12 +166,12 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, * Bind a variable to a value. * @param value The value to bind the variable to. */ - def bindVar(value: MixedTree[NodeT, EClassCall]): Unit = { + def bindVar(value: EClassCall): Unit = { boundVarsArr(varIdx) = value varIdx += 1 } - private def boundVars: ArrayMap[Pattern.Var, MixedTree[NodeT, EClassCall]] = { + private def boundVars: ArrayMap[Pattern.Var, CallTree[NodeT]] = { ArrayMap.unsafeWrapArrays(effects.boundVars.unsafeArray, java.util.Arrays.copyOf(boundVarsArr, varIdx), varIdx) } @@ -185,8 +186,8 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, /** Convert to an immutable MachineState snapshot. */ def freeze(): MachineState[NodeT] = { - val regs = immutable.ArraySeq.unsafeWrapArray(registersArr.slice(0, regIdx)) - val nodes = immutable.ArraySeq.unsafeWrapArray(boundNodesArr.slice(0, nodeIdx)) + val regs = ArraySeq.unsafeWrapArray(registersArr.slice(0, regIdx)) + val nodes = ArraySeq.unsafeWrapArray(boundNodesArr.slice(0, nodeIdx)) MachineState(regs, boundVars, boundSlots, nodes) } @@ -196,6 +197,51 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, def toPatternMatch: PatternMatch[NodeT] = { PatternMatch(registersArr(0), boundVars, boundSlots) } + + /** + * The e-class in which the pattern was found. + */ + override def root: EClassCall = registersArr(0) + + /** + * Gets the tree that corresponds to a variable. + * + * @param variable The variable. + * @return The tree. + */ + override def apply(variable: Pattern.Var): CallTree[NodeT] = { + val i = effects.boundVars.indexOf(variable) + if (i < 0 || i >= varIdx) + throw new NoSuchElementException(s"Variable $variable not bound in this state") + + boundVarsArr(i) + } + + /** + * Gets the slot that corresponds to a slot variable. + * + * @param slot The slot variable. + * @return The slot. + */ + override def apply(slot: Slot): Slot = { + val i = effects.boundSlots.indexOf(slot) + if (i < 0 || i >= slotIdx) + throw new NoSuchElementException(s"Slot $slot not bound in this state") + + boundSlotsArr(i) + } + + /** + * Gets the slot that corresponds to a slot variable, returning an option. + * + * @param slot The slot variable. + * @return The slot if it exists; None otherwise. + */ + override def get(slot: Slot): Option[Slot] = { + val i = effects.boundSlots.indexOf(slot) + if (i >= 0 && i < slotIdx) Some(boundSlotsArr(i)) + else None + } } /** @@ -206,7 +252,7 @@ object MutableMachineState { // Preallocate single empty arrays of the right types to avoid repeated allocations // when effects report zero bound vars/slots/nodes. // Unsafe casts are safe because these arrays are never written to. - private val emptyVars = new Array[MixedTree[_, EClassCall]](0) + private val emptyVars = new Array[EClassCall](0) private val emptySlots = new Array[Slot](0) private val emptyNodes = new Array[ENode[_]](0) @@ -221,7 +267,7 @@ object MutableMachineState { val m = new MutableMachineState[NodeT]( effects, new Array[EClassCall](1 + effects.createdRegisters), - if (varsLen > 0) new Array[MixedTree[NodeT, EClassCall]](varsLen) else emptyVars.asInstanceOf[Array[MixedTree[NodeT, EClassCall]]], + if (varsLen > 0) new Array[EClassCall](varsLen) else emptyVars, if (slotsLen > 0) new Array[Slot](slotsLen) else emptySlots, if (nodesLen > 0) new Array[ENode[NodeT]](nodesLen) else emptyNodes.asInstanceOf[Array[ENode[NodeT]]], pool diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index 91aedae9..07ab5102 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -3,10 +3,11 @@ package foresight.eqsat.rewriting.patterns import foresight.eqsat.collections.SlotSeq import foresight.eqsat.commands.{CommandScheduleBuilder, IntRef} import foresight.eqsat.readonly.EGraph -import foresight.eqsat.rewriting.{ReversibleApplier, Searcher} -import foresight.eqsat.{EClassSymbol, MixedTree, Slot} +import foresight.eqsat.rewriting.{Applier, ReversibleApplier, Searcher} +import foresight.eqsat.{CallTree, EClassCall, EClassSymbol, ENode, MixedTree, Slot} import scala.collection.compat.immutable.ArraySeq +import java.util.ArrayDeque /** * An applier that applies a pattern match to an e-graph. @@ -16,9 +17,9 @@ import scala.collection.compat.immutable.ArraySeq * @tparam EGraphT The type of the e-graph that the applier applies the match to. */ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedTree[NodeT, Pattern.Var]) - extends ReversibleApplier[NodeT, PatternMatch[NodeT], EGraphT] { + extends ReversibleApplier[NodeT, PatternMatch[NodeT], EGraphT] with Applier[NodeT, AbstractPatternMatch[NodeT], EGraphT] { - override def apply(m: PatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { + override def apply(m: AbstractPatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { val symbol = instantiateAsSimplifiedAddCommand(pattern, m, egraph, builder) builder.unionSimplified(EClassSymbol.real(m.root), symbol, egraph) } @@ -46,7 +47,7 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT m: PatternMatch[NodeT]): MixedTree[NodeT, EClassSymbol] = { pattern match { case MixedTree.Atom(p) => p match { - case v: Pattern.Var => m(v).mapAtoms(EClassSymbol.real) + case v: Pattern.Var => m(v).toMixedTree.mapAtoms(EClassSymbol.real) } case MixedTree.Node(t, Seq(), uses, args) => @@ -65,9 +66,35 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } - private final class SimplifiedAddCommandInstantiator(m: PatternMatch[NodeT], - egraph: EGraphT, - builder: CommandScheduleBuilder[NodeT]) { + private final class SimplifiedAddCommandInstantiator { + // Mutable state to allow pooling; initialized via init() before use. + private var m: AbstractPatternMatch[NodeT] = _ + private var egraph: EGraphT = _ + private var builder: CommandScheduleBuilder[NodeT] = _ + private var nodePool: ENode.Pool = _ + private var refPool: IntRef.Pool = _ + + def init(m0: AbstractPatternMatch[NodeT], + egraph0: EGraphT, + builder0: CommandScheduleBuilder[NodeT], + nodePool0: ENode.Pool, + refPool0: IntRef.Pool): Unit = { + this.m = m0 + this.egraph = egraph0 + this.builder = builder0 + this.nodePool = nodePool0 + this.refPool = refPool0 + } + + def clear(): Unit = { + // release references to help GC when pooled instances sit around + this.m = null + this.egraph = null.asInstanceOf[EGraphT] + this.builder = null + this.nodePool = null + this.refPool = null + } + def instantiate(pattern: MixedTree[NodeT, Pattern.Var], maxBatch: IntRef): EClassSymbol = { pattern match { case MixedTree.Atom(p) => builder.addSimplifiedReal(m(p), egraph) @@ -77,13 +104,42 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT case MixedTree.Node(t, defs, uses, args) => val defSlots = defs.map { (s: Slot) => - m.slotMapping.get(s) match { + m.get(s) match { case Some(v) => v case None => Slot.fresh() } } - val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) - new SimplifiedAddCommandInstantiator(newMatch, egraph, builder).addSimplifiedNode(t, defSlots, uses, args, maxBatch) + + val newMatch = new AbstractPatternMatch[NodeT] { + override def root: EClassCall = m.root + override def apply(variable: Pattern.Var): CallTree[NodeT] = m.apply(variable) + override def apply(slot: Slot): Slot = { + if (defs.contains(slot)) { + val index = defs.indexOf(slot) + defSlots(index) + } else { + m.apply(slot) + } + } + override def get(slot: Slot): Option[Slot] = { + if (defs.contains(slot)) { + val index = defs.indexOf(slot) + Some(defSlots(index)) + } else { + m.get(slot) + } + } + } + + // Acquire a nested instantiator. + val nested = SimplifiedAddCommandInstantiator.acquire() + nested.init(newMatch, egraph, builder, nodePool, refPool) + try { + nested.addSimplifiedNode(t, defSlots, uses, args, maxBatch) + } finally { + nested.clear() + SimplifiedAddCommandInstantiator.release(nested) + } } } @@ -92,22 +148,54 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT uses: SlotSeq, args: ArraySeq[MixedTree[NodeT, Pattern.Var]], maxBatch: IntRef): EClassSymbol = { - val argMaxBatch = new IntRef(0) - val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, instantiate) + val argMaxBatch = refPool.acquire(0) + val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, nodePool, instantiate) val useSymbols = uses.map(m.apply: Slot => Slot) - val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph) + val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph, nodePool) if (argMaxBatch.elem > maxBatch.elem) { maxBatch.elem = argMaxBatch.elem } + refPool.release(argMaxBatch) result } } + private object SimplifiedAddCommandInstantiator { + // Per-thread pool to avoid contention; LIFO to improve cache locality. + private val local: ThreadLocal[ArrayDeque[SimplifiedAddCommandInstantiator]] = + new ThreadLocal[ArrayDeque[SimplifiedAddCommandInstantiator]]() { + override def initialValue(): ArrayDeque[SimplifiedAddCommandInstantiator] = + new ArrayDeque[SimplifiedAddCommandInstantiator]() + } + + def acquire(): SimplifiedAddCommandInstantiator = { + val dq = local.get() + val inst = dq.pollFirst() + if (inst != null) inst else new SimplifiedAddCommandInstantiator + } + + def release(inst: SimplifiedAddCommandInstantiator): Unit = { + local.get().offerFirst(inst) + } + } + private def instantiateAsSimplifiedAddCommand(pattern: MixedTree[NodeT, Pattern.Var], - m: PatternMatch[NodeT], + m: AbstractPatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): EClassSymbol = { - new SimplifiedAddCommandInstantiator(m, egraph, builder).instantiate(pattern, new IntRef(0)) + val refPool = IntRef.defaultPool + val ref = refPool.acquire(0) + val inst = SimplifiedAddCommandInstantiator.acquire() + inst.init(m, egraph, builder, ENode.defaultPool, refPool) + val result = + try { + inst.instantiate(pattern, ref) + } finally { + inst.clear() + SimplifiedAddCommandInstantiator.release(inst) + } + refPool.release(ref) + result } } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala index b593faed..64c77a06 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala @@ -2,7 +2,7 @@ package foresight.eqsat.rewriting.patterns import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.PortableMatch -import foresight.eqsat.{EClassCall, MixedTree, Slot} +import foresight.eqsat.{CallTree, EClassCall, MixedTree, Slot} import foresight.util.collections.ArrayMap import foresight.util.collections.StrictMapOps.toStrictMapOps @@ -15,22 +15,30 @@ import foresight.util.collections.StrictMapOps.toStrictMapOps * @tparam NodeT The type of the nodes in the e-graph. */ final case class PatternMatch[NodeT](root: EClassCall, - varMapping: ArrayMap[Pattern.Var, MixedTree[NodeT, EClassCall]], - slotMapping: ArrayMap[Slot, Slot]) extends PortableMatch[NodeT, PatternMatch[NodeT]] { + varMapping: ArrayMap[Pattern.Var, CallTree[NodeT]], + slotMapping: ArrayMap[Slot, Slot]) + extends AbstractPatternMatch[NodeT] with PortableMatch[NodeT, PatternMatch[NodeT]] { /** * Gets the tree that corresponds to a variable. * @param variable The variable. * @return The tree. */ - def apply(variable: Pattern.Var): MixedTree[NodeT, EClassCall] = varMapping(variable) + override def apply(variable: Pattern.Var): CallTree[NodeT] = varMapping(variable) /** * Gets the slot that corresponds to a slot variable. * @param slot The slot variable. * @return The slot. */ - def apply(slot: Slot): Slot = slotMapping(slot) + override def apply(slot: Slot): Slot = slotMapping(slot) + + /** + * Gets the slot that corresponds to a slot variable, returning an option. + * @param slot The slot variable. + * @return The slot if it exists; None otherwise. + */ + override def get(slot: Slot): Option[Slot] = slotMapping.get(slot) /** * Creates an updated match with a new variable binding. @@ -38,7 +46,7 @@ final case class PatternMatch[NodeT](root: EClassCall, * @param value The value to bind the variable to. * @return The updated match with the new binding. */ - def bind(variable: Pattern.Var, value: MixedTree[NodeT, EClassCall]): PatternMatch[NodeT] = { + def bind(variable: Pattern.Var, value: CallTree[NodeT]): PatternMatch[NodeT] = { PatternMatch(root, varMapping.updated(variable, value), slotMapping) } @@ -74,7 +82,7 @@ final case class PatternMatch[NodeT](root: EClassCall, override def port(egraph: EGraph[NodeT]): PatternMatch[NodeT] = { val newRoot = egraph.canonicalize(root) - val newVarMapping = varMapping.mapValuesStrict(_.mapAtoms(egraph.canonicalize)) + val newVarMapping = varMapping.mapValuesStrict(_.mapCalls(egraph.canonicalize)) val newSlotMapping = slotMapping PatternMatch(newRoot, newVarMapping, newSlotMapping) } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/StateBorrowingMachineEClassSearcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/StateBorrowingMachineEClassSearcher.scala new file mode 100644 index 00000000..d614a319 --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/StateBorrowingMachineEClassSearcher.scala @@ -0,0 +1,49 @@ +package foresight.eqsat.rewriting.patterns + +import foresight.eqsat.rewriting.{Applier, EClassSearcher, ReversibleSearcher, SearcherContinuation} +import foresight.eqsat.EClassCall +import foresight.eqsat.readonly.EGraph + +/** + * A phase of a searcher that searches for matches of a pattern machine in an e-graph. + * + * This searcher uses a borrowing mechanism for the machine state, allowing for more efficient + * searches by reusing state without the overhead of copying. + * + * @param pattern The pattern to search for. + * @tparam NodeT The type of the nodes in the e-graph. + * @tparam EGraphT The type of the e-graph that the searcher searches in. + */ +final case class StateBorrowingMachineEClassSearcher[ + NodeT, + EGraphT <: EGraph[NodeT] +](pattern: CompiledPattern[NodeT, EGraphT], + buildContinuation: SearcherContinuation.ContinuationBuilder[NodeT, AbstractPatternMatch[NodeT], EGraphT]) + + extends EClassSearcher[NodeT, AbstractPatternMatch[NodeT], EGraphT] { + + protected override def search(call: EClassCall, egraph: EGraphT, continuation: SearcherContinuation.Continuation[NodeT, AbstractPatternMatch[NodeT], EGraphT]): Unit = { + pattern.searchBorrowed(call, egraph, continuation) + } + + override def withContinuationBuilder(continuation: SearcherContinuation.ContinuationBuilder[NodeT, AbstractPatternMatch[NodeT], EGraphT]): StateBorrowingMachineEClassSearcher[NodeT, EGraphT] = { + copy(buildContinuation = continuation) + } +} + +object StateBorrowingMachineEClassSearcher { + /** + * Creates a `StateBorrowingMachineEClassSearcher` from a compiled pattern. + * + * @param pattern The pattern to search for. + * @tparam NodeT The type of the nodes in the e-graph. + * @tparam EGraphT The type of the e-graph that the searcher searches in. + * @return A new `StateBorrowingMachineEClassSearcher` instance. + */ + def apply[ + NodeT, + EGraphT <: EGraph[NodeT] + ](pattern: CompiledPattern[NodeT, EGraphT]): StateBorrowingMachineEClassSearcher[NodeT, EGraphT] = { + StateBorrowingMachineEClassSearcher(pattern, SearcherContinuation.identityBuilder) + } +} diff --git a/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala b/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala new file mode 100644 index 00000000..57427663 --- /dev/null +++ b/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala @@ -0,0 +1,226 @@ +package foresight.eqsat + +import org.junit.Test +import org.junit.Assert._ +import scala.collection.compat.immutable.ArraySeq + +class ENodePoolTest { + // Helpers to build sized sequences without needing concrete Slot/EClassCall values + private def slots(n: Int): Seq[Slot] = { + // Backed by a real Array[Slot] of length n; elements are null and that's fine for pooling + ArraySeq.unsafeWrapArray(new Array[Slot](n)) + } + private def calls(n: Int): Seq[EClassCall] = { + ArraySeq.unsafeWrapArray(new Array[EClassCall](n)) + } + + @Test + def createPool(): Unit = { + val p = ENode.newPool(8) + assertNotNull(p) + val d = ENode.defaultPool + assertNotNull(d) + assertNotSame("newPool returns a different instance than the thread-local default", p, d) + } + + @Test + def acquireReleaseReacquireReusesNodeObjectWhenShapesMatch(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = "add", definitions = slots(2), uses = slots(1), args = calls(3)) + val defs1 = n1.unsafeDefinitionsArray + val uses1 = n1.unsafeUsesArray + val args1 = n1.unsafeArgsArray + + // Release to the pool + p.release(n1) + + // Re-acquire with the same shapes; node object and arrays should be reused + val n2 = p.acquire(nodeType = "add", definitions = slots(2), uses = slots(1), args = calls(3)) + + // The node instance itself should be reused + assertTrue("Node object is expected to be reused", (n1 eq n2)) + + // The backing arrays should be reused by identity when lengths match + assertTrue("Definitions array should be reused by identity", (defs1 eq n2.unsafeDefinitionsArray)) + assertTrue("Uses array should be reused by identity", (uses1 eq n2.unsafeUsesArray)) + assertTrue("Args array should be reused by identity", (args1 eq n2.unsafeArgsArray)) + + // Clean up + p.release(n2) + } + + @Test + def acquireWithDifferentLengthsDoesNotReuseArraysAcrossBuckets(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = "mul", definitions = slots(1), uses = slots(1), args = calls(2)) + val defs1 = n1.unsafeDefinitionsArray + val uses1 = n1.unsafeUsesArray + val args1 = n1.unsafeArgsArray + p.release(n1) + + // Change shapes so that buckets don't match; identities should differ + val n2 = p.acquire(nodeType = "mul", definitions = slots(2), uses = slots(1), args = calls(1)) + assertFalse(defs1 eq n2.unsafeDefinitionsArray) + assertTrue(uses1 eq n2.unsafeUsesArray) // same length 1 -> should reuse + assertFalse(args1 eq n2.unsafeArgsArray) + + p.release(n2) + } + + @Test + def zeroLengthArraysUseSharedEmptyButStillReuseNode(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = "leaf", definitions = slots(0), uses = slots(0), args = calls(0)) + val defs1 = n1.unsafeDefinitionsArray + val uses1 = n1.unsafeUsesArray + val args1 = n1.unsafeArgsArray + assertEquals(0, defs1.length) + assertEquals(0, uses1.length) + assertEquals(0, args1.length) + + p.release(n1) + + val n2 = p.acquire(nodeType = "leaf", definitions = slots(0), uses = slots(0), args = calls(0)) + // Node object should be reused even though arrays are zero-length + assertTrue(n1 eq n2) + + // Zero-length arrays are the canonical empty arrays (identity may or may not match across acquires), + // but lengths must be 0 and nothing should crash. + assertEquals(0, n2.unsafeDefinitionsArray.length) + assertEquals(0, n2.unsafeUsesArray.length) + assertEquals(0, n2.unsafeArgsArray.length) + + p.release(n2) + } + + @Test + def pooledNodeBehavesLikeRegularRenameAndEquality(): Unit = { + val p = ENode.newPool(64) + + val x = Slot.fresh() + val y = Slot.fresh() + + val pooled = p.acquire(nodeType = 0, definitions = Seq(x), uses = Seq.empty, args = Seq.empty) + // Rename on a pooled node should behave identically to a regular node + val renamed = pooled.rename(collections.SlotMap.from(x -> y)) + assert(renamed == ENode(0, Seq(y), Seq.empty, Seq.empty)) + + // Release only the pooled node; renamed is a fresh ENode not owned by the pool + p.release(pooled) + } + + @Test + def pooledNodeAsShapeCallRoundTrip(): Unit = { + val p = ENode.newPool(64) + + val x = Slot.fresh(); val y = Slot.fresh(); val z = Slot.fresh(); + val w = Slot.fresh(); val v = Slot.fresh() + val c = new EClassRef(42) + + val pooled: ENode[Int] = p.acquire( + nodeType = 0, + definitions = Seq(x, y), + uses = Seq(z), + args = Seq(EClassCall(c, collections.SlotMap.from(w -> v))) + ) + + val call@ShapeCall(shape, args) = pooled.asShapeCall + assert(shape == ENode(0, + Seq(Slot.numeric(0), Slot.numeric(1)), + Seq(Slot.numeric(2)), + Seq(EClassCall(c, collections.SlotMap.from(w -> Slot.numeric(3)))) + )) + assert(args(Slot.numeric(0)) == x) + assert(args(Slot.numeric(1)) == y) + assert(args(Slot.numeric(2)) == z) + assert(args(Slot.numeric(3)) == v) + assert(call.asNode == pooled) + + p.release(pooled) + } + + @Test + def interleavedAcquireReleaseReusesCorrectBuckets(): Unit = { + val p = ENode.newPool(64) + + // Acquire A (1,0,0) + val a = p.acquire(nodeType = "A", definitions = slots(1), uses = slots(0), args = calls(0)) + val aDefs = a.unsafeDefinitionsArray + // Acquire B (2,1,3) + val b = p.acquire(nodeType = "B", definitions = slots(2), uses = slots(1), args = calls(3)) + val bDefs = b.unsafeDefinitionsArray; val bUses = b.unsafeUsesArray; val bArgs = b.unsafeArgsArray + + // Release in reverse order + p.release(b) + p.release(a) + + // Reacquire C matching B's shapes + val c = p.acquire(nodeType = "C", definitions = slots(2), uses = slots(1), args = calls(3)) + assertTrue(bDefs eq c.unsafeDefinitionsArray) + assertTrue("Uses array reuses one of the previously released length-1 slot arrays", (bUses eq c.unsafeUsesArray) || (aDefs eq c.unsafeUsesArray)) + assertTrue(bArgs eq c.unsafeArgsArray) + + // Reacquire D matching A's shapes + val d = p.acquire(nodeType = "D", definitions = slots(1), uses = slots(0), args = calls(0)) + assertTrue("Definitions array reuses one of the previously released length-1 slot arrays", (aDefs eq d.unsafeDefinitionsArray) || (bUses eq d.unsafeDefinitionsArray)) + + p.release(c); p.release(d) + } + + @Test + def nodeTypeCanChangeAcrossReuseAndAccessorRemainsCorrect(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = 0, definitions = slots(1), uses = slots(0), args = calls(0)) + val sameObj1 = n1 + assertEquals(0, sameObj1.nodeType) + p.release(n1) + + // Reuse same object but with a different nodeType type (String) + val n2 = p.acquire(nodeType = "op", definitions = slots(1), uses = slots(0), args = calls(0)) + assertTrue("Object identity reused", (sameObj1 eq n2)) + assertEquals("op", n2.nodeType) + p.release(n2) + } + + @Test + def equalsAndHashRemainStableAfterReuse(): Unit = { + val p = ENode.newPool(64) + + val x = Slot.fresh(); val y = Slot.fresh(); val z = Slot.fresh() + + val n1 = p.acquire(nodeType = 1, definitions = Seq(x, y), uses = Seq(z), args = Seq.empty) + val expected = ENode(1, Seq(x, y), Seq(z), Seq.empty) + assert(n1 == expected) + val h1 = n1.hashCode + p.release(n1) + + val n2 = p.acquire(nodeType = 1, definitions = Seq(x, y), uses = Seq(z), args = Seq.empty) + assert(n2 == expected) + val h2 = n2.hashCode + + // Not guaranteed to be identical across implementations, but should be consistent for equal nodes + assertEquals(expected.hashCode, h2) + + p.release(n2) + } + + @Test + def poolIgnoresForeignNodesOnRelease(): Unit = { + val p = ENode.newPool(64) + + // Create a non-pooled node + val foreign = ENode("x", Seq.empty, Seq.empty, Seq.empty) + + // Releasing should be a no-op (and must not throw) + p.release(foreign) + + // Pool still hands out valid nodes afterwards + val n = p.acquire(nodeType = "y", definitions = slots(1), uses = slots(0), args = calls(0)) + assertNotNull(n) + p.release(n) + } +} diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala index 5e18370e..480e8fed 100644 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala @@ -2,7 +2,7 @@ package foresight.eqsat.commands import foresight.eqsat.collections.SlotSeq import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} +import foresight.eqsat.{CallTree, EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} import foresight.eqsat.immutable.EGraph import org.junit.Test @@ -186,7 +186,7 @@ class CommandScheduleBuilderTest { // Use addSimplifiedReal on an Atom (real call). val sym = baseBuilder.addSimplifiedReal( - MixedTree.Atom[Int, EClassCall](realCall), + realCall, g1 ) @@ -206,7 +206,7 @@ class CommandScheduleBuilderTest { val g0 = EGraph.empty[Int] // Build a trivial 1-node tree: Node(7, [], [], []) - val tree = MixedTree.Node[Int, EClassCall](7, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty[MixedTree[Int, EClassCall]]) + val tree = CallTree.Node[Int](7, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty[CallTree[Int]]) val sym = builder.addSimplifiedReal(tree, g0) // We only care that exactly one node gets scheduled and applying the schedule grows the graph by 1. @@ -240,14 +240,14 @@ class CommandScheduleBuilderTest { val realLeaf: EClassCall = g1.canonicalize(g1.classes.head) // child = Node(10, [], [], [ Atom(realLeaf) ]) - val child = MixedTree.Node[Int, EClassCall]( + val child = CallTree.Node[Int]( 10, SlotSeq.empty, SlotSeq.empty, - ArraySeq(MixedTree.Atom[Int, EClassCall](realLeaf)) + ArraySeq(realLeaf) ) // parent = Node(20, [], [], [ child ]) - val parent = MixedTree.Node[Int, EClassCall]( + val parent = CallTree.Node[Int]( 20, SlotSeq.empty, SlotSeq.empty, @@ -296,29 +296,29 @@ class CommandScheduleBuilderTest { } // Level 1 nodes (directly depend on real atoms) - val child1 = MixedTree.Node[Int, EClassCall]( + val child1 = CallTree.Node[Int]( 201, SlotSeq.empty, SlotSeq.empty, - ArraySeq(MixedTree.Atom[Int, EClassCall](realA)) + ArraySeq(realA) ) - val child2 = MixedTree.Node[Int, EClassCall]( + val child2 = CallTree.Node[Int]( 202, SlotSeq.empty, SlotSeq.empty, - ArraySeq(MixedTree.Atom[Int, EClassCall](realA), MixedTree.Atom[Int, EClassCall](realB)) + ArraySeq(realA, realB) ) val childSym1 = builder.addSimplifiedReal(child1, g1) val childSym2 = builder.addSimplifiedReal(child2, g1) // Level 2 nodes (depend on Level 1) - val parent1 = MixedTree.Node[Int, EClassCall]( + val parent1 = CallTree.Node[Int]( 301, SlotSeq.empty, SlotSeq.empty, - ArraySeq(child1, MixedTree.Atom[Int, EClassCall](realB)) + ArraySeq(child1, realB) ) - val parent2 = MixedTree.Node[Int, EClassCall]( + val parent2 = CallTree.Node[Int]( 302, SlotSeq.empty, SlotSeq.empty, @@ -328,7 +328,7 @@ class CommandScheduleBuilderTest { val parentSym2 = builder.addSimplifiedReal(parent2, g1) // Level 3 node (root depends on both Level 2 nodes) - val root = MixedTree.Node[Int, EClassCall]( + val root = CallTree.Node[Int]( 401, SlotSeq.empty, SlotSeq.empty, diff --git a/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala b/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala index 0caf7cca..7d436bd5 100644 --- a/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala @@ -63,8 +63,8 @@ class MutableMachineStateTest { val m = MutableMachineState[Any](root, effects) // Values for the variables - val mt1: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] - val mt2: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt1: EClassCall = null.asInstanceOf[EClassCall] + val mt2: EClassCall = null.asInstanceOf[EClassCall] m.bindVar(mt1) m.bindVar(mt2) @@ -170,8 +170,8 @@ class MutableMachineStateTest { val m = MutableMachineState[Any](root, effects) // Bind two variables - val mt1: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] - val mt2: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt1: EClassCall = null.asInstanceOf[EClassCall] + val mt2: EClassCall = null.asInstanceOf[EClassCall] m.bindVar(mt1) m.bindVar(mt2) @@ -269,7 +269,7 @@ class MutableMachineStateTest { val ms1 = pool.borrow(root1) // Mutate state to non-zero indices - val mt: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt: EClassCall = null.asInstanceOf[EClassCall] ms1.bindVar(mt) ms1.bindVar(mt) assertEquals(2, ms1.boundVarsCount) @@ -390,8 +390,8 @@ class MutableMachineStateTest { val ms = pool.borrow(root) // Populate state: two vars and one node with one slot use - val mt1: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] - val mt2: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt1: EClassCall = null.asInstanceOf[EClassCall] + val mt2: EClassCall = null.asInstanceOf[EClassCall] ms.bindVar(mt1) ms.bindVar(mt2)