From d878703daa5469efffe5f3d45f2506429052a712 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:17:32 -0500 Subject: [PATCH 01/23] Optimize ENode memory management with an optional pooling mechanism --- .../main/scala/foresight/eqsat/ENode.scala | 219 ++++++++++++++++- .../scala/foresight/eqsat/ENodePoolTest.scala | 226 ++++++++++++++++++ 2 files changed, 432 insertions(+), 13 deletions(-) create mode 100644 foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala diff --git a/foresight/src/main/scala/foresight/eqsat/ENode.scala b/foresight/src/main/scala/foresight/eqsat/ENode.scala index fec0fdfe..1cd5d1bf 100644 --- a/foresight/src/main/scala/foresight/eqsat/ENode.scala +++ b/foresight/src/main/scala/foresight/eqsat/ENode.scala @@ -18,18 +18,24 @@ import scala.collection.compat.immutable.ArraySeq * * Node types (`NodeT`) supply the operator and any non-structural payload. Slots and arguments are provided here. * - * @param nodeType Operator or symbol for this node. * @tparam NodeT The domain-specific node type. It defines operator identity and payload but not slots/children. */ final class ENode[+NodeT] private ( - val nodeType: NodeT, - private val _definitions: Array[Slot], - private val _uses: Array[Slot], - private val _args: Array[EClassCall] + private var _nodeType: Any, + private var _definitions: Array[Slot], + private var _uses: Array[Slot], + private var _args: Array[EClassCall] ) extends Node[NodeT, EClassCall] with ENodeSymbol[NodeT] { // Cached hash code to make hashing and equality fast (benign data race; like String.hash) private var _hash: Int = 0 + /** + * The operator or symbol for this node. + * + * @return The node type. + */ + def nodeType: NodeT = _nodeType.asInstanceOf[NodeT] + /** * Slots introduced by this node that are scoped locally and invisible to parents. These are * redundant by construction at the boundary of this node and exist to model binders such as @@ -60,6 +66,16 @@ final class ENode[+NodeT] private ( */ private[eqsat] def unsafeArgsArray: Array[EClassCall] = _args + /** + * Unsafe access to the internal array of definition slots. Do not modify. + */ + private[eqsat] def unsafeDefinitionsArray: Array[Slot] = _definitions + + /** + * Unsafe access to the internal array of use slots. Do not modify. + */ + private[eqsat] def unsafeUsesArray: Array[Slot] = _uses + /** * The total number of slots occurring in this node: definitions, uses, and children’s argument slots. * Includes slots that may be duplicated across these categories. @@ -407,17 +423,194 @@ final class ENode[+NodeT] private ( * Constructors and helpers for [[ENode]]. */ object ENode { - private[eqsat] def arraysEqual[T](a: Array[T], b: Array[T]): Boolean = { - if (a eq b) return true - if (a.length != b.length) return false - var i = 0 - while (i < a.length) { - if (a(i) != b(i)) return false - i += 1 + /** + * Pool for building and recycling `ENode`s with re-usable backing arrays. + * + * This pool both reuses `ENode` objects themselves and their backing arrays for `definitions`, + * `uses`, and `args` to cut allocations on hot paths. + */ + final class Pool private[ENode] ( + private val perBucketCap: Int + ) { + // Buckets indexed by length -> stack of arrays + private val slotBuckets = new java.util.HashMap[Int, java.util.ArrayDeque[Array[Slot]]]() + private val callBuckets = new java.util.HashMap[Int, java.util.ArrayDeque[Array[EClassCall]]]() + // Free-list of reusable ENode objects + private val nodeFree = new java.util.ArrayDeque[ENode[Any]]() + private def borrowNode(): ENode[Any] = { + if (nodeFree.isEmpty) new ENode[Any](null.asInstanceOf[Any], emptySlotArray, emptySlotArray, emptyCallArray) + else nodeFree.removeFirst() + } + + private def prepareNode[NodeT](node: ENode[Any], nodeType: NodeT, defs: Array[Slot], uses: Array[Slot], args: Array[EClassCall]): ENode[NodeT] = { + node._nodeType = nodeType + node._definitions = defs + node._uses = uses + node._args = args + node._hash = 0 // reset cached hash + node.asInstanceOf[ENode[NodeT]] + } + + /** + * Releases an `ENode` back to this pool, along with its backing arrays. + * @param len Length of the array. + * @return An array of the given length. + */ + def acquireSlotArray(len: Int): Array[Slot] = { + if (len == 0) return emptySlotArray + val q = slotBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[Slot]]()) + val arr = if (q.isEmpty) new Array[Slot](len) else q.removeFirst() + arr + } + + /** + * Acquires a call array of the given length from this pool. + * @param len Length of the array. + * @return An array of the given length. + */ + def acquireCallArray(len: Int): Array[EClassCall] = { + if (len == 0) return emptyCallArray + val q = callBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[EClassCall]]()) + val arr = if (q.isEmpty) new Array[EClassCall](len) else q.removeFirst() + arr + } + + /** + * Releases a slot array back to this pool. + * @param arr Array to release. + */ + def releaseSlotArray(arr: Array[Slot]): Unit = { + if (arr eq null) return + val len = arr.length + if (len == 0) return + java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) + val q = slotBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[Slot]]()) + if (q.size() < perBucketCap) q.addFirst(arr) + } + + /** + * Releases a call array back to this pool. + * @param arr Array to release. + */ + def releaseCallArray(arr: Array[EClassCall]): Unit = { + if (arr eq null) return + val len = arr.length + if (len == 0) return + java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) + val q = callBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[EClassCall]]()) + if (q.size() < perBucketCap) q.addFirst(arr) + } + + private def copySlotsIntoPooledArray(slots: Seq[Slot]): Array[Slot] = slots match { + case slotSeq: SlotSeq => + val len = slotSeq.size + if (len == 0) emptySlotArray + else { + val arr = acquireSlotArray(len) + var i = 0 + while (i < len) { arr(i) = slotSeq.unsafeArray(i); i += 1 } + arr + } + case as: scala.collection.compat.immutable.ArraySeq[Slot] if as.unsafeArray.isInstanceOf[Array[Slot]] => + val src = as.unsafeArray.asInstanceOf[Array[Slot]] + val len = src.length + if (len == 0) emptySlotArray + else { + val arr = acquireSlotArray(len) + System.arraycopy(src, 0, arr, 0, len) + arr + } + case _ => + val len = slots.size + if (len == 0) emptySlotArray + else { + val arr = acquireSlotArray(len) + var i = 0 + val it = slots.iterator + while (i < len && it.hasNext) { arr(i) = it.next(); i += 1 } + arr + } + } + + private def copyCallsIntoPooledArray(calls: Seq[EClassCall]): Array[EClassCall] = calls match { + case as: scala.collection.compat.immutable.ArraySeq[EClassCall] if as.unsafeArray.isInstanceOf[Array[EClassCall]] => + val src = as.unsafeArray.asInstanceOf[Array[EClassCall]] + val len = src.length + if (len == 0) emptyCallArray + else { + val arr = acquireCallArray(len) + System.arraycopy(src, 0, arr, 0, len) + arr + } + case _ => + val len = calls.size + if (len == 0) emptyCallArray + else { + val arr = acquireCallArray(len) + var i = 0 + val it = calls.iterator + while (i < len && it.hasNext) { arr(i) = it.next(); i += 1 } + arr + } + } + + /** + * Builds an `ENode` whose backing arrays are owned by this pool and therefore recyclable via [[release]]. + */ + def acquire[NodeT](nodeType: NodeT, definitions: Seq[Slot], uses: Seq[Slot], args: Seq[EClassCall]): ENode[NodeT] = { + val defsArr = copySlotsIntoPooledArray(definitions) + val usesArr = copySlotsIntoPooledArray(uses) + val argsArr = copyCallsIntoPooledArray(args) + acquireUnsafe(nodeType, defsArr, usesArr, argsArr) + } + + /** + * Builds an `ENode` whose backing arrays are owned by this pool and therefore recyclable via [[release]]. + * + * The caller must ensure that the provided arrays are not used or retained after calling this method. + */ + private[eqsat] def acquireUnsafe[NodeT](nodeType: NodeT, definitions: Array[Slot], uses: Array[Slot], args: Array[EClassCall]): ENode[NodeT] = { + val nAny = borrowNode() + val n = prepareNode(nAny, nodeType, definitions, uses, args) + n + } + + /** + * Returns `node` to this pool. After calling this, the caller must not keep or use `node`. + */ + def release(node: ENode[_]): Unit = { + // Return backing arrays if they belong to this pool + releaseSlotArray(node.unsafeDefinitionsArray) + releaseSlotArray(node.unsafeUsesArray) + releaseCallArray(node.unsafeArgsArray) + + val nAny = node.asInstanceOf[ENode[Any]] + // Clear fields to avoid retaining references and to mark as blank + nAny._nodeType = null + nAny._definitions = emptySlotArray + nAny._uses = emptySlotArray + nAny._args = emptyCallArray + nAny._hash = 0 + if (nodeFree.size() < perBucketCap) nodeFree.addFirst(nAny) + } + + /** Clears all buckets. */ + def clear(): Unit = { + slotBuckets.clear() + callBuckets.clear() + nodeFree.clear() } - true } + /** Thread-local default pool for lightweight reuse without wiring one through call sites. */ + private val threadLocalDefaultPool = new ThreadLocal[Pool] { override def initialValue(): Pool = new Pool(64) } + + /** Creates a new pool with the given per-bucket capacity. */ + def newPool(perBucketCap: Int = 64): Pool = new Pool(perBucketCap) + + /** Gets the thread-local default pool. */ + def defaultPool: Pool = threadLocalDefaultPool.get() + private val emptySlotArray: Array[Slot] = Array.empty private val emptyCallArray: Array[EClassCall] = Array.empty diff --git a/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala b/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala new file mode 100644 index 00000000..81936398 --- /dev/null +++ b/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala @@ -0,0 +1,226 @@ +package foresight.eqsat + +import org.junit.Test +import org.junit.Assert._ +import scala.collection.compat.immutable.ArraySeq + +class ENodePoolTest { + // Helpers to build sized sequences without needing concrete Slot/EClassCall values + private def slots(n: Int): Seq[Slot] = { + // Backed by a real Array[Slot] of length n; elements are null and that's fine for pooling + ArraySeq.unsafeWrapArray(new Array[Slot](n)) + } + private def calls(n: Int): Seq[EClassCall] = { + ArraySeq.unsafeWrapArray(new Array[EClassCall](n)) + } + + @Test + def createPool(): Unit = { + val p = ENode.newPool(8) + assertNotNull(p) + val d = ENode.defaultPool + assertNotNull(d) + assertNotSame("newPool returns a different instance than the thread-local default", p, d) + } + + @Test + def acquireReleaseReacquireReusesNodeObjectWhenShapesMatch(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = "add", definitions = slots(2), uses = slots(1), args = calls(3)) + val defs1 = n1.unsafeDefinitionsArray + val uses1 = n1.unsafeUsesArray + val args1 = n1.unsafeArgsArray + + // Release to the pool + p.release(n1) + + // Re-acquire with the same shapes; node object and arrays should be reused + val n2 = p.acquire(nodeType = "add", definitions = slots(2), uses = slots(1), args = calls(3)) + + // The node instance itself should be reused + assertTrue("Node object is expected to be reused", (n1 eq n2)) + + // The backing arrays should be reused by identity when lengths match + assertTrue("Definitions array should be reused by identity", (defs1 eq n2.unsafeDefinitionsArray)) + assertTrue("Uses array should be reused by identity", (uses1 eq n2.unsafeUsesArray)) + assertTrue("Args array should be reused by identity", (args1 eq n2.unsafeArgsArray)) + + // Clean up + p.release(n2) + } + + @Test + def acquireWithDifferentLengthsDoesNotReuseArraysAcrossBuckets(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = "mul", definitions = slots(1), uses = slots(1), args = calls(2)) + val defs1 = n1.unsafeDefinitionsArray + val uses1 = n1.unsafeUsesArray + val args1 = n1.unsafeArgsArray + p.release(n1) + + // Change shapes so that buckets don't match; identities should differ + val n2 = p.acquire(nodeType = "mul", definitions = slots(2), uses = slots(1), args = calls(1)) + assertFalse(defs1 eq n2.unsafeDefinitionsArray) + assertTrue(uses1 eq n2.unsafeUsesArray) // same length 1 -> should reuse + assertFalse(args1 eq n2.unsafeArgsArray) + + p.release(n2) + } + + @Test + def zeroLengthArraysUseSharedEmptyButStillReuseNode(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = "leaf", definitions = slots(0), uses = slots(0), args = calls(0)) + val defs1 = n1.unsafeDefinitionsArray + val uses1 = n1.unsafeUsesArray + val args1 = n1.unsafeArgsArray + assertEquals(0, defs1.length) + assertEquals(0, uses1.length) + assertEquals(0, args1.length) + + p.release(n1) + + val n2 = p.acquire(nodeType = "leaf", definitions = slots(0), uses = slots(0), args = calls(0)) + // Node object should be reused even though arrays are zero-length + assertTrue(n1 eq n2) + + // Zero-length arrays are the canonical empty arrays (identity may or may not match across acquires), + // but lengths must be 0 and nothing should crash. + assertEquals(0, n2.unsafeDefinitionsArray.length) + assertEquals(0, n2.unsafeUsesArray.length) + assertEquals(0, n2.unsafeArgsArray.length) + + p.release(n2) + } + + @Test + def pooledNodeBehavesLikeRegularRenameAndEquality(): Unit = { + val p = ENode.newPool(64) + + val x = Slot.fresh() + val y = Slot.fresh() + + val pooled = p.acquire(nodeType = 0, definitions = Seq(x), uses = Seq.empty, args = Seq.empty) + // Rename on a pooled node should behave identically to a regular node + val renamed = pooled.rename(collections.SlotMap.from(x -> y)) + assert(renamed == ENode(0, Seq(y), Seq.empty, Seq.empty)) + + // Release only the pooled node; renamed is a fresh ENode not owned by the pool + p.release(pooled) + } + + @Test + def pooledNodeAsShapeCallRoundTrip(): Unit = { + val p = ENode.newPool(64) + + val x = Slot.fresh(); val y = Slot.fresh(); val z = Slot.fresh(); + val w = Slot.fresh(); val v = Slot.fresh() + val c = new EClassRef(42) + + val pooled = p.acquire( + nodeType = 0, + definitions = Seq(x, y), + uses = Seq(z), + args = Seq(EClassCall(c, collections.SlotMap.from(w -> v))) + ) + + val call@ShapeCall(shape, args) = pooled.asShapeCall + assert(shape == ENode(0, + Seq(Slot.numeric(0), Slot.numeric(1)), + Seq(Slot.numeric(2)), + Seq(EClassCall(c, collections.SlotMap.from(w -> Slot.numeric(3)))) + )) + assert(args(Slot.numeric(0)) == x) + assert(args(Slot.numeric(1)) == y) + assert(args(Slot.numeric(2)) == z) + assert(args(Slot.numeric(3)) == v) + assert(call.asNode == pooled) + + p.release(pooled) + } + + @Test + def interleavedAcquireReleaseReusesCorrectBuckets(): Unit = { + val p = ENode.newPool(64) + + // Acquire A (1,0,0) + val a = p.acquire(nodeType = "A", definitions = slots(1), uses = slots(0), args = calls(0)) + val aDefs = a.unsafeDefinitionsArray + // Acquire B (2,1,3) + val b = p.acquire(nodeType = "B", definitions = slots(2), uses = slots(1), args = calls(3)) + val bDefs = b.unsafeDefinitionsArray; val bUses = b.unsafeUsesArray; val bArgs = b.unsafeArgsArray + + // Release in reverse order + p.release(b) + p.release(a) + + // Reacquire C matching B's shapes + val c = p.acquire(nodeType = "C", definitions = slots(2), uses = slots(1), args = calls(3)) + assertTrue(bDefs eq c.unsafeDefinitionsArray) + assertTrue("Uses array reuses one of the previously released length-1 slot arrays", (bUses eq c.unsafeUsesArray) || (aDefs eq c.unsafeUsesArray)) + assertTrue(bArgs eq c.unsafeArgsArray) + + // Reacquire D matching A's shapes + val d = p.acquire(nodeType = "D", definitions = slots(1), uses = slots(0), args = calls(0)) + assertTrue("Definitions array reuses one of the previously released length-1 slot arrays", (aDefs eq d.unsafeDefinitionsArray) || (bUses eq d.unsafeDefinitionsArray)) + + p.release(c); p.release(d) + } + + @Test + def nodeTypeCanChangeAcrossReuseAndAccessorRemainsCorrect(): Unit = { + val p = ENode.newPool(64) + + val n1 = p.acquire(nodeType = 0, definitions = slots(1), uses = slots(0), args = calls(0)) + val sameObj1 = n1 + assertEquals(0, sameObj1.nodeType) + p.release(n1) + + // Reuse same object but with a different nodeType type (String) + val n2 = p.acquire(nodeType = "op", definitions = slots(1), uses = slots(0), args = calls(0)) + assertTrue("Object identity reused", (sameObj1 eq n2)) + assertEquals("op", n2.nodeType) + p.release(n2) + } + + @Test + def equalsAndHashRemainStableAfterReuse(): Unit = { + val p = ENode.newPool(64) + + val x = Slot.fresh(); val y = Slot.fresh(); val z = Slot.fresh() + + val n1 = p.acquire(nodeType = 1, definitions = Seq(x, y), uses = Seq(z), args = Seq.empty) + val expected = ENode(1, Seq(x, y), Seq(z), Seq.empty) + assert(n1 == expected) + val h1 = n1.hashCode + p.release(n1) + + val n2 = p.acquire(nodeType = 1, definitions = Seq(x, y), uses = Seq(z), args = Seq.empty) + assert(n2 == expected) + val h2 = n2.hashCode + + // Not guaranteed to be identical across implementations, but should be consistent for equal nodes + assertEquals(expected.hashCode, h2) + + p.release(n2) + } + + @Test + def poolIgnoresForeignNodesOnRelease(): Unit = { + val p = ENode.newPool(64) + + // Create a non-pooled node + val foreign = ENode("x", Seq.empty, Seq.empty, Seq.empty) + + // Releasing should be a no-op (and must not throw) + p.release(foreign) + + // Pool still hands out valid nodes afterwards + val n = p.acquire(nodeType = "y", definitions = slots(1), uses = slots(0), args = calls(0)) + assertNotNull(n) + p.release(n) + } +} From 4435469f8035f110557380ed50c954491fd1ccf6 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:33:48 -0500 Subject: [PATCH 02/23] Refactor ENode array allocation to use dedicated delegates --- .../main/scala/foresight/eqsat/ENode.scala | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/ENode.scala b/foresight/src/main/scala/foresight/eqsat/ENode.scala index 1cd5d1bf..3f0c24e2 100644 --- a/foresight/src/main/scala/foresight/eqsat/ENode.scala +++ b/foresight/src/main/scala/foresight/eqsat/ENode.scala @@ -423,6 +423,14 @@ final class ENode[+NodeT] private ( * Constructors and helpers for [[ENode]]. */ object ENode { + private val newSlotDequeDelegate = new java.util.function.Function[Int, java.util.ArrayDeque[Array[Slot]]] { + override def apply(t: Int): java.util.ArrayDeque[Array[Slot]] = new java.util.ArrayDeque[Array[Slot]] + } + + private val newCallDequeDelegate = new java.util.function.Function[Int, java.util.ArrayDeque[Array[EClassCall]]] { + override def apply(t: Int): java.util.ArrayDeque[Array[EClassCall]] = new java.util.ArrayDeque[Array[EClassCall]] + } + /** * Pool for building and recycling `ENode`s with re-usable backing arrays. * @@ -436,7 +444,7 @@ object ENode { private val slotBuckets = new java.util.HashMap[Int, java.util.ArrayDeque[Array[Slot]]]() private val callBuckets = new java.util.HashMap[Int, java.util.ArrayDeque[Array[EClassCall]]]() // Free-list of reusable ENode objects - private val nodeFree = new java.util.ArrayDeque[ENode[Any]]() + private val nodeFree = new java.util.ArrayDeque[ENode[Any]]() private def borrowNode(): ENode[Any] = { if (nodeFree.isEmpty) new ENode[Any](null.asInstanceOf[Any], emptySlotArray, emptySlotArray, emptyCallArray) else nodeFree.removeFirst() @@ -458,7 +466,7 @@ object ENode { */ def acquireSlotArray(len: Int): Array[Slot] = { if (len == 0) return emptySlotArray - val q = slotBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[Slot]]()) + val q = slotBuckets.computeIfAbsent(len, newSlotDequeDelegate) val arr = if (q.isEmpty) new Array[Slot](len) else q.removeFirst() arr } @@ -470,7 +478,7 @@ object ENode { */ def acquireCallArray(len: Int): Array[EClassCall] = { if (len == 0) return emptyCallArray - val q = callBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[EClassCall]]()) + val q = callBuckets.computeIfAbsent(len, newCallDequeDelegate) val arr = if (q.isEmpty) new Array[EClassCall](len) else q.removeFirst() arr } @@ -484,7 +492,7 @@ object ENode { val len = arr.length if (len == 0) return java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) - val q = slotBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[Slot]]()) + val q = slotBuckets.computeIfAbsent(len, newSlotDequeDelegate) if (q.size() < perBucketCap) q.addFirst(arr) } @@ -497,7 +505,7 @@ object ENode { val len = arr.length if (len == 0) return java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) - val q = callBuckets.computeIfAbsent(len, _ => new java.util.ArrayDeque[Array[EClassCall]]()) + val q = callBuckets.computeIfAbsent(len, newCallDequeDelegate) if (q.size() < perBucketCap) q.addFirst(arr) } @@ -604,7 +612,7 @@ object ENode { /** Thread-local default pool for lightweight reuse without wiring one through call sites. */ private val threadLocalDefaultPool = new ThreadLocal[Pool] { override def initialValue(): Pool = new Pool(64) } - + /** Creates a new pool with the given per-bucket capacity. */ def newPool(perBucketCap: Int = 64): Pool = new Pool(perBucketCap) From 55b89b1723e7f632992ee4aa0ca7e3f46c9987f5 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:33:56 -0500 Subject: [PATCH 03/23] Specify type for pooled ENode acquisition in ENodePoolTest --- foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala b/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala index 81936398..57427663 100644 --- a/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/ENodePoolTest.scala @@ -120,7 +120,7 @@ class ENodePoolTest { val w = Slot.fresh(); val v = Slot.fresh() val c = new EClassRef(42) - val pooled = p.acquire( + val pooled: ENode[Int] = p.acquire( nodeType = 0, definitions = Seq(x, y), uses = Seq(z), From 41883a1191d5ba1d1e34238a2af02e3f26b1df39 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:34:08 -0500 Subject: [PATCH 04/23] Enhance CommandScheduleBuilder and PatternApplier to utilize ENode pooling for improved memory management --- .../eqsat/commands/CommandScheduleBuilder.scala | 15 +++++++++------ .../eqsat/rewriting/patterns/PatternApplier.scala | 11 ++++++----- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index a8c9fbf2..b4798a7b 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -84,12 +84,13 @@ trait CommandScheduleBuilder[NodeT] { private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], egraph: EGraph[NodeT]): EClassSymbol = { val maxBatch = new IntRef(0) - addSimplifiedReal(tree, egraph, maxBatch) + addSimplifiedReal(tree, egraph, maxBatch, ENode.defaultPool) } private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], egraph: EGraph[NodeT], - maxBatch: IntRef): EClassSymbol = { + maxBatch: IntRef, + pool: ENode.Pool): EClassSymbol = { tree match { case MixedTree.Node(t, defs, uses, args) => // Local accumulator for children of this node. @@ -97,9 +98,9 @@ trait CommandScheduleBuilder[NodeT] { val argSymbols = CommandScheduleBuilder.symbolArrayFrom( args, childMax, - (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb) + (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb, pool) ) - val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph) + val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph, pool) // Propagate maximum required batch up to the caller's accumulator. if (childMax.elem > maxBatch.elem) maxBatch.elem = childMax.elem sym @@ -115,7 +116,8 @@ trait CommandScheduleBuilder[NodeT] { uses: SlotSeq, args: Array[EClassSymbol], maxBatch: IntRef, - egraph: EGraph[NodeT]): EClassSymbol = { + egraph: EGraph[NodeT], + pool: ENode.Pool): EClassSymbol = { // Check if all children are already in the graph. val argCalls = CommandScheduleBuilder.resolveAllOrNull(args) @@ -126,7 +128,7 @@ trait CommandScheduleBuilder[NodeT] { assert(maxBatch.elem == 0) } - val candidateNode = ENode.unsafeWrapArrays(nodeType, definitions, uses, argCalls) + val candidateNode = pool.acquire(nodeType, definitions, uses, argCalls) egraph.findOrNull(candidateNode) match { case null => // Node does not exist in the graph but its children do exist in the graph. @@ -135,6 +137,7 @@ trait CommandScheduleBuilder[NodeT] { case existingCall => // Node already exists in the graph; reuse its class. + pool.release(candidateNode) EClassSymbol.real(existingCall) } } else { diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index 91aedae9..fc5055d6 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -4,7 +4,7 @@ import foresight.eqsat.collections.SlotSeq import foresight.eqsat.commands.{CommandScheduleBuilder, IntRef} import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.{ReversibleApplier, Searcher} -import foresight.eqsat.{EClassSymbol, MixedTree, Slot} +import foresight.eqsat.{EClassSymbol, ENode, MixedTree, Slot} import scala.collection.compat.immutable.ArraySeq @@ -67,7 +67,8 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT private final class SimplifiedAddCommandInstantiator(m: PatternMatch[NodeT], egraph: EGraphT, - builder: CommandScheduleBuilder[NodeT]) { + builder: CommandScheduleBuilder[NodeT], + pool: ENode.Pool) { def instantiate(pattern: MixedTree[NodeT, Pattern.Var], maxBatch: IntRef): EClassSymbol = { pattern match { case MixedTree.Atom(p) => builder.addSimplifiedReal(m(p), egraph) @@ -83,7 +84,7 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) - new SimplifiedAddCommandInstantiator(newMatch, egraph, builder).addSimplifiedNode(t, defSlots, uses, args, maxBatch) + new SimplifiedAddCommandInstantiator(newMatch, egraph, builder, pool).addSimplifiedNode(t, defSlots, uses, args, maxBatch) } } @@ -95,7 +96,7 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT val argMaxBatch = new IntRef(0) val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, instantiate) val useSymbols = uses.map(m.apply: Slot => Slot) - val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph) + val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph, pool) if (argMaxBatch.elem > maxBatch.elem) { maxBatch.elem = argMaxBatch.elem } @@ -108,6 +109,6 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): EClassSymbol = { - new SimplifiedAddCommandInstantiator(m, egraph, builder).instantiate(pattern, new IntRef(0)) + new SimplifiedAddCommandInstantiator(m, egraph, builder, ENode.defaultPool).instantiate(pattern, new IntRef(0)) } } From c405ceeb0433da1fa57d57e33081df6be4e28bbc Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:41:10 -0500 Subject: [PATCH 05/23] Add methods to acquire and fill slot and call arrays from the pool --- .../src/main/scala/foresight/eqsat/ENode.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/foresight/src/main/scala/foresight/eqsat/ENode.scala b/foresight/src/main/scala/foresight/eqsat/ENode.scala index 3f0c24e2..ec3887e7 100644 --- a/foresight/src/main/scala/foresight/eqsat/ENode.scala +++ b/foresight/src/main/scala/foresight/eqsat/ENode.scala @@ -471,6 +471,15 @@ object ENode { arr } + /** + * Acquires a slot array filled with the given slots from this pool. + * @param slots Slots to fill into the array. + * @return An array containing the given slots. + */ + def acquireAndFillSlotArray(slots: Seq[Slot]): Array[Slot] = { + copySlotsIntoPooledArray(slots) + } + /** * Acquires a call array of the given length from this pool. * @param len Length of the array. @@ -483,6 +492,15 @@ object ENode { arr } + /** + * Acquires a call array filled with the given calls from this pool. + * @param calls Calls to fill into the array. + * @return An array containing the given calls. + */ + def acquireAndFillCallArray(calls: Seq[EClassCall]): Array[EClassCall] = { + copyCallsIntoPooledArray(calls) + } + /** * Releases a slot array back to this pool. * @param arr Array to release. From 824cd2570f426c21a090b32698da1ba420b48f95 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:41:16 -0500 Subject: [PATCH 06/23] Refactor candidate node acquisition to use acquireUnsafe and fill slot arrays --- .../foresight/eqsat/commands/CommandScheduleBuilder.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index b4798a7b..6081e6ce 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -128,7 +128,12 @@ trait CommandScheduleBuilder[NodeT] { assert(maxBatch.elem == 0) } - val candidateNode = pool.acquire(nodeType, definitions, uses, argCalls) + val candidateNode = pool.acquireUnsafe( + nodeType, + pool.acquireAndFillSlotArray(definitions), + pool.acquireAndFillSlotArray(uses), + argCalls) + egraph.findOrNull(candidateNode) match { case null => // Node does not exist in the graph but its children do exist in the graph. From 1e975923c4a3716a844a9bb9df0d1fc9a29f2e03 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:47:41 -0500 Subject: [PATCH 07/23] Optimize memory management by replacing HashMap with fixed-size arrays for slot and call buckets in ENode pool --- .../main/scala/foresight/eqsat/ENode.scala | 64 ++++++++++++++----- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/ENode.scala b/foresight/src/main/scala/foresight/eqsat/ENode.scala index ec3887e7..5cedec57 100644 --- a/foresight/src/main/scala/foresight/eqsat/ENode.scala +++ b/foresight/src/main/scala/foresight/eqsat/ENode.scala @@ -438,11 +438,35 @@ object ENode { * `uses`, and `args` to cut allocations on hot paths. */ final class Pool private[ENode] ( - private val perBucketCap: Int + private val perBucketCap: Int, + private val maxBucketLen: Int = 16 ) { // Buckets indexed by length -> stack of arrays - private val slotBuckets = new java.util.HashMap[Int, java.util.ArrayDeque[Array[Slot]]]() - private val callBuckets = new java.util.HashMap[Int, java.util.ArrayDeque[Array[EClassCall]]]() + private val slotBuckets: Array[java.util.ArrayDeque[Array[Slot]]] = + new Array[java.util.ArrayDeque[Array[Slot]]](maxBucketLen + 1) + private val callBuckets: Array[java.util.ArrayDeque[Array[EClassCall]]] = + new Array[java.util.ArrayDeque[Array[EClassCall]]](maxBucketLen + 1) + + // Eagerly initialize deques for all bucket lengths + { + var i = 0 + while (i <= maxBucketLen) { + slotBuckets(i) = new java.util.ArrayDeque[Array[Slot]]() + callBuckets(i) = new java.util.ArrayDeque[Array[EClassCall]]() + i += 1 + } + } + + @inline private def slotDeque(len: Int): java.util.ArrayDeque[Array[Slot]] = { + if (len < 0 || len > maxBucketLen) null + else slotBuckets(len) + } + + @inline private def callDeque(len: Int): java.util.ArrayDeque[Array[EClassCall]] = { + if (len < 0 || len > maxBucketLen) null + else callBuckets(len) + } + // Free-list of reusable ENode objects private val nodeFree = new java.util.ArrayDeque[ENode[Any]]() private def borrowNode(): ENode[Any] = { @@ -466,9 +490,9 @@ object ENode { */ def acquireSlotArray(len: Int): Array[Slot] = { if (len == 0) return emptySlotArray - val q = slotBuckets.computeIfAbsent(len, newSlotDequeDelegate) - val arr = if (q.isEmpty) new Array[Slot](len) else q.removeFirst() - arr + val q = slotDeque(len) + if (q eq null) return new Array[Slot](len) + if (q.isEmpty) new Array[Slot](len) else q.removeFirst() } /** @@ -487,9 +511,9 @@ object ENode { */ def acquireCallArray(len: Int): Array[EClassCall] = { if (len == 0) return emptyCallArray - val q = callBuckets.computeIfAbsent(len, newCallDequeDelegate) - val arr = if (q.isEmpty) new Array[EClassCall](len) else q.removeFirst() - arr + val q = callDeque(len) + if (q eq null) return new Array[EClassCall](len) + if (q.isEmpty) new Array[EClassCall](len) else q.removeFirst() } /** @@ -510,8 +534,8 @@ object ENode { val len = arr.length if (len == 0) return java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) - val q = slotBuckets.computeIfAbsent(len, newSlotDequeDelegate) - if (q.size() < perBucketCap) q.addFirst(arr) + val q = slotDeque(len) + if ((q ne null) && q.size() < perBucketCap) q.addFirst(arr) } /** @@ -523,8 +547,8 @@ object ENode { val len = arr.length if (len == 0) return java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) - val q = callBuckets.computeIfAbsent(len, newCallDequeDelegate) - if (q.size() < perBucketCap) q.addFirst(arr) + val q = callDeque(len) + if ((q ne null) && q.size() < perBucketCap) q.addFirst(arr) } private def copySlotsIntoPooledArray(slots: Seq[Slot]): Array[Slot] = slots match { @@ -622,8 +646,18 @@ object ENode { /** Clears all buckets. */ def clear(): Unit = { - slotBuckets.clear() - callBuckets.clear() + var i = 0 + while (i < slotBuckets.length) { + val q = slotBuckets(i) + if (q != null) q.clear() + i += 1 + } + i = 0 + while (i < callBuckets.length) { + val q = callBuckets(i) + if (q != null) q.clear() + i += 1 + } nodeFree.clear() } } From 98bc09388a41b9295dc72d50511d2c0bc0610355 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:52:40 -0500 Subject: [PATCH 08/23] Update symbolArrayFrom to utilize ENode pool for memory-efficient array allocation --- .../eqsat/commands/CommandScheduleBuilder.scala | 9 ++++++--- .../eqsat/rewriting/patterns/PatternApplier.scala | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index 6081e6ce..f93cadd4 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -98,6 +98,7 @@ trait CommandScheduleBuilder[NodeT] { val argSymbols = CommandScheduleBuilder.symbolArrayFrom( args, childMax, + pool, (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb, pool) ) val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph, pool) @@ -165,14 +166,14 @@ object CommandScheduleBuilder { */ def newConcurrentBuilder[NodeT]: CommandScheduleBuilder[NodeT] = new ConcurrentCommandScheduleBuilder[NodeT]() - private[eqsat] def symbolArrayFrom[A](values: ArraySeq[A], maxBatch: IntRef, valueToSymbol: (A, IntRef) => EClassSymbol): Array[EClassSymbol] = { + private[eqsat] def symbolArrayFrom[A](values: ArraySeq[A], maxBatch: IntRef, pool: ENode.Pool, valueToSymbol: (A, IntRef) => EClassSymbol): Array[EClassSymbol] = { // Try to avoid allocating an array of EClassSymbol if all entries are EClassCall. // The common case is that all children are already in the e-graph, and we will // want to construct an ENode with an Array[EClassCall]. // If we find any entry that is not an EClassCall, we fall back to allocating // an Array[EClassSymbol] and copying the prefix of calls. val n = values.length - val calls = new Array[EClassCall](n) + val calls = pool.acquireCallArray(n) var i = 0 while (i < n) { valueToSymbol(values(i), maxBatch) match { @@ -183,8 +184,10 @@ object CommandScheduleBuilder { val syms = new Array[EClassSymbol](n) var j = 0 while (j < i) { - syms(j) = calls(j); j += 1 + syms(j) = calls(j) + j += 1 } + pool.releaseCallArray(calls) syms(i) = other j = i + 1 while (j < n) { diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index fc5055d6..5176d6f5 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -94,7 +94,7 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT args: ArraySeq[MixedTree[NodeT, Pattern.Var]], maxBatch: IntRef): EClassSymbol = { val argMaxBatch = new IntRef(0) - val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, instantiate) + val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, pool, instantiate) val useSymbols = uses.map(m.apply: Slot => Slot) val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph, pool) if (argMaxBatch.elem > maxBatch.elem) { From 5f458a6a42885e5bc7fd44e115ac71995e78d85d Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 10:59:47 -0500 Subject: [PATCH 09/23] Refactor CommandScheduleBuilder and PatternApplier to utilize IntRef pooling for improved memory management --- .../commands/CommandScheduleBuilder.scala | 19 +++++---- .../foresight/eqsat/commands/IntRef.scala | 42 +++++++++++++++++++ .../rewriting/patterns/PatternApplier.scala | 19 ++++++--- 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index f93cadd4..66ff33a7 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -83,27 +83,32 @@ trait CommandScheduleBuilder[NodeT] { private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], egraph: EGraph[NodeT]): EClassSymbol = { - val maxBatch = new IntRef(0) - addSimplifiedReal(tree, egraph, maxBatch, ENode.defaultPool) + val refPool = IntRef.defaultPool + val maxBatch = refPool.acquire(0) + val result = addSimplifiedReal(tree, egraph, maxBatch, ENode.defaultPool, refPool) + refPool.release(maxBatch) + result } private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], egraph: EGraph[NodeT], maxBatch: IntRef, - pool: ENode.Pool): EClassSymbol = { + nodePool: ENode.Pool, + refPool: IntRef.Pool): EClassSymbol = { tree match { case MixedTree.Node(t, defs, uses, args) => // Local accumulator for children of this node. - val childMax = new IntRef(0) + val childMax = refPool.acquire(0) val argSymbols = CommandScheduleBuilder.symbolArrayFrom( args, childMax, - pool, - (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb, pool) + nodePool, + (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb, nodePool, refPool) ) - val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph, pool) + val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph, nodePool) // Propagate maximum required batch up to the caller's accumulator. if (childMax.elem > maxBatch.elem) maxBatch.elem = childMax.elem + refPool.release(childMax) sym case MixedTree.Atom(call) => diff --git a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala index f0992cd6..3cde8dcc 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala @@ -1,3 +1,45 @@ package foresight.eqsat.commands private[eqsat] final class IntRef(var elem: Int) + +object IntRef { + /** + * Pool of reusable [[IntRef]] instances. + * + * Usage: + * val r = IntRef.acquire() // from the default thread-local pool + * r.elem = 42 + * IntRef.release(r) // return to pool + */ + final class Pool { + // LIFO stack to maximize cache locality + private val free = new java.util.ArrayDeque[IntRef]() + + /** Acquire an IntRef, initializing its value. */ + @inline def acquire(initial: Int): IntRef = { + val ref = free.pollFirst() + if (ref eq null) new IntRef(initial) + else { ref.elem = initial; ref } + } + + /** Return an IntRef to the pool for reuse. */ + @inline def release(ref: IntRef): Unit = { + // no double-free tracking for performance; callers ensure discipline + free.addFirst(ref) + } + + /** Number of currently stored reusable objects. */ + @inline def size: Int = free.size() + + /** Drop all cached objects. */ + def clear(): Unit = free.clear() + } + + // Default thread-local pool. + private val threadLocal: ThreadLocal[Pool] = new ThreadLocal[Pool] { + override def initialValue(): Pool = new Pool + } + + /** Access the default thread-local pool for the current thread. */ + @inline def defaultPool: Pool = threadLocal.get() +} \ No newline at end of file diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index 5176d6f5..d610be99 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -68,7 +68,8 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT private final class SimplifiedAddCommandInstantiator(m: PatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT], - pool: ENode.Pool) { + nodePool: ENode.Pool, + refPool: IntRef.Pool) { def instantiate(pattern: MixedTree[NodeT, Pattern.Var], maxBatch: IntRef): EClassSymbol = { pattern match { case MixedTree.Atom(p) => builder.addSimplifiedReal(m(p), egraph) @@ -84,7 +85,8 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) - new SimplifiedAddCommandInstantiator(newMatch, egraph, builder, pool).addSimplifiedNode(t, defSlots, uses, args, maxBatch) + new SimplifiedAddCommandInstantiator(newMatch, egraph, builder, nodePool, refPool) + .addSimplifiedNode(t, defSlots, uses, args, maxBatch) } } @@ -93,13 +95,14 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT uses: SlotSeq, args: ArraySeq[MixedTree[NodeT, Pattern.Var]], maxBatch: IntRef): EClassSymbol = { - val argMaxBatch = new IntRef(0) - val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, pool, instantiate) + val argMaxBatch = refPool.acquire(0) + val argSymbols = CommandScheduleBuilder.symbolArrayFrom(args, argMaxBatch, nodePool, instantiate) val useSymbols = uses.map(m.apply: Slot => Slot) - val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph, pool) + val result = builder.addSimplifiedNode(nodeType, definitions, useSymbols, argSymbols, argMaxBatch, egraph, nodePool) if (argMaxBatch.elem > maxBatch.elem) { maxBatch.elem = argMaxBatch.elem } + refPool.release(argMaxBatch) result } } @@ -109,6 +112,10 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): EClassSymbol = { - new SimplifiedAddCommandInstantiator(m, egraph, builder, ENode.defaultPool).instantiate(pattern, new IntRef(0)) + val refPool = IntRef.defaultPool + val ref = refPool.acquire(0) + val result = new SimplifiedAddCommandInstantiator(m, egraph, builder, ENode.defaultPool, refPool).instantiate(pattern, ref) + refPool.release(ref) + result } } From 7b096da1b5cb577663c1fbb188b4ed737f9f0364 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 11:07:19 -0500 Subject: [PATCH 10/23] Limit IntRef pool size to 64 for improved memory management and cache locality --- .../src/main/scala/foresight/eqsat/commands/IntRef.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala index 3cde8dcc..8f5dda17 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/IntRef.scala @@ -12,8 +12,9 @@ object IntRef { * IntRef.release(r) // return to pool */ final class Pool { + private final val maxSize = 64 // LIFO stack to maximize cache locality - private val free = new java.util.ArrayDeque[IntRef]() + private val free = new java.util.ArrayDeque[IntRef](maxSize) /** Acquire an IntRef, initializing its value. */ @inline def acquire(initial: Int): IntRef = { @@ -24,6 +25,7 @@ object IntRef { /** Return an IntRef to the pool for reuse. */ @inline def release(ref: IntRef): Unit = { + if (free.size() >= maxSize) return // no double-free tracking for performance; callers ensure discipline free.addFirst(ref) } From 10caa18d50b2bb3549a89e0178f27939fc1f6ee5 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 11:15:29 -0500 Subject: [PATCH 11/23] Implement object pooling for SimplifiedAddCommandInstantiator to enhance memory management and reduce garbage collection overhead --- .../rewriting/patterns/PatternApplier.scala | 77 +++++++++++++++++-- 1 file changed, 69 insertions(+), 8 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index d610be99..8dec1547 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -7,6 +7,8 @@ import foresight.eqsat.rewriting.{ReversibleApplier, Searcher} import foresight.eqsat.{EClassSymbol, ENode, MixedTree, Slot} import scala.collection.compat.immutable.ArraySeq +import java.util.ArrayDeque +import scala.compiletime.uninitialized /** * An applier that applies a pattern match to an e-graph. @@ -65,11 +67,35 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } - private final class SimplifiedAddCommandInstantiator(m: PatternMatch[NodeT], - egraph: EGraphT, - builder: CommandScheduleBuilder[NodeT], - nodePool: ENode.Pool, - refPool: IntRef.Pool) { + private final class SimplifiedAddCommandInstantiator { + // Mutable state to allow pooling; initialized via init() before use. + private var m: PatternMatch[NodeT] = _ + private var egraph: EGraphT = _ + private var builder: CommandScheduleBuilder[NodeT] = _ + private var nodePool: ENode.Pool = _ + private var refPool: IntRef.Pool = _ + + def init(m0: PatternMatch[NodeT], + egraph0: EGraphT, + builder0: CommandScheduleBuilder[NodeT], + nodePool0: ENode.Pool, + refPool0: IntRef.Pool): Unit = { + this.m = m0 + this.egraph = egraph0 + this.builder = builder0 + this.nodePool = nodePool0 + this.refPool = refPool0 + } + + def clear(): Unit = { + // release references to help GC when pooled instances sit around + this.m = null + this.egraph = null.asInstanceOf[EGraphT] + this.builder = null + this.nodePool = null + this.refPool = null + } + def instantiate(pattern: MixedTree[NodeT, Pattern.Var], maxBatch: IntRef): EClassSymbol = { pattern match { case MixedTree.Atom(p) => builder.addSimplifiedReal(m(p), egraph) @@ -85,8 +111,16 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) - new SimplifiedAddCommandInstantiator(newMatch, egraph, builder, nodePool, refPool) - .addSimplifiedNode(t, defSlots, uses, args, maxBatch) + + // Acquire a nested instantiator. + val nested = SimplifiedAddCommandInstantiator.acquire() + nested.init(newMatch, egraph, builder, nodePool, refPool) + try { + nested.addSimplifiedNode(t, defSlots, uses, args, maxBatch) + } finally { + nested.clear() + SimplifiedAddCommandInstantiator.release(nested) + } } } @@ -107,6 +141,25 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } } + private object SimplifiedAddCommandInstantiator { + // Per-thread pool to avoid contention; LIFO to improve cache locality. + private val local: ThreadLocal[ArrayDeque[SimplifiedAddCommandInstantiator]] = + new ThreadLocal[ArrayDeque[SimplifiedAddCommandInstantiator]]() { + override def initialValue(): ArrayDeque[SimplifiedAddCommandInstantiator] = + new ArrayDeque[SimplifiedAddCommandInstantiator]() + } + + def acquire(): SimplifiedAddCommandInstantiator = { + val dq = local.get() + val inst = dq.pollFirst() + if (inst != null) inst else new SimplifiedAddCommandInstantiator + } + + def release(inst: SimplifiedAddCommandInstantiator): Unit = { + local.get().offerFirst(inst) + } + } + private def instantiateAsSimplifiedAddCommand(pattern: MixedTree[NodeT, Pattern.Var], m: PatternMatch[NodeT], egraph: EGraphT, @@ -114,7 +167,15 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT val refPool = IntRef.defaultPool val ref = refPool.acquire(0) - val result = new SimplifiedAddCommandInstantiator(m, egraph, builder, ENode.defaultPool, refPool).instantiate(pattern, ref) + val inst = SimplifiedAddCommandInstantiator.acquire() + inst.init(m, egraph, builder, ENode.defaultPool, refPool) + val result = + try { + inst.instantiate(pattern, ref) + } finally { + inst.clear() + SimplifiedAddCommandInstantiator.release(inst) + } refPool.release(ref) result } From 4fce0834d623ff3a54a2ac9495b1d209239952af Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 11:19:36 -0500 Subject: [PATCH 12/23] Remove unused import in PatternApplier.scala to clean up code --- .../foresight/eqsat/rewriting/patterns/PatternApplier.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index 8dec1547..19aef04a 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -8,7 +8,6 @@ import foresight.eqsat.{EClassSymbol, ENode, MixedTree, Slot} import scala.collection.compat.immutable.ArraySeq import java.util.ArrayDeque -import scala.compiletime.uninitialized /** * An applier that applies a pattern match to an e-graph. From a89c67b7a12791a09cc11c03ee929104ea57c745 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 11:25:43 -0500 Subject: [PATCH 13/23] Refactor MutableMachineState to replace MixedTree with EClassCall for bound variables --- .../eqsat/rewriting/patterns/Instruction.scala | 2 +- .../rewriting/patterns/MutableMachineState.scala | 16 +++++++++++----- .../patterns/MutableMachineStateTest.scala | 14 +++++++------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala index 3e2660c7..f8ab9f0b 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/Instruction.scala @@ -400,7 +400,7 @@ object Instruction { ) override def execute(ctx: Instruction.Execution[NodeT, EGraphT]): Boolean = { - val value = MixedTree.Atom[NodeT, EClassCall](ctx.machine.registerAt(register)) + val value = ctx.machine.registerAt(register) ctx.machine.bindVar(value) ctx.continue() } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala index 59e27379..e63e2c8c 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala @@ -11,7 +11,7 @@ import scala.collection.compat._ */ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, private val registersArr: Array[EClassCall], - private val boundVarsArr: Array[MixedTree[NodeT, EClassCall]], + private val boundVarsArr: Array[EClassCall], private val boundSlotsArr: Array[Slot], private val boundNodesArr: Array[ENode[NodeT]], private val homePool: MutableMachineState.Pool[NodeT]) { @@ -165,13 +165,19 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, * Bind a variable to a value. * @param value The value to bind the variable to. */ - def bindVar(value: MixedTree[NodeT, EClassCall]): Unit = { + def bindVar(value: EClassCall): Unit = { boundVarsArr(varIdx) = value varIdx += 1 } private def boundVars: ArrayMap[Pattern.Var, MixedTree[NodeT, EClassCall]] = { - ArrayMap.unsafeWrapArrays(effects.boundVars.unsafeArray, java.util.Arrays.copyOf(boundVarsArr, varIdx), varIdx) + val values = new Array[MixedTree[NodeT, EClassCall]](varIdx) + var i = 0 + while (i < varIdx) { + values(i) = MixedTree.Atom(boundVarsArr(i)) + i += 1 + } + ArrayMap.unsafeWrapArrays(effects.boundVars.unsafeArray, values, varIdx) } private def boundSlots: ArrayMap[Slot, Slot] = { @@ -206,7 +212,7 @@ object MutableMachineState { // Preallocate single empty arrays of the right types to avoid repeated allocations // when effects report zero bound vars/slots/nodes. // Unsafe casts are safe because these arrays are never written to. - private val emptyVars = new Array[MixedTree[_, EClassCall]](0) + private val emptyVars = new Array[EClassCall](0) private val emptySlots = new Array[Slot](0) private val emptyNodes = new Array[ENode[_]](0) @@ -221,7 +227,7 @@ object MutableMachineState { val m = new MutableMachineState[NodeT]( effects, new Array[EClassCall](1 + effects.createdRegisters), - if (varsLen > 0) new Array[MixedTree[NodeT, EClassCall]](varsLen) else emptyVars.asInstanceOf[Array[MixedTree[NodeT, EClassCall]]], + if (varsLen > 0) new Array[EClassCall](varsLen) else emptyVars, if (slotsLen > 0) new Array[Slot](slotsLen) else emptySlots, if (nodesLen > 0) new Array[ENode[NodeT]](nodesLen) else emptyNodes.asInstanceOf[Array[ENode[NodeT]]], pool diff --git a/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala b/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala index 0caf7cca..7d436bd5 100644 --- a/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/rewriting/patterns/MutableMachineStateTest.scala @@ -63,8 +63,8 @@ class MutableMachineStateTest { val m = MutableMachineState[Any](root, effects) // Values for the variables - val mt1: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] - val mt2: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt1: EClassCall = null.asInstanceOf[EClassCall] + val mt2: EClassCall = null.asInstanceOf[EClassCall] m.bindVar(mt1) m.bindVar(mt2) @@ -170,8 +170,8 @@ class MutableMachineStateTest { val m = MutableMachineState[Any](root, effects) // Bind two variables - val mt1: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] - val mt2: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt1: EClassCall = null.asInstanceOf[EClassCall] + val mt2: EClassCall = null.asInstanceOf[EClassCall] m.bindVar(mt1) m.bindVar(mt2) @@ -269,7 +269,7 @@ class MutableMachineStateTest { val ms1 = pool.borrow(root1) // Mutate state to non-zero indices - val mt: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt: EClassCall = null.asInstanceOf[EClassCall] ms1.bindVar(mt) ms1.bindVar(mt) assertEquals(2, ms1.boundVarsCount) @@ -390,8 +390,8 @@ class MutableMachineStateTest { val ms = pool.borrow(root) // Populate state: two vars and one node with one slot use - val mt1: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] - val mt2: MixedTree[Any, EClassCall] = null.asInstanceOf[MixedTree[Any, EClassCall]] + val mt1: EClassCall = null.asInstanceOf[EClassCall] + val mt2: EClassCall = null.asInstanceOf[EClassCall] ms.bindVar(mt1) ms.bindVar(mt2) From 3123741c8bda59e7c6d21290aeeacdddfaa7c64c Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 11:42:18 -0500 Subject: [PATCH 14/23] Refactor ApplierOps to use bind method for variable substitution --- .../foresight/eqsat/examples/arithWithLang/ApplierOps.scala | 2 +- .../main/scala/foresight/eqsat/examples/arith/ApplierOps.scala | 2 +- .../main/scala/foresight/eqsat/examples/liar/ApplierOps.scala | 2 +- .../main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala index bc1d6b80..f82a9e6e 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala @@ -46,7 +46,7 @@ object ApplierOps { } val substituted = subst(extractedExpr) - val newMatch = m.copy(varMapping = m.varMapping + (destination.variable -> L.toTree[EClassCall](substituted))) + val newMatch = m.bind(destination.variable, L.toTree[EClassCall](substituted)) applier.apply(newMatch, egraph, builder) } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala index afcaa6b5..e4131e58 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala @@ -36,7 +36,7 @@ object ApplierOps { } val substituted = subst(extracted) - val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) + val newMatch = m.bind(destination, substituted) applier.apply(newMatch, egraph, builder) } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala index 353b4815..670e82a5 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala @@ -44,7 +44,7 @@ object ApplierOps { } val substituted = subst(extracted) - val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) + val newMatch = m.bind(destination, substituted) applier.apply(newMatch, egraph, builder) } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala index e4af4fa1..41a1792b 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala @@ -36,7 +36,7 @@ object ApplierOps { } val substituted = subst(extracted) - val newMatch = m.copy(varMapping = m.varMapping + (destination -> substituted)) + val newMatch = m.bind(destination, substituted) applier.apply(newMatch, egraph, builder) } } From 74deae83be5a9fc77358399e8de0ba5f53776d50 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:18:33 -0500 Subject: [PATCH 15/23] Introduce specialized CallTree type --- .../examples/arithWithLang/ApplierOps.scala | 4 +- .../eqsat/examples/arithWithLang/Rules.scala | 2 +- .../eqsat/examples/arith/ApplierOps.scala | 4 +- .../eqsat/examples/arith/Rules.scala | 4 +- .../eqsat/examples/liar/ApplierOps.scala | 12 +-- .../eqsat/examples/liar/SearcherOps.scala | 22 ++--- .../eqsat/examples/sdql/ApplierOps.scala | 4 +- .../foresight/eqsat/lang/Language.scala | 29 +++++++ .../scala/foresight/eqsat/EClassCall.scala | 83 ++++++++++++++++++- .../commands/CommandScheduleBuilder.scala | 12 +-- .../eqsat/extraction/Extractor.scala | 25 +++++- .../eqsat/immutable/EGraphLike.scala | 24 +++++- .../eqsat/readonly/AnalysisMetadata.scala | 24 +++++- .../rewriting/patterns/MachineState.scala | 10 +-- .../patterns/MutableMachineState.scala | 18 ++-- .../rewriting/patterns/PatternApplier.scala | 2 +- .../rewriting/patterns/PatternMatch.scala | 10 +-- .../commands/CommandScheduleBuilderTest.scala | 28 +++---- 18 files changed, 240 insertions(+), 77 deletions(-) diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala index f82a9e6e..cffb09f1 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/ApplierOps.scala @@ -32,7 +32,7 @@ object ApplierOps { def subst(tree: ArithExpr): ArithExpr = { tree match { - case Var(slot) if slot == m(from) => L.fromTree[EClassCall](m(to.variable)) + case Var(slot) if slot == m(from) => L.fromTree(m(to.variable)) case Var(slot) => Var(slot) case Lam(param, body) => Lam(param, subst(body)) case App(fun, arg) => App(subst(fun), subst(arg)) @@ -46,7 +46,7 @@ object ApplierOps { } val substituted = subst(extractedExpr) - val newMatch = m.bind(destination.variable, L.toTree[EClassCall](substituted)) + val newMatch = m.bind(destination.variable, L.toCallTree(substituted)) applier.apply(newMatch, egraph, builder) } } diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala index 4e4ac43a..871c2859 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/arithWithLang/Rules.scala @@ -64,7 +64,7 @@ final case class Rules()(using L: Language[ArithExpr]) { result.toSeq.map { value => // If a constant value is found, create a new Number node and bind it to the variable, overwriting the // original binding. - subst.bind(x.variable, L.toTree(Number(value))) + subst.bind(x.variable, L.toCallTree(Number(value))) } }), L.toApplier(x) diff --git a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala index e4131e58..61f9e7f9 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/arith/ApplierOps.scala @@ -27,11 +27,11 @@ object ApplierOps { override def apply(m: PatternMatch[ArithIR], egraph: EGraphWithMetadata[ArithIR, EGraphT], builder: CommandScheduleBuilder[ArithIR]): Unit = { val extracted = ExtractionAnalysis.smallest[ArithIR].extractor[EGraphT](m(source), egraph) - def subst(tree: Tree[ArithIR]): MixedTree[ArithIR, EClassCall] = { + def subst(tree: Tree[ArithIR]): CallTree[ArithIR] = { tree match { case Tree(Var, Seq(), Seq(use), Seq()) if use == m(from) => m(to) case Tree(nodeType, defs, uses, args) => - MixedTree.Node(nodeType, defs, uses, args.map(subst)) + CallTree.Node(nodeType, defs, uses, args.map(subst)) } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala b/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala index 795d60ea..1040603d 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/arith/Rules.scala @@ -4,7 +4,7 @@ import foresight.eqsat.examples.arith.ApplierOps._ import foresight.eqsat.readonly.{EGraph, EGraphWithMetadata} import foresight.eqsat.rewriting.Rule import foresight.eqsat.rewriting.patterns.{Pattern, PatternMatch} -import foresight.eqsat.{MixedTree, Slot} +import foresight.eqsat.{CallTree, MixedTree, Slot} /** * This object contains a collection of rules for rewriting arithmetic expressions. @@ -63,7 +63,7 @@ object Rules { result.toSeq.map { value => // If a constant value is found, create a new Number node and bind it to the variable, overwriting the // original binding. - subst.bind(x, Number(value)) + subst.bind(x, CallTree.from(Number(value))) } }), MixedTree.Atom[ArithIR, Pattern.Var](x).toApplier diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala index 670e82a5..aebc522e 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/ApplierOps.scala @@ -28,18 +28,18 @@ object ApplierOps { override def apply(m: PatternMatch[ArrayIR], egraph: EGraphWithMetadata[ArrayIR, EGraphT], builder: CommandScheduleBuilder[ArrayIR]): Unit = { val extracted = ExtractionAnalysis.smallest[ArrayIR].extractor[EGraphT](m(source), egraph) - def typeOf(tree: MixedTree[ArrayIR, EClassCall]): MixedTree[Type, EClassCall] = { - TypeInferenceAnalysis.get(egraph)(tree, egraph) + def typeOf(tree: CallTree[ArrayIR]): CallTree[Type] = { + CallTree.from(TypeInferenceAnalysis.get(egraph)(tree, egraph)) } - def subst(tree: Tree[ArrayIR]): MixedTree[ArrayIR, EClassCall] = { + def subst(tree: Tree[ArrayIR]): CallTree[ArrayIR] = { tree match { case Tree(Var, Seq(), Seq(use), Seq(fromType)) - if use == m(from) && typeOf(m(to)) == MixedTree.fromTree(fromType) => + if use == m(from) && typeOf(m(to)) == CallTree.from(fromType) => m(to) case Tree(nodeType, defs, uses, args) => - MixedTree.Node(nodeType, defs, uses, args.map(subst)) + CallTree.Node(nodeType, defs, uses, args.map(subst)) } } @@ -66,7 +66,7 @@ object ApplierOps { val tree = applier.instantiate(m) val realTree = tree.mapAtoms(_.asInstanceOf[EClassCall]) inferType(realTree, egraph) - val c = builder.addSimplifiedReal(realTree, egraph) + val c = builder.addSimplifiedReal(CallTree.from(realTree), egraph) builder.unionSimplified(EClassSymbol.real(m.root), c, egraph) } } diff --git a/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala b/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala index 8b2ddf20..fca010d6 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/liar/SearcherOps.scala @@ -4,7 +4,7 @@ import foresight.eqsat.examples.liar.TypeRequirements.RequirementsSearcherContin import foresight.eqsat.parallel.ParallelMap import foresight.eqsat.rewriting.patterns.{CompiledPattern, Pattern, PatternMatch} import foresight.eqsat.rewriting.{Applier, ReversibleSearcher, Searcher} -import foresight.eqsat.MixedTree +import foresight.eqsat.{CallTree, EClassCall, MixedTree} import foresight.eqsat.immutable.{EGraph, EGraphLike, EGraphWithMetadata} object SearcherOps { @@ -23,8 +23,8 @@ object SearcherOps { searcher.map((m, egraph) => { val newVarMapping = m.varMapping ++ types.map { case (value, t) => - val (call, newEGraph) = egraph.add(m(value)) - t -> TypeInferenceAnalysis.get(newEGraph)(call, newEGraph) + val (call, newEGraph) = egraph.add(m(value).toMixedTree) + t -> CallTree.from(TypeInferenceAnalysis.get(newEGraph)(call, newEGraph)) } PatternMatch(m.root, newVarMapping, m.slotMapping) }) @@ -53,8 +53,8 @@ object SearcherOps { searcher.filter((m, egraph) => { values.forall(v => { m(v) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType.isInstanceOf[Value] - case MixedTree.Node(nodeType, _, _, _) => nodeType.isInstanceOf[Value] + case c: EClassCall => egraph.nodes(c).head.nodeType.isInstanceOf[Value] + case CallTree.Node(nodeType, _, _, _) => nodeType.isInstanceOf[Value] } }) }) @@ -69,8 +69,8 @@ object SearcherOps { def requireNonFunctionType(t: Pattern.Var): Searcher[ArrayIR, PatternMatch[ArrayIR], EGraphT] = { searcher.filter((m, egraph) => { m(t) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType != FunctionType - case MixedTree.Node(nodeType, _, _, _) => nodeType != FunctionType + case c: EClassCall => egraph.nodes(c).head.nodeType != FunctionType + case CallTree.Node(nodeType, _, _, _) => nodeType != FunctionType } }) } @@ -83,8 +83,8 @@ object SearcherOps { def requireInt32Type(t: Pattern.Var): Searcher[ArrayIR, PatternMatch[ArrayIR], EGraphT] = { searcher.filter((m, egraph) => { m(t) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType == Int32Type - case MixedTree.Node(nodeType, _, _, _) => nodeType == Int32Type + case c: EClassCall => egraph.nodes(c).head.nodeType == Int32Type + case CallTree.Node(nodeType, _, _, _) => nodeType == Int32Type } }) } @@ -97,8 +97,8 @@ object SearcherOps { def requireDoubleType(t: Pattern.Var): Searcher[ArrayIR, PatternMatch[ArrayIR], EGraphT] = { searcher.filter((m, egraph) => { m(t) match { - case MixedTree.Atom(c) => egraph.nodes(c).head.nodeType == DoubleType - case MixedTree.Node(nodeType, _, _, _) => nodeType == DoubleType + case c: EClassCall => egraph.nodes(c).head.nodeType == DoubleType + case CallTree.Node(nodeType, _, _, _) => nodeType == DoubleType } }) } diff --git a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala index 41a1792b..bb89d51e 100644 --- a/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala +++ b/examples/src/main/scala/foresight/eqsat/examples/sdql/ApplierOps.scala @@ -27,11 +27,11 @@ object ApplierOps { override def apply(m: PatternMatch[SdqlIR], egraph: EGraphWithMetadata[SdqlIR, EGraphT], builder: CommandScheduleBuilder[SdqlIR]): Unit = { val extracted = ExtractionAnalysis.smallest[SdqlIR].extractor[EGraphT](m(source), egraph) - def subst(tree: Tree[SdqlIR]): MixedTree[SdqlIR, EClassCall] = { + def subst(tree: Tree[SdqlIR]): CallTree[SdqlIR] = { tree match { case Tree(Var, Seq(), Seq(use), Seq()) if use == m(from) => m(to) case Tree(nodeType, defs, uses, args) => - MixedTree.Node(nodeType, defs, uses, args.map(subst)) + CallTree.Node(nodeType, defs, uses, args.map(subst)) } } diff --git a/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala b/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala index cbc68550..32b0ff74 100644 --- a/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala +++ b/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala @@ -90,6 +90,35 @@ trait Language[E]: */ def fromTree[A](n: MixedTree[Op, A])(using dec: AtomDecoder[E, A]): E + /** + * Decode a call tree back into a surface AST `E`. + * + * @param n A call tree with [[Op]] internal nodes and [[EClassCall]] leaves. + * @param dec A given [[AtomDecoder]] that knows how to turn `EClassCall` atoms back into `E` fragments. + * @return The reconstructed surface expression. + * + * Note: decoding is typically partial in the presence of analysis-only atoms; see [[fromAnalysisNode]]. + */ + def fromTree(n: CallTree[Op])(using dec: AtomDecoder[E, EClassCall]): E = { + fromTree[EClassCall](n.toMixedTree) + } + + /** + * Encode a surface AST `e: E` into a call tree with [[EClassCall]] leaves. + * + * @param e A surface AST node. + * @param enc A given [[AtomEncoder]] that knows how to encode `E` atoms into `EClassCall`. + * @return A call tree equivalent to `e`. + * + * @example + * {{{ + * val callTree: CallTree[Lang.Op] = Lang.toCallTree(surfaceExpr) + * }}} + */ + def toCallTree(e: E)(using enc: AtomEncoder[E, EClassCall]): CallTree[Op] = { + CallTree.from(toTree[EClassCall](e)) + } + /** * Encode a surface AST `e: E` into an immutable e-graph, returning the root e-class call * and the new e-graph containing `e`. diff --git a/foresight/src/main/scala/foresight/eqsat/EClassCall.scala b/foresight/src/main/scala/foresight/eqsat/EClassCall.scala index d6ee2399..6a1947dc 100644 --- a/foresight/src/main/scala/foresight/eqsat/EClassCall.scala +++ b/foresight/src/main/scala/foresight/eqsat/EClassCall.scala @@ -1,8 +1,10 @@ package foresight.eqsat -import foresight.eqsat.collections.{SlotMap, SlotSet} +import foresight.eqsat.collections.{SlotMap, SlotSeq, SlotSet} import foresight.eqsat.readonly.EGraph +import scala.collection.compat.immutable.ArraySeq + /** * Represents the application of an [[EClassRef]] to a set of argument slots. * @@ -29,7 +31,7 @@ import foresight.eqsat.readonly.EGraph * val call2 = EClassCall(subXY, SlotMap(x -> b, y -> a)) // represents "b - a" * }}} */ -final case class EClassCall(ref: EClassRef, args: SlotMap) extends EClassSymbol { +final case class EClassCall(ref: EClassRef, args: SlotMap) extends EClassSymbol with CallTree[Nothing] { /** * The set of slots used as arguments in this application, in sequence order. * These are the slots referenced by the argument values, not the parameter slots. @@ -39,7 +41,7 @@ final case class EClassCall(ref: EClassRef, args: SlotMap) extends EClassSymbol /** * The set of distinct slots used as arguments in this application. */ - def slotSet: SlotSet = args.valueSet + override def slotSet: SlotSet = args.valueSet /** * Renames all argument slots in this application according to a given mapping. @@ -208,3 +210,78 @@ object EClassSymbol { */ def real(call: EClassCall): Real = call } + +/** + * A tree of nodes and e-class calls, representing expressions in an e-graph. + * @tparam NodeT The type of the nodes in the e-graph. + */ +sealed trait CallTree[+NodeT] { + /** + * The set of distinct slots used in this call tree. + */ + def slotSet: Set[Slot] = this match { + case call: EClassCall => call.slotSet + case CallTree.Node(_, defs, uses, children) => + defs.toSet ++ uses ++ children.flatMap(_.slotSet) + } + + /** + * Converts this call tree into a mixed tree of nodes and e-class calls. + * @return The resulting mixed tree. + */ + final def toMixedTree: MixedTree[NodeT, EClassCall] = this match { + case call: EClassCall => MixedTree.Atom(call) + case CallTree.Node(n, defs, uses, children) => + MixedTree.Node(n, defs, uses, children.map(_.toMixedTree)) + } + + /** + * Maps all [[EClassCall]] nodes in this call tree using a given function. + * @param f The mapping function. + * @return The resulting call tree with mapped calls. + */ + final def mapCalls(f: EClassCall => EClassCall): CallTree[NodeT] = this match { + case call: EClassCall => f(call) + case CallTree.Node(n, defs, uses, children) => + CallTree.Node(n, defs, uses, children.map(_.mapCalls(f))) + } +} + +/** + * Constructors for [[CallTree]] nodes. + */ +object CallTree { + /** + * Represents a node in the call tree. + * @param node The node. + * @param definitions The slots defined by this node. + * @param uses The slots used by this node. + * @param children The child call trees. + * @tparam NodeT The type of the nodes in the e-graph. + */ + final case class Node[+NodeT](node: NodeT, definitions: SlotSeq, uses: SlotSeq, children: ArraySeq[CallTree[NodeT]]) + extends CallTree[NodeT] + + /** + * Converts a [[MixedTree]] of [[EClassCall]]s into a [[CallTree]]. + * @param tree The mixed tree to convert. + * @tparam NodeT The type of the nodes in the e-graph. + * @return The resulting call tree. + */ + def from[NodeT](tree: MixedTree[NodeT, EClassCall]): CallTree[NodeT] = tree match { + case MixedTree.Atom(call) => call + case MixedTree.Node(n, defs, uses, args) => + Node(n, defs, uses, args.map(arg => from(arg))) + } + + /** + * Converts a [[Tree]] into a [[CallTree]]. + * @param tree The tree to convert. + * @tparam NodeT The type of the nodes in the e-graph. + * @return The resulting call tree. + */ + def from[NodeT](tree: Tree[NodeT]): CallTree[NodeT] = tree match { + case Tree(n, defs, uses, children) => + Node(n, defs, uses, children.map(child => from(child))) + } +} diff --git a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala index 66ff33a7..bea9a877 100644 --- a/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala +++ b/foresight/src/main/scala/foresight/eqsat/commands/CommandScheduleBuilder.scala @@ -2,7 +2,7 @@ package foresight.eqsat.commands import foresight.eqsat.collections.SlotSeq import foresight.eqsat.readonly.EGraph -import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} +import foresight.eqsat.{CallTree, EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} import foresight.util.Debug import foresight.util.collections.UnsafeSeqFromArray @@ -81,7 +81,7 @@ trait CommandScheduleBuilder[NodeT] { } } - private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], + private[eqsat] def addSimplifiedReal(tree: CallTree[NodeT], egraph: EGraph[NodeT]): EClassSymbol = { val refPool = IntRef.defaultPool val maxBatch = refPool.acquire(0) @@ -90,20 +90,20 @@ trait CommandScheduleBuilder[NodeT] { result } - private[eqsat] def addSimplifiedReal(tree: MixedTree[NodeT, EClassCall], + private[eqsat] def addSimplifiedReal(tree: CallTree[NodeT], egraph: EGraph[NodeT], maxBatch: IntRef, nodePool: ENode.Pool, refPool: IntRef.Pool): EClassSymbol = { tree match { - case MixedTree.Node(t, defs, uses, args) => + case CallTree.Node(t, defs, uses, args) => // Local accumulator for children of this node. val childMax = refPool.acquire(0) val argSymbols = CommandScheduleBuilder.symbolArrayFrom( args, childMax, nodePool, - (child: MixedTree[NodeT, EClassCall], mb: IntRef) => addSimplifiedReal(child, egraph, mb, nodePool, refPool) + (child: CallTree[NodeT], mb: IntRef) => addSimplifiedReal(child, egraph, mb, nodePool, refPool) ) val sym = addSimplifiedNode(t, defs, uses, argSymbols, childMax, egraph, nodePool) // Propagate maximum required batch up to the caller's accumulator. @@ -111,7 +111,7 @@ trait CommandScheduleBuilder[NodeT] { refPool.release(childMax) sym - case MixedTree.Atom(call) => + case call: EClassCall => // No insertion required; keep caller's accumulator unchanged. EClassSymbol.real(call) } diff --git a/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala b/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala index 634011a8..266dc0c4 100644 --- a/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala +++ b/foresight/src/main/scala/foresight/eqsat/extraction/Extractor.scala @@ -1,8 +1,7 @@ package foresight.eqsat.extraction -import foresight.eqsat.{EClassCall, MixedTree, Tree} +import foresight.eqsat.{CallTree, EClassCall, MixedTree, Tree, readonly} import foresight.eqsat.immutable.{EGraph, EGraphLike} -import foresight.eqsat.readonly /** * An extractor that converts e-graph references (e-class calls) into concrete expression trees. @@ -54,4 +53,26 @@ trait Extractor[NodeT, -Repr <: readonly.EGraph[NodeT]] { apply(call, egraph) } } + + /** + * Extracts a concrete expression tree from a [[CallTree]] by extracting an expression for each + * e-class call within it. + * + * @param tree The call tree to materialize prior to extraction. + * @param egraph The original e-graph; remains unchanged. + * @return A concrete [[Tree]] extracted from the materialized call. + * + * @example + * {{{ + * val result: Tree[NodeT] = extractor(callTree, egraph) // `egraph` is unchanged + * }}} + */ + final def apply(tree: CallTree[NodeT], egraph: Repr): Tree[NodeT] = { + tree match { + case call: EClassCall => apply(call, egraph) + case CallTree.Node(n, defs, uses, children) => + val extractedChildren = children.map(child => apply(child, egraph)) + Tree(n, defs, uses, extractedChildren) + } + } } diff --git a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala index 74e597b3..15c50486 100644 --- a/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala +++ b/foresight/src/main/scala/foresight/eqsat/immutable/EGraphLike.scala @@ -1,8 +1,7 @@ package foresight.eqsat.immutable -import foresight.eqsat.{AddNodeResult, EClassCall, ENode, MixedTree, Tree} +import foresight.eqsat.{AddNodeResult, CallTree, EClassCall, ENode, MixedTree, Tree, readonly} import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.readonly import scala.collection.compat.immutable.ArraySeq @@ -148,6 +147,27 @@ trait EGraphLike[NodeT, +This <: EGraphLike[NodeT, This] with EGraph[NodeT]] ext } } + /** + * Adds a call tree to the e-graph. + * + * Child subtrees are added/resolved first; then the root e-node is added or found. + * Returns a new e-graph containing the result. + * + * @param tree The call tree to add. + * @return (E-class of the root, new e-graph). + */ + final def add(tree: CallTree[NodeT]): (EClassCall, This) = { + tree match { + case call: EClassCall => (call, this.asInstanceOf[This]) + case CallTree.Node(t, defs, uses, args) => + val (newArgs, graphWithArgs) = args.foldLeft((Seq.empty[EClassCall], this.asInstanceOf[This]))((acc, arg) => { + val (node, egraph) = acc._2.add(arg) + (acc._1 :+ node, egraph) + }) + graphWithArgs.add(ENode(t, defs, uses, newArgs)) + } + } + /** * Adds a pure tree to the e-graph. * diff --git a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala index a7b7e03e..66790b8d 100644 --- a/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala +++ b/foresight/src/main/scala/foresight/eqsat/readonly/AnalysisMetadata.scala @@ -1,6 +1,6 @@ package foresight.eqsat.readonly -import foresight.eqsat.{EClassCall, EClassRef, MixedTree} +import foresight.eqsat.{CallTree, EClassCall, EClassRef, MixedTree} import foresight.eqsat.metadata.Analysis /** @@ -55,6 +55,28 @@ trait AnalysisMetadata[NodeT, A] { } } + /** + * Evaluate a call tree whose leaves are either e-class applications or concrete nodes. + * + * For a call leaf, this delegates to [[apply(EClassCall,EGraph)]]. For a node leaf, it first + * computes the results of all argument subtrees and then invokes the analysis transfer function + * [[Analysis.make]] using the node’s definitions and uses provided by the tree. + * + * @param tree The call tree to evaluate. + * @param egraph The e-graph to resolve calls and canonicalization. + * @return The analysis result for the whole tree. + */ + final def apply(tree: CallTree[NodeT], + egraph: EGraph[NodeT]): A = { + tree match { + case CallTree.Node(node, defs, uses, args) => + val argsResults = args.map(apply(_, egraph)) + analysis.make(node, defs, uses, argsResults) + + case call: EClassCall => apply(call, egraph) + } + } + /** * Compute the analysis result for an already-canonicalized e-class application. * diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala index 9f14e937..4e0fe32a 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MachineState.scala @@ -1,9 +1,9 @@ package foresight.eqsat.rewriting.patterns -import foresight.eqsat.{EClassCall, ENode, MixedTree, Slot} +import foresight.eqsat.{CallTree, EClassCall, ENode, Slot} import foresight.util.collections.ArrayMap -import scala.collection.compat._ +import scala.collection.compat.immutable.ArraySeq /** * The state of a pattern machine. @@ -14,7 +14,7 @@ import scala.collection.compat._ * @param boundNodes The nodes that are bound in the machine. * @tparam NodeT The type of the nodes in the e-graph. */ -final case class MachineState[NodeT](registers: immutable.ArraySeq[EClassCall], - boundVars: ArrayMap[Pattern.Var, MixedTree[NodeT, EClassCall]], +final case class MachineState[NodeT](registers: ArraySeq[EClassCall], + boundVars: ArrayMap[Pattern.Var, CallTree[NodeT]], boundSlots: ArrayMap[Slot, Slot], - boundNodes: immutable.ArraySeq[ENode[NodeT]]) + boundNodes: ArraySeq[ENode[NodeT]]) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala index e63e2c8c..49f92d27 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala @@ -1,9 +1,9 @@ package foresight.eqsat.rewriting.patterns -import foresight.eqsat.{EClassCall, ENode, MixedTree, Slot} +import foresight.eqsat.{CallTree, EClassCall, ENode, Slot} import foresight.util.collections.ArrayMap -import scala.collection.compat._ +import scala.collection.compat.immutable.ArraySeq /** * A mutable machine state that preallocates fixed-size arrays @@ -170,14 +170,8 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, varIdx += 1 } - private def boundVars: ArrayMap[Pattern.Var, MixedTree[NodeT, EClassCall]] = { - val values = new Array[MixedTree[NodeT, EClassCall]](varIdx) - var i = 0 - while (i < varIdx) { - values(i) = MixedTree.Atom(boundVarsArr(i)) - i += 1 - } - ArrayMap.unsafeWrapArrays(effects.boundVars.unsafeArray, values, varIdx) + private def boundVars: ArrayMap[Pattern.Var, CallTree[NodeT]] = { + ArrayMap.unsafeWrapArrays(effects.boundVars.unsafeArray, java.util.Arrays.copyOf(boundVarsArr, varIdx), varIdx) } private def boundSlots: ArrayMap[Slot, Slot] = { @@ -191,8 +185,8 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, /** Convert to an immutable MachineState snapshot. */ def freeze(): MachineState[NodeT] = { - val regs = immutable.ArraySeq.unsafeWrapArray(registersArr.slice(0, regIdx)) - val nodes = immutable.ArraySeq.unsafeWrapArray(boundNodesArr.slice(0, nodeIdx)) + val regs = ArraySeq.unsafeWrapArray(registersArr.slice(0, regIdx)) + val nodes = ArraySeq.unsafeWrapArray(boundNodesArr.slice(0, nodeIdx)) MachineState(regs, boundVars, boundSlots, nodes) } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index 19aef04a..ff539096 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -47,7 +47,7 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT m: PatternMatch[NodeT]): MixedTree[NodeT, EClassSymbol] = { pattern match { case MixedTree.Atom(p) => p match { - case v: Pattern.Var => m(v).mapAtoms(EClassSymbol.real) + case v: Pattern.Var => m(v).toMixedTree.mapAtoms(EClassSymbol.real) } case MixedTree.Node(t, Seq(), uses, args) => diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala index b593faed..4c6e8c29 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala @@ -2,7 +2,7 @@ package foresight.eqsat.rewriting.patterns import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.PortableMatch -import foresight.eqsat.{EClassCall, MixedTree, Slot} +import foresight.eqsat.{CallTree, EClassCall, MixedTree, Slot} import foresight.util.collections.ArrayMap import foresight.util.collections.StrictMapOps.toStrictMapOps @@ -15,7 +15,7 @@ import foresight.util.collections.StrictMapOps.toStrictMapOps * @tparam NodeT The type of the nodes in the e-graph. */ final case class PatternMatch[NodeT](root: EClassCall, - varMapping: ArrayMap[Pattern.Var, MixedTree[NodeT, EClassCall]], + varMapping: ArrayMap[Pattern.Var, CallTree[NodeT]], slotMapping: ArrayMap[Slot, Slot]) extends PortableMatch[NodeT, PatternMatch[NodeT]] { /** @@ -23,7 +23,7 @@ final case class PatternMatch[NodeT](root: EClassCall, * @param variable The variable. * @return The tree. */ - def apply(variable: Pattern.Var): MixedTree[NodeT, EClassCall] = varMapping(variable) + def apply(variable: Pattern.Var): CallTree[NodeT] = varMapping(variable) /** * Gets the slot that corresponds to a slot variable. @@ -38,7 +38,7 @@ final case class PatternMatch[NodeT](root: EClassCall, * @param value The value to bind the variable to. * @return The updated match with the new binding. */ - def bind(variable: Pattern.Var, value: MixedTree[NodeT, EClassCall]): PatternMatch[NodeT] = { + def bind(variable: Pattern.Var, value: CallTree[NodeT]): PatternMatch[NodeT] = { PatternMatch(root, varMapping.updated(variable, value), slotMapping) } @@ -74,7 +74,7 @@ final case class PatternMatch[NodeT](root: EClassCall, override def port(egraph: EGraph[NodeT]): PatternMatch[NodeT] = { val newRoot = egraph.canonicalize(root) - val newVarMapping = varMapping.mapValuesStrict(_.mapAtoms(egraph.canonicalize)) + val newVarMapping = varMapping.mapValuesStrict(_.mapCalls(egraph.canonicalize)) val newSlotMapping = slotMapping PatternMatch(newRoot, newVarMapping, newSlotMapping) } diff --git a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala index 5e18370e..480e8fed 100644 --- a/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala +++ b/foresight/src/test/scala/foresight/eqsat/commands/CommandScheduleBuilderTest.scala @@ -2,7 +2,7 @@ package foresight.eqsat.commands import foresight.eqsat.collections.SlotSeq import foresight.eqsat.parallel.ParallelMap -import foresight.eqsat.{EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} +import foresight.eqsat.{CallTree, EClassCall, EClassSymbol, ENode, ENodeSymbol, MixedTree} import foresight.eqsat.immutable.EGraph import org.junit.Test @@ -186,7 +186,7 @@ class CommandScheduleBuilderTest { // Use addSimplifiedReal on an Atom (real call). val sym = baseBuilder.addSimplifiedReal( - MixedTree.Atom[Int, EClassCall](realCall), + realCall, g1 ) @@ -206,7 +206,7 @@ class CommandScheduleBuilderTest { val g0 = EGraph.empty[Int] // Build a trivial 1-node tree: Node(7, [], [], []) - val tree = MixedTree.Node[Int, EClassCall](7, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty[MixedTree[Int, EClassCall]]) + val tree = CallTree.Node[Int](7, SlotSeq.empty, SlotSeq.empty, ArraySeq.empty[CallTree[Int]]) val sym = builder.addSimplifiedReal(tree, g0) // We only care that exactly one node gets scheduled and applying the schedule grows the graph by 1. @@ -240,14 +240,14 @@ class CommandScheduleBuilderTest { val realLeaf: EClassCall = g1.canonicalize(g1.classes.head) // child = Node(10, [], [], [ Atom(realLeaf) ]) - val child = MixedTree.Node[Int, EClassCall]( + val child = CallTree.Node[Int]( 10, SlotSeq.empty, SlotSeq.empty, - ArraySeq(MixedTree.Atom[Int, EClassCall](realLeaf)) + ArraySeq(realLeaf) ) // parent = Node(20, [], [], [ child ]) - val parent = MixedTree.Node[Int, EClassCall]( + val parent = CallTree.Node[Int]( 20, SlotSeq.empty, SlotSeq.empty, @@ -296,29 +296,29 @@ class CommandScheduleBuilderTest { } // Level 1 nodes (directly depend on real atoms) - val child1 = MixedTree.Node[Int, EClassCall]( + val child1 = CallTree.Node[Int]( 201, SlotSeq.empty, SlotSeq.empty, - ArraySeq(MixedTree.Atom[Int, EClassCall](realA)) + ArraySeq(realA) ) - val child2 = MixedTree.Node[Int, EClassCall]( + val child2 = CallTree.Node[Int]( 202, SlotSeq.empty, SlotSeq.empty, - ArraySeq(MixedTree.Atom[Int, EClassCall](realA), MixedTree.Atom[Int, EClassCall](realB)) + ArraySeq(realA, realB) ) val childSym1 = builder.addSimplifiedReal(child1, g1) val childSym2 = builder.addSimplifiedReal(child2, g1) // Level 2 nodes (depend on Level 1) - val parent1 = MixedTree.Node[Int, EClassCall]( + val parent1 = CallTree.Node[Int]( 301, SlotSeq.empty, SlotSeq.empty, - ArraySeq(child1, MixedTree.Atom[Int, EClassCall](realB)) + ArraySeq(child1, realB) ) - val parent2 = MixedTree.Node[Int, EClassCall]( + val parent2 = CallTree.Node[Int]( 302, SlotSeq.empty, SlotSeq.empty, @@ -328,7 +328,7 @@ class CommandScheduleBuilderTest { val parentSym2 = builder.addSimplifiedReal(parent2, g1) // Level 3 node (root depends on both Level 2 nodes) - val root = MixedTree.Node[Int, EClassCall]( + val root = CallTree.Node[Int]( 401, SlotSeq.empty, SlotSeq.empty, From 7e99429cc92867578adc5b92b4745f700eb34477 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:24:57 -0500 Subject: [PATCH 16/23] Comment out array nullification in ENode to optimize memory management --- foresight/src/main/scala/foresight/eqsat/ENode.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/ENode.scala b/foresight/src/main/scala/foresight/eqsat/ENode.scala index 5cedec57..6d2f768c 100644 --- a/foresight/src/main/scala/foresight/eqsat/ENode.scala +++ b/foresight/src/main/scala/foresight/eqsat/ENode.scala @@ -533,7 +533,7 @@ object ENode { if (arr eq null) return val len = arr.length if (len == 0) return - java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) + // java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) val q = slotDeque(len) if ((q ne null) && q.size() < perBucketCap) q.addFirst(arr) } @@ -546,7 +546,7 @@ object ENode { if (arr eq null) return val len = arr.length if (len == 0) return - java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) + // java.util.Arrays.fill(arr.asInstanceOf[Array[AnyRef]], null) val q = callDeque(len) if ((q ne null) && q.size() < perBucketCap) q.addFirst(arr) } From e66401b11d0c5fee05dc4c40b07bca3fe9765367 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:35:48 -0500 Subject: [PATCH 17/23] Add AbstractPatternMatch trait and refactor PatternMatch to extend it --- .../patterns/AbstractPatternMatch.scala | 26 +++++++++++++++++++ .../rewriting/patterns/PatternMatch.scala | 7 ++--- 2 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala new file mode 100644 index 00000000..536b645e --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala @@ -0,0 +1,26 @@ +package foresight.eqsat.rewriting.patterns + +import foresight.eqsat.{CallTree, Slot} + +/** + * An abstract pattern match that maps pattern variables to trees and slots. + * + * @tparam NodeT The type of the nodes in the e-graph. + */ +trait AbstractPatternMatch[NodeT] { + /** + * Gets the tree that corresponds to a variable. + * + * @param variable The variable. + * @return The tree. + */ + def apply(variable: Pattern.Var): CallTree[NodeT] + + /** + * Gets the slot that corresponds to a slot variable. + * + * @param slot The slot variable. + * @return The slot. + */ + def apply(slot: Slot): Slot +} diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala index 4c6e8c29..e43545a0 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala @@ -16,21 +16,22 @@ import foresight.util.collections.StrictMapOps.toStrictMapOps */ final case class PatternMatch[NodeT](root: EClassCall, varMapping: ArrayMap[Pattern.Var, CallTree[NodeT]], - slotMapping: ArrayMap[Slot, Slot]) extends PortableMatch[NodeT, PatternMatch[NodeT]] { + slotMapping: ArrayMap[Slot, Slot]) + extends AbstractPatternMatch[NodeT] with PortableMatch[NodeT, PatternMatch[NodeT]] { /** * Gets the tree that corresponds to a variable. * @param variable The variable. * @return The tree. */ - def apply(variable: Pattern.Var): CallTree[NodeT] = varMapping(variable) + override def apply(variable: Pattern.Var): CallTree[NodeT] = varMapping(variable) /** * Gets the slot that corresponds to a slot variable. * @param slot The slot variable. * @return The slot. */ - def apply(slot: Slot): Slot = slotMapping(slot) + override def apply(slot: Slot): Slot = slotMapping(slot) /** * Creates an updated match with a new variable binding. From fba986011219101d397e429f676d4ed19ebe76e0 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:43:50 -0500 Subject: [PATCH 18/23] Refactor PatternMatch and PatternApplier to use AbstractPatternMatch and add get method --- .../patterns/AbstractPatternMatch.scala | 15 +++++++- .../rewriting/patterns/PatternApplier.scala | 38 ++++++++++++++----- .../rewriting/patterns/PatternMatch.scala | 7 ++++ 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala index 536b645e..cf581cd6 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/AbstractPatternMatch.scala @@ -1,6 +1,6 @@ package foresight.eqsat.rewriting.patterns -import foresight.eqsat.{CallTree, Slot} +import foresight.eqsat.{CallTree, EClassCall, Slot} /** * An abstract pattern match that maps pattern variables to trees and slots. @@ -8,6 +8,11 @@ import foresight.eqsat.{CallTree, Slot} * @tparam NodeT The type of the nodes in the e-graph. */ trait AbstractPatternMatch[NodeT] { + /** + * The e-class in which the pattern was found. + */ + def root: EClassCall + /** * Gets the tree that corresponds to a variable. * @@ -23,4 +28,12 @@ trait AbstractPatternMatch[NodeT] { * @return The slot. */ def apply(slot: Slot): Slot + + /** + * Gets the slot that corresponds to a slot variable, returning an option. + * + * @param slot The slot variable. + * @return The slot if it exists; None otherwise. + */ + def get(slot: Slot): Option[Slot] } diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala index ff539096..07ab5102 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternApplier.scala @@ -3,8 +3,8 @@ package foresight.eqsat.rewriting.patterns import foresight.eqsat.collections.SlotSeq import foresight.eqsat.commands.{CommandScheduleBuilder, IntRef} import foresight.eqsat.readonly.EGraph -import foresight.eqsat.rewriting.{ReversibleApplier, Searcher} -import foresight.eqsat.{EClassSymbol, ENode, MixedTree, Slot} +import foresight.eqsat.rewriting.{Applier, ReversibleApplier, Searcher} +import foresight.eqsat.{CallTree, EClassCall, EClassSymbol, ENode, MixedTree, Slot} import scala.collection.compat.immutable.ArraySeq import java.util.ArrayDeque @@ -17,9 +17,9 @@ import java.util.ArrayDeque * @tparam EGraphT The type of the e-graph that the applier applies the match to. */ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedTree[NodeT, Pattern.Var]) - extends ReversibleApplier[NodeT, PatternMatch[NodeT], EGraphT] { + extends ReversibleApplier[NodeT, PatternMatch[NodeT], EGraphT] with Applier[NodeT, AbstractPatternMatch[NodeT], EGraphT] { - override def apply(m: PatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { + override def apply(m: AbstractPatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): Unit = { val symbol = instantiateAsSimplifiedAddCommand(pattern, m, egraph, builder) builder.unionSimplified(EClassSymbol.real(m.root), symbol, egraph) } @@ -68,13 +68,13 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT private final class SimplifiedAddCommandInstantiator { // Mutable state to allow pooling; initialized via init() before use. - private var m: PatternMatch[NodeT] = _ + private var m: AbstractPatternMatch[NodeT] = _ private var egraph: EGraphT = _ private var builder: CommandScheduleBuilder[NodeT] = _ private var nodePool: ENode.Pool = _ private var refPool: IntRef.Pool = _ - def init(m0: PatternMatch[NodeT], + def init(m0: AbstractPatternMatch[NodeT], egraph0: EGraphT, builder0: CommandScheduleBuilder[NodeT], nodePool0: ENode.Pool, @@ -104,12 +104,32 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT case MixedTree.Node(t, defs, uses, args) => val defSlots = defs.map { (s: Slot) => - m.slotMapping.get(s) match { + m.get(s) match { case Some(v) => v case None => Slot.fresh() } } - val newMatch = m.copy(slotMapping = m.slotMapping ++ defs.zip(defSlots)) + + val newMatch = new AbstractPatternMatch[NodeT] { + override def root: EClassCall = m.root + override def apply(variable: Pattern.Var): CallTree[NodeT] = m.apply(variable) + override def apply(slot: Slot): Slot = { + if (defs.contains(slot)) { + val index = defs.indexOf(slot) + defSlots(index) + } else { + m.apply(slot) + } + } + override def get(slot: Slot): Option[Slot] = { + if (defs.contains(slot)) { + val index = defs.indexOf(slot) + Some(defSlots(index)) + } else { + m.get(slot) + } + } + } // Acquire a nested instantiator. val nested = SimplifiedAddCommandInstantiator.acquire() @@ -160,7 +180,7 @@ final case class PatternApplier[NodeT, EGraphT <: EGraph[NodeT]](pattern: MixedT } private def instantiateAsSimplifiedAddCommand(pattern: MixedTree[NodeT, Pattern.Var], - m: PatternMatch[NodeT], + m: AbstractPatternMatch[NodeT], egraph: EGraphT, builder: CommandScheduleBuilder[NodeT]): EClassSymbol = { diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala index e43545a0..64c77a06 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/PatternMatch.scala @@ -33,6 +33,13 @@ final case class PatternMatch[NodeT](root: EClassCall, */ override def apply(slot: Slot): Slot = slotMapping(slot) + /** + * Gets the slot that corresponds to a slot variable, returning an option. + * @param slot The slot variable. + * @return The slot if it exists; None otherwise. + */ + override def get(slot: Slot): Option[Slot] = slotMapping.get(slot) + /** * Creates an updated match with a new variable binding. * @param variable The variable to bind. From 5f2ba342bc8a95f3f7c07876dbf2b83a2cbafd4c Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:48:57 -0500 Subject: [PATCH 19/23] Extend MutableMachineState to implement AbstractPatternMatch and add methods for variable and slot retrieval --- .../patterns/MutableMachineState.scala | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala index 49f92d27..378a19ca 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/MutableMachineState.scala @@ -14,7 +14,8 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, private val boundVarsArr: Array[EClassCall], private val boundSlotsArr: Array[Slot], private val boundNodesArr: Array[ENode[NodeT]], - private val homePool: MutableMachineState.Pool[NodeT]) { + private val homePool: MutableMachineState.Pool[NodeT]) + extends AbstractPatternMatch[NodeT] { private var regIdx: Int = 0 private var varIdx: Int = 0 @@ -196,6 +197,51 @@ final class MutableMachineState[NodeT] private(val effects: Instruction.Effects, def toPatternMatch: PatternMatch[NodeT] = { PatternMatch(registersArr(0), boundVars, boundSlots) } + + /** + * The e-class in which the pattern was found. + */ + override def root: EClassCall = registersArr(0) + + /** + * Gets the tree that corresponds to a variable. + * + * @param variable The variable. + * @return The tree. + */ + override def apply(variable: Pattern.Var): CallTree[NodeT] = { + val i = effects.boundVars.indexOf(variable) + if (i < 0 || i >= varIdx) + throw new NoSuchElementException(s"Variable $variable not bound in this state") + + boundVarsArr(i) + } + + /** + * Gets the slot that corresponds to a slot variable. + * + * @param slot The slot variable. + * @return The slot. + */ + override def apply(slot: Slot): Slot = { + val i = effects.boundSlots.indexOf(slot) + if (i < 0 || i >= slotIdx) + throw new NoSuchElementException(s"Slot $slot not bound in this state") + + boundSlotsArr(i) + } + + /** + * Gets the slot that corresponds to a slot variable, returning an option. + * + * @param slot The slot variable. + * @return The slot if it exists; None otherwise. + */ + override def get(slot: Slot): Option[Slot] = { + val i = effects.boundSlots.indexOf(slot) + if (i >= 0 && i < slotIdx) Some(boundSlotsArr(i)) + else None + } } /** From 89acdc043d12c9ea38d458c1151c2563a4e26d60 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:49:10 -0500 Subject: [PATCH 20/23] Add searchBorrowed method to CompiledPattern for e-graph pattern matching --- .../rewriting/patterns/CompiledPattern.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala index 93c3a41d..a267c355 100644 --- a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/CompiledPattern.scala @@ -47,6 +47,27 @@ final case class CompiledPattern[NodeT, EGraphT <: EGraph[NodeT]](pattern: Mixed }) } + /** + * Searches for matches of the pattern in an e-graph, calling a continuation for each match. + * If the continuation returns false, the search is stopped. + * + * Borrows the machine state and passes it to the continuation. The continuation must not + * retain the state after returning. + * + * @param call The e-class application to search for. + * @param egraph The e-graph to search in. + * @param continuation A continuation that is called for each match of the pattern. If the continuation returns false, + * the search is stopped. + */ + def searchBorrowed(call: EClassCall, egraph: EGraphT, continuation: (AbstractPatternMatch[NodeT], EGraphT) => Boolean): Unit = { + val state = machinePool.borrow(call) + Machine.run(egraph, state, instructions, (state: MutableMachineState[NodeT]) => { + val result = continuation(state, egraph) + state.release() + result + }) + } + /** * Searches for matches of the pattern in an e-graph. * @param call The e-class application to search for. From 3491f3017171835af6ae82fb11795e48df2ffd57 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:51:56 -0500 Subject: [PATCH 21/23] Add StateBorrowingMachineEClassSearcher for efficient pattern matching in e-graphs --- .../StateBorrowingMachineEClassSearcher.scala | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 foresight/src/main/scala/foresight/eqsat/rewriting/patterns/StateBorrowingMachineEClassSearcher.scala diff --git a/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/StateBorrowingMachineEClassSearcher.scala b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/StateBorrowingMachineEClassSearcher.scala new file mode 100644 index 00000000..d614a319 --- /dev/null +++ b/foresight/src/main/scala/foresight/eqsat/rewriting/patterns/StateBorrowingMachineEClassSearcher.scala @@ -0,0 +1,49 @@ +package foresight.eqsat.rewriting.patterns + +import foresight.eqsat.rewriting.{Applier, EClassSearcher, ReversibleSearcher, SearcherContinuation} +import foresight.eqsat.EClassCall +import foresight.eqsat.readonly.EGraph + +/** + * A phase of a searcher that searches for matches of a pattern machine in an e-graph. + * + * This searcher uses a borrowing mechanism for the machine state, allowing for more efficient + * searches by reusing state without the overhead of copying. + * + * @param pattern The pattern to search for. + * @tparam NodeT The type of the nodes in the e-graph. + * @tparam EGraphT The type of the e-graph that the searcher searches in. + */ +final case class StateBorrowingMachineEClassSearcher[ + NodeT, + EGraphT <: EGraph[NodeT] +](pattern: CompiledPattern[NodeT, EGraphT], + buildContinuation: SearcherContinuation.ContinuationBuilder[NodeT, AbstractPatternMatch[NodeT], EGraphT]) + + extends EClassSearcher[NodeT, AbstractPatternMatch[NodeT], EGraphT] { + + protected override def search(call: EClassCall, egraph: EGraphT, continuation: SearcherContinuation.Continuation[NodeT, AbstractPatternMatch[NodeT], EGraphT]): Unit = { + pattern.searchBorrowed(call, egraph, continuation) + } + + override def withContinuationBuilder(continuation: SearcherContinuation.ContinuationBuilder[NodeT, AbstractPatternMatch[NodeT], EGraphT]): StateBorrowingMachineEClassSearcher[NodeT, EGraphT] = { + copy(buildContinuation = continuation) + } +} + +object StateBorrowingMachineEClassSearcher { + /** + * Creates a `StateBorrowingMachineEClassSearcher` from a compiled pattern. + * + * @param pattern The pattern to search for. + * @tparam NodeT The type of the nodes in the e-graph. + * @tparam EGraphT The type of the e-graph that the searcher searches in. + * @return A new `StateBorrowingMachineEClassSearcher` instance. + */ + def apply[ + NodeT, + EGraphT <: EGraph[NodeT] + ](pattern: CompiledPattern[NodeT, EGraphT]): StateBorrowingMachineEClassSearcher[NodeT, EGraphT] = { + StateBorrowingMachineEClassSearcher(pattern, SearcherContinuation.identityBuilder) + } +} From dc8c86bbe29cb620fd50026c5ebb1f51cc6fb712 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 12:58:34 -0500 Subject: [PATCH 22/23] Add borrowing searcher and rules for pattern matching in e-graphs --- .../foresight/eqsat/lang/Language.scala | 110 +++++++++++++++++- 1 file changed, 108 insertions(+), 2 deletions(-) diff --git a/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala b/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala index 32b0ff74..c2ae8a53 100644 --- a/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala +++ b/foresight/src/main/scala-3/foresight/eqsat/lang/Language.scala @@ -1,8 +1,8 @@ package foresight.eqsat.lang import foresight.eqsat.extraction.ExtractionAnalysis -import foresight.eqsat.rewriting.patterns.{Pattern, PatternApplier, PatternMatch} -import foresight.eqsat.rewriting.{ReversibleSearcher, Rule} +import foresight.eqsat.rewriting.patterns.{AbstractPatternMatch, Pattern, PatternApplier, PatternMatch, StateBorrowingMachineEClassSearcher} +import foresight.eqsat.rewriting.{ReversibleSearcher, Rule, Searcher} import foresight.eqsat.* import foresight.eqsat.immutable import foresight.eqsat.mutable @@ -376,6 +376,19 @@ trait Language[E]: (using enc: AtomEncoder[E, Pattern.Var]): ReversibleSearcher[Op, PatternMatch[Op], EGraphT] = toTree(e).toSearcher + /** + * Build a borrowing searcher from a surface expression, by first encoding to a pattern tree. + * + * Requires an encoder for [[foresight.eqsat.rewriting.patterns.Pattern.Var]] leaves, i.e., a way to turn surface + * leaves into pattern variables. + * + * @param e Surface-side pattern. + * @tparam EGraphT An e-graph type that supports this language. + */ + def toBorrowingSearcher[EGraphT <: EGraph[Op]](e: E) + (using enc: AtomEncoder[E, Pattern.Var]): Searcher[Op, AbstractPatternMatch[Op], EGraphT] = + StateBorrowingMachineEClassSearcher(toTree(e).compiled) + /** * Build a pattern applier from a surface expression, by first encoding to a pattern tree. * @@ -409,10 +422,39 @@ trait Language[E]: : Rule[Op, PatternMatch[Op], EGraphT] = Rule(name, toSearcher[EGraphT](lhs), toApplier[EGraphT](rhs)) + /** + * Construct a rewrite rule that uses a borrowing searcher under the hood. + * + * @param name Human-readable rule name (used in logs/diagnostics). + * @param lhs Surface pattern to match. + * @param rhs Surface template to build. + * @param enc Encoder for [[foresight.eqsat.rewriting.patterns.Pattern.Var]] leaves. + * @tparam EGraphT An e-graph type that supports this language. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String, lhs: E, rhs: E) + (using enc: AtomEncoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + Rule(name, toBorrowingSearcher[EGraphT](lhs), toApplier[EGraphT](rhs)) + /** Generate a fresh [[foresight.eqsat.rewriting.patterns.Pattern.Var]] and decode it back to a surface `E` value. */ private def freshPatternVar(using dec: AtomDecoder[E, Pattern.Var]): E = fromTree(MixedTree.Atom[Op, Pattern.Var](Pattern.Var.fresh())) + /** + * Create a borrowing rewrite rule by supplying a lambda that receives one fresh pattern variable. + * + * Works like [[rule(name)(f: E => (E,E))]] but compiles the LHS with a borrowing searcher. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: E => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives one fresh pattern variable. * @@ -481,6 +523,18 @@ trait Language[E]: val (lhs, rhs) = f(freshPatternVar, freshPatternVar) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with two fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar, freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives three fresh pattern variables. * @@ -509,6 +563,18 @@ trait Language[E]: val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with three fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives four fresh pattern variables. * @@ -538,6 +604,18 @@ trait Language[E]: val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with four fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f(freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives five fresh pattern variables. * @@ -569,6 +647,20 @@ trait Language[E]: ) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with five fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f( + freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar + ) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Create a rewrite rule by supplying a lambda that receives six fresh pattern variables. * @@ -600,6 +692,20 @@ trait Language[E]: ) rule[EGraphT](name, lhs, rhs) + /** + * Create a borrowing rewrite rule with six fresh variables. + */ + def borrowingRule[EGraphT <: EGraph[Op]] + (name: String) + (f: (E, E, E, E, E, E) => (E, E)) + (using enc: AtomEncoder[E, Pattern.Var], + dec: AtomDecoder[E, Pattern.Var]) + : Rule[Op, AbstractPatternMatch[Op], EGraphT] = + val (lhs, rhs) = f( + freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar, freshPatternVar + ) + borrowingRule[EGraphT](name, lhs, rhs) + /** * Reconstruct a surface expression `E` from an *analysis-view* of a single node. * From c34f93fd83ca0acfec1b640231707fe1db761055 Mon Sep 17 00:00:00 2001 From: jonathanvdc Date: Sun, 9 Nov 2025 13:01:31 -0500 Subject: [PATCH 23/23] Refactor Rules to use AbstractPatternMatch and state-borrowing rules --- .../scala-3/foresight/eqsat/examples/mm/Rules.scala | 10 +++++----- .../scala-3/foresight/eqsat/examples/mm/MMTest.scala | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala b/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala index 50d0da50..66ecdae8 100644 --- a/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala +++ b/examples/src/main/scala-3/foresight/eqsat/examples/mm/Rules.scala @@ -3,7 +3,7 @@ package foresight.eqsat.examples.mm import foresight.eqsat.lang.{Language, LanguageOp} import foresight.eqsat.readonly.EGraph import foresight.eqsat.rewriting.Rule -import foresight.eqsat.rewriting.patterns.PatternMatch +import foresight.eqsat.rewriting.patterns.AbstractPatternMatch import scala.language.implicitConversions @@ -13,17 +13,17 @@ import scala.language.implicitConversions final case class Rules()(using L: Language[LinalgExpr]) { type Op = LanguageOp[LinalgExpr] type LinalgEGraph = EGraph[LinalgIR] - type LinalgRule = Rule[LinalgIR, PatternMatch[LinalgIR], LinalgEGraph] + type LinalgRule = Rule[LinalgIR, AbstractPatternMatch[LinalgIR], LinalgEGraph] - import L.rule + import L.borrowingRule val matMulAssociativity1: LinalgRule = - rule("mul-associativity1") { (x, y, z) => + borrowingRule("mul-associativity1") { (x, y, z) => ((x * y) * z) -> (x * (y * z)) } val matMulAssociativity2: LinalgRule = - rule("mul-associativity2") { (x, y, z) => + borrowingRule("mul-associativity2") { (x, y, z) => (x * (y * z)) -> ((x * y) * z) } diff --git a/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala b/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala index d589bdbf..c3371166 100644 --- a/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala +++ b/examples/src/test/scala-3/foresight/eqsat/examples/mm/MMTest.scala @@ -1,8 +1,7 @@ package foresight.eqsat.examples.mm -import foresight.eqsat.examples.mm.{Fact, LinalgExpr, LinalgIR, Mat, Mul, Rules, *} import foresight.eqsat.lang._ -import foresight.eqsat.saturation.{MaximalRuleApplication, MaximalRuleApplicationWithCaching, Strategy} +import foresight.eqsat.saturation.{MaximalRuleApplication, Strategy} import foresight.eqsat.EClassCall import foresight.eqsat.immutable.EGraph import foresight.eqsat.mutable