Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 81 additions & 21 deletions foresight/src/main/scala/foresight/eqsat/ENode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import foresight.eqsat.collections.{SlotMap, SlotSeq, SlotSet}
import foresight.util.Debug
import foresight.util.collections.UnsafeSeqFromArray

import java.util.concurrent.atomic.AtomicInteger
import scala.collection.compat.immutable.ArraySeq

/**
Expand All @@ -28,8 +27,8 @@ final class ENode[+NodeT] private (
private val _uses: Array[Slot],
private val _args: Array[EClassCall]
) extends Node[NodeT, EClassCall] with ENodeSymbol[NodeT] {
// Cached hash code to make hashing and equality fast
private val _hash: AtomicInteger = new AtomicInteger(0)
// Cached hash code to make hashing and equality fast (benign data race; like String.hash)
private var _hash: Int = 0

/**
* Slots introduced by this node that are scoped locally and invisible to parents. These are
Expand Down Expand Up @@ -294,49 +293,110 @@ final class ENode[+NodeT] private (
// --- case-class-like API preservation ---
override def toString: String = s"ENode($nodeType, $definitions, $uses, $args)"

@inline
private def callsEqualFast(a: Array[EClassCall], b: Array[EClassCall]): Boolean = {
if (a eq b) return true
if (a.length != b.length) return false
var i = 0
while (i < a.length) {
val ai = a(i); val bi = b(i)
// Compare the e-class ids first (cheap)
if (ai.ref.id != bi.ref.id) return false
// Then compare the SlotMap by reference first, fall back to equals only if needed
val aArgs = ai.args; val bArgs = bi.args
if ((aArgs ne bArgs) && !(aArgs == bArgs)) return false
i += 1
}
true
}

@inline
private def slotsEqualByRef(a: Array[Slot], b: Array[Slot]): Boolean = {
if (a eq b) return true
if (a.length != b.length) return false
var i = 0
while (i < a.length) {
if (a(i) ne b(i)) return false
i += 1
}
true
}

//noinspection ComparingUnrelatedTypes
override def equals(other: Any): Boolean = other match {
case that: ENode[_] =>
this.nodeType == that.nodeType &&
ENode.arraysEqual(this._definitions, that._definitions) &&
ENode.arraysEqual(this._uses, that._uses) &&
ENode.arraysEqual(this._args, that._args)
if (this eq that) return true
// Cheap hash pre-check to avoid deep scans during collision probes
if (this.hashCode() != that.hashCode()) return false
// Now structural checks
(this.nodeType == that.nodeType) &&
(this._definitions.length == that._definitions.length) &&
(this._uses.length == that._uses.length) &&
(this._args.length == that._args.length) &&
slotsEqualByRef(this._definitions, that._definitions) &&
slotsEqualByRef(this._uses, that._uses) &&
callsEqualFast(this._args, that._args)
case _ => false
}

private def computeHash(): Int = {
var h = 1
h = 31 * h + (if (nodeType == null) 0 else nodeType.hashCode)
@inline private def mix(h: Int, data: Int): Int = {
var k = data
k *= 0xcc9e2d51
k = (k << 15) | (k >>> 17)
k *= 0x1b873593
var res = h ^ k
res = (res << 13) | (res >>> 19)
res = res * 5 + 0xe6546b64
res
}
@inline private def avalanche(h: Int, len: Int): Int = {
var x = h ^ len
x ^= (x >>> 16)
x *= 0x85ebca6b
x ^= (x >>> 13)
x *= 0xc2b2ae35
x ^= (x >>> 16)
x
}

private def computeHash(): Int = {
// Murmur3-style mix for better avalanche; keep consistent with equals:
// - Slots are compared by reference => use identityHashCode for slots
// - Args compare ref.id and SlotMap by (ref or equals) => use id + identityHashCode(map) as primary signal
var h = 0
val nt = if (nodeType == null) 0 else nodeType.hashCode
h = mix(h, nt)
// definitions (by reference)
var i = 0
while (i < _definitions.length) {
h = 31 * h + _definitions(i).hashCode()
h = mix(h, System.identityHashCode(_definitions(i)))
i += 1
}

// uses (by reference)
i = 0
while (i < _uses.length) {
h = 31 * h + _uses(i).hashCode()
h = mix(h, System.identityHashCode(_uses(i)))
i += 1
}

// args: mix ref id and structure of the SlotMap; include size to separate small maps
i = 0
while (i < _args.length) {
val arg = _args(i)
h = 31 * h + arg.ref.id
h = 31 * h + arg.args.hashCode()
h = mix(h, arg.ref.id)
// Use structural hash for SlotMap to remain consistent with equals (which may consider distinct instances equal)
h = mix(h, arg.args.hashCode())
i += 1
}
h
// length salt to distinguish permutations/shapes with same prefix
avalanche(h, 1 + _definitions.length + _uses.length + (_args.length << 1))
}

override def hashCode(): Int = {
val cached = _hash.get()
@inline override def hashCode(): Int = {
val cached = _hash
if (cached != 0) return cached

val h = computeHash()
val result = if (h == 0) 1 else h
_hash.compareAndSet(0, result)
_hash = result
result
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ final case class CommandSchedule[NodeT](batchZero: (ArraySeq[EClassSymbol.Virtua
def apply(egraph: mutable.EGraph[NodeT],
parallelize: ParallelMap): Boolean = {

val reification = new util.IdentityHashMap[EClassSymbol.Virtual, foresight.eqsat.EClassCall]()
val reificationMapEntries = batchZero._1.length + otherBatches.map(_._1.length).sum
val reification = new util.IdentityHashMap[EClassSymbol.Virtual, foresight.eqsat.EClassCall](reificationMapEntries)

var anyChanges: Boolean = false
anyChanges = anyChanges | applyBatchZero(egraph, parallelize, reification)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,6 @@ private[hashCons] abstract class AbstractMutableHashConsEGraph[NodeT]
}
}

/**
* Query the hash cons for the given node. Returns null if the node is not in the hash cons.
*
* @param node The node to query.
* @return The e-class reference of the node, or null if the node is not in the hash cons.
*/
private def nodeToClassOrNull(node: ENode[NodeT]): EClassRef = {
nodeToRefOrElse(node, EClassRef.Invalid)
}

