Skip to content

Commit f46b1e1

Browse files
authored
fix[next][dace]: Remove isolated access nodes for unused args in let-lambda (#2422)
This PR handles `SymRefs` to symbols which are not used, and which are passed as arguments to let-lambdas. The cleanup stage was removed in #2418, resulting in isolated access nodes. Test coverage added.
1 parent 116024d commit f46b1e1

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,20 @@ def visit_Lambda(
12351235

12361236
ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet)
12371237

1238+
# We can now safely remove all remaining arguments, because unused.
1239+
# It can happen that a let-lambda argument is not used, in which case it
1240+
# should be `None` at this stage, but the symbol passed as argument is captured
1241+
# by the let-lambda expression and thus can be accessed directly inside the
1242+
# inner scope. The domain in the annex of this symbol is not `NEVER`, and
1243+
# an access node is created, so we find an isolated access node here.
1244+
if unused_access_nodes := [
1245+
arg_node.dc_node
1246+
for arg_node in lambda_arg_nodes.values()
1247+
if not (arg_node is None or arg_node.dc_node.desc(ctx.sdfg).transient)
1248+
]:
1249+
assert all(ctx.state.degree(access_node) == 0 for access_node in unused_access_nodes)
1250+
ctx.state.remove_nodes_from(unused_access_nodes)
1251+
12381252
def construct_output_for_nested_sdfg(
12391253
inner_data: gtir_to_sdfg_types.FieldopData,
12401254
) -> gtir_to_sdfg_types.FieldopData:

src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,8 +618,8 @@ def translate_tuple_get(
618618
index = int(node.args[0].value)
619619

620620
data_nodes = sdfg_builder.visit(node.args[1], ctx=ctx)
621-
if isinstance(data_nodes, gtir_to_sdfg_types.FieldopData):
622-
raise ValueError(f"Invalid tuple expression {node}")
621+
if not isinstance(data_nodes, tuple):
622+
raise ValueError(f"Invalid tuple expression {node}.")
623623
# Now we remove the tuple fields that are not used, to avoid an SDFG validation
624624
# error because of isolated access nodes.
625625
unused_data_nodes = gtx_utils.flatten_nested_tuple(
@@ -724,6 +724,11 @@ def translate_symbol_ref(
724724
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
725725
assert isinstance(node, gtir.SymRef)
726726

727+
# If the symbol is not used, the domain is set to 'NEVER' in node annex.
728+
node_domain = getattr(node.annex, "domain", infer_domain.DomainAccessDescriptor.UNKNOWN)
729+
if node_domain == infer_domain.DomainAccessDescriptor.NEVER:
730+
return None
731+
727732
symbol_name = str(node.id)
728733
# we retrieve the type of the symbol in the GT4Py prgram
729734
gt_symbol_type = sdfg_builder.get_symbol_type(symbol_name)

tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,38 @@ def test_gtir_let_lambda():
16471647
assert np.allclose(b, ref)
16481648

16491649

1650+
def test_gtir_let_lambda_unused_arg():
1651+
testee = gtir.Program(
1652+
id="let_lambda_unused_arg",
1653+
function_definitions=[],
1654+
params=[
1655+
gtir.Sym(id="x", type=IFTYPE),
1656+
gtir.Sym(id="y", type=IFTYPE),
1657+
gtir.Sym(id="z", type=IFTYPE),
1658+
],
1659+
declarations=[],
1660+
body=[
1661+
gtir.SetAt(
1662+
# Arg 'xᐞ1' is used inside the let-lambda, 'yᐞ1' is not.
1663+
expr=im.let(("xᐞ1", im.op_as_fieldop("multiplies")("x", 3.0)), ("yᐞ1", "y"))(
1664+
im.op_as_fieldop("multiplies")("xᐞ1", 2.0)
1665+
),
1666+
domain=im.get_field_domain(gtx_common.GridType.CARTESIAN, "z", [IDim]),
1667+
target=gtir.SymRef(id="z"),
1668+
)
1669+
],
1670+
)
1671+
1672+
a = np.random.rand(N)
1673+
b = np.random.rand(N)
1674+
c = np.random.rand(N)
1675+
1676+
sdfg = build_dace_sdfg(testee, {})
1677+
1678+
sdfg(a, b, c, **FSYMBOLS)
1679+
assert np.allclose(c, a * 6.0)
1680+
1681+
16501682
def test_gtir_let_lambda_scalar_expression():
16511683
domain_inner = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, "size_inner")})
16521684
domain_outer = im.get_field_domain(

0 commit comments

Comments
 (0)