From 8675ac2cf0453d5a38df4c93f46382f4d8106793 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 8 Jan 2026 09:36:56 -0500 Subject: [PATCH 1/2] updating unroll to handle PartialTuple --- src/kirin/dialects/scf/unroll.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/kirin/dialects/scf/unroll.py b/src/kirin/dialects/scf/unroll.py index 09e2ef794..0d1ac33c7 100644 --- a/src/kirin/dialects/scf/unroll.py +++ b/src/kirin/dialects/scf/unroll.py @@ -3,6 +3,7 @@ from kirin.dialects import func from kirin.rewrite.abc import RewriteRule, RewriteResult from kirin.dialects.py.constant import Constant +from kirin.dialects.py.indexing import GetItem from .stmts import For, Yield, IfElse @@ -54,21 +55,35 @@ def insert_body(self, node: IfElse, body: ir.Region): class ForLoop(RewriteRule): + def yield_item_results_const(self, node: For, hint: const.Value): + for item in hint.data: + item_stmt = Constant(item) + item_stmt.insert_before(node) + yield item_stmt.result + + def yield_item_results_from_len(self, node: For, len_iterable: int): + for i in range(len_iterable): + (index_stmt := Constant(i)).insert_before(node) + (item_stmt := GetItem(node.iterable, index_stmt.result)).insert_before(node) + yield item_stmt.result + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if not isinstance(node, For): return RewriteResult() - # TODO: support for PartialTuple and IList with known length - if not isinstance(hint := node.iterable.hints.get("const"), const.Value): + if isinstance(hint := node.iterable.hints.get("const"), const.Value): + item_results = self.yield_item_results_const(node, hint) + elif isinstance(hint, const.PartialTuple): + item_results = self.yield_item_results_from_len(node, len(hint.data)) + else: return RewriteResult() loop_vars = node.initializers - for item in hint.data: + + for item_result in item_results: body = node.body.clone() block = body.blocks[0] - item_stmt = Constant(item) - item_stmt.insert_before(node) - block.args[0].replace_by(item_stmt.result) + block.args[0].replace_by(item_result) for var, input in zip(block.args[1:], loop_vars): var.replace_by(input) From 70ea1b473d2552a722aeb4dc769d1ed5e5c9f30e Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 8 Jan 2026 09:50:01 -0500 Subject: [PATCH 2/2] adding test for partial tuple branch of rewrite --- test/dialects/scf/test_unroll.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/dialects/scf/test_unroll.py b/test/dialects/scf/test_unroll.py index 876bd0673..c04e610f6 100644 --- a/test/dialects/scf/test_unroll.py +++ b/test/dialects/scf/test_unroll.py @@ -25,3 +25,33 @@ def simple_loop(x): assert isinstance(stmts.at(6), py.Add) assert isinstance(stmts.at(7), func.Return) assert simple_loop(1) == 4 + + +def test_partial_tuple_loop_unroll(): + @structural_no_opt + def simple_loop(a: int, b: int, c: int): + x = 0 + for i in (a, b, c): + x = x + i + return x + + fold = Fold(structural_no_opt) + fold(simple_loop) + Walk(scf.unroll.ForLoop()).rewrite(simple_loop.code) + fold(simple_loop) + + # after fold the `getitem` should be eliminated as well + # leaving just the block arguments being added directly + # to `x` + assert len(simple_loop.callable_region.blocks) == 1 + block = simple_loop.callable_region.blocks[0] + args = block.args + stmts = block.stmts + assert isinstance(stmts.at(0), py.Constant) + assert isinstance(stmt := stmts.at(1), py.Add) + assert stmt.rhs is args[1] + assert isinstance(stmt := stmts.at(2), py.Add) + assert stmt.rhs is args[2] + assert isinstance(stmt := stmts.at(3), py.Add) + assert stmt.rhs is args[3] + assert isinstance(stmts.at(4), func.Return)