private def slots(ref: EClassRef): SlotSet = dataForClass(ref).slots

/**
Expand Down Expand Up @@ -393,7 +383,7 @@ private[hashCons] abstract class AbstractMutableHashConsEGraph[NodeT]
* @param node The node to repair.
*/
def repairNodeWithoutSlots(node: ENode[NodeT]): Unit = {
val ref = nodeToClassOrNull(node)
val ref = nodeToRefOrInvalid(node)
if (Debug.isEnabled) {
assert(ref != EClassRef.Invalid, "The node to repair must be in the hash-cons.")
assert(!node.hasSlots, "The node to repair must not have slots.")
Expand All @@ -411,7 +401,7 @@ private[hashCons] abstract class AbstractMutableHashConsEGraph[NodeT]
val canonicalNode = canonicalizeWithoutSlots(node)

if (canonicalNode != node) {
nodeToClassOrNull(canonicalNode) match {
nodeToRefOrInvalid(canonicalNode) match {
case EClassRef.Invalid =>
// Eliminate the old node from the e-class and add the canonicalized node.
removeNodeFromClass(ref, node)
Expand Down Expand Up @@ -448,7 +438,7 @@ private[hashCons] abstract class AbstractMutableHashConsEGraph[NodeT]
// 3. The canonicalized node is different from the original node, and the canonicalized node is not in the
// hash-cons map. In this case, we add the canonicalized node to the hash-cons and queue its arguments
// for parent set repair.
val ref = nodeToClassOrNull(node)
val ref = nodeToRefOrInvalid(node)
if (Debug.isEnabled) {
assert(ref != EClassRef.Invalid, "The node to repair must be in the hash-cons.")
}
Expand Down Expand Up @@ -488,7 +478,7 @@ private[hashCons] abstract class AbstractMutableHashConsEGraph[NodeT]
}

if (canonicalNode.shape != node) {
nodeToClassOrNull(canonicalNode.shape) match {
nodeToRefOrInvalid(canonicalNode.shape) match {
case EClassRef.Invalid =>
// Eliminate the old node from the e-class and add the canonicalized node.
removeNodeFromClass(ref, node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ private[hashCons] trait ReadOnlyHashConsEGraph[NodeT] extends EGraph[NodeT] {
*/
protected val unionFind: UnionFind

/**
* Retrieves the e-class reference for a given e-node, or returns an invalid reference if the e-node is not found.
* This method does not canonicalize the e-node before looking it up; it assumes the caller has already done so
* and simply performs a hash cons lookup.
* @param node The e-node to look up.
* @return The e-class reference corresponding to the e-node, or an invalid reference if not found.
*/
def nodeToRefOrInvalid(node: ENode[NodeT]): EClassRef

/**
* Retrieves the e-class reference for a given e-node, or returns a default value if the e-node is not found.
* This method does not canonicalize the e-node before looking it up; it assumes the caller has already done so
Expand All @@ -38,7 +47,10 @@ private[hashCons] trait ReadOnlyHashConsEGraph[NodeT] extends EGraph[NodeT] {
* @param default A default e-class reference to return if the e-node is not found.
* @return The e-class reference corresponding to the e-node, or the default value if not found.
*/
def nodeToRefOrElse(node: ENode[NodeT], default: => EClassRef): EClassRef
final def nodeToRefOrElse(node: ENode[NodeT], default: => EClassRef): EClassRef = {
val ref = nodeToRefOrInvalid(node)
if (ref.isInvalid) default else ref
}

/**
* Retrieves the data associated with a given e-class. Assumes that the e-class reference is canonical.
Expand Down Expand Up @@ -165,7 +177,7 @@ private[hashCons] trait ReadOnlyHashConsEGraph[NodeT] extends EGraph[NodeT] {
assert(renamedShape.shape.isShape)
}

val ref = nodeToRefOrElse(renamedShape.shape, EClassRef.Invalid)
val ref = nodeToRefOrInvalid(renamedShape.shape)
if (ref.isInvalid) {
return null
}
Expand All @@ -185,7 +197,7 @@ private[hashCons] trait ReadOnlyHashConsEGraph[NodeT] extends EGraph[NodeT] {
assert(!node.hasSlots)
}

val ref = nodeToRefOrElse(node, EClassRef.Invalid)
val ref = nodeToRefOrInvalid(node)
if (ref.isInvalid) {
return null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ private[eqsat] final case class HashConsEGraph[NodeT] private[hashCons](protecte
unionFind.findOrNull(ref)
}

override def nodeToRefOrElse(node: ENode[NodeT], default: => EClassRef): EClassRef = {
hashCons.getOrElse(node, default)
override def nodeToRefOrInvalid(node: ENode[NodeT]): EClassRef = {
hashCons.getOrElse(node, EClassRef.Invalid)
}

override def dataForClass(ref: EClassRef): EClassData[NodeT] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ private final class HashConsEGraphBuilder[NodeT](protected override val unionFin

protected override def shapes: Iterable[ENode[NodeT]] = hashCons.keys

override def nodeToRefOrElse(node: ENode[NodeT], default: => EClassRef): EClassRef = {
hashCons.getOrElse(node, default)
override def nodeToRefOrInvalid(node: ENode[NodeT]): EClassRef = {
hashCons.getOrElse(node, EClassRef.Invalid)
}

override def dataForClass(ref: EClassRef): EClassData[NodeT] = {
Expand Down
Loading