Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions src/kirin/dialects/scf/unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from kirin.dialects import func
from kirin.rewrite.abc import RewriteRule, RewriteResult
from kirin.dialects.py.constant import Constant
from kirin.dialects.py.indexing import GetItem

from .stmts import For, Yield, IfElse

Expand Down Expand Up @@ -54,21 +55,35 @@ def insert_body(self, node: IfElse, body: ir.Region):

class ForLoop(RewriteRule):

def yield_item_results_const(self, node: For, hint: const.Value):
for item in hint.data:
item_stmt = Constant(item)
item_stmt.insert_before(node)
yield item_stmt.result

def yield_item_results_from_len(self, node: For, len_iterable: int):
for i in range(len_iterable):
(index_stmt := Constant(i)).insert_before(node)
(item_stmt := GetItem(node.iterable, index_stmt.result)).insert_before(node)
yield item_stmt.result

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not isinstance(node, For):
return RewriteResult()

# TODO: support for PartialTuple and IList with known length
if not isinstance(hint := node.iterable.hints.get("const"), const.Value):
if isinstance(hint := node.iterable.hints.get("const"), const.Value):
item_results = self.yield_item_results_const(node, hint)
elif isinstance(hint, const.PartialTuple):
item_results = self.yield_item_results_from_len(node, len(hint.data))
else:
return RewriteResult()

loop_vars = node.initializers
for item in hint.data:

for item_result in item_results:
body = node.body.clone()
block = body.blocks[0]
item_stmt = Constant(item)
item_stmt.insert_before(node)
block.args[0].replace_by(item_stmt.result)
block.args[0].replace_by(item_result)
for var, input in zip(block.args[1:], loop_vars):
var.replace_by(input)

Expand Down
30 changes: 30 additions & 0 deletions test/dialects/scf/test_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,33 @@ def simple_loop(x):
assert isinstance(stmts.at(6), py.Add)
assert isinstance(stmts.at(7), func.Return)
assert simple_loop(1) == 4


def test_partial_tuple_loop_unroll():
@structural_no_opt
def simple_loop(a: int, b: int, c: int):
x = 0
for i in (a, b, c):
x = x + i
return x

fold = Fold(structural_no_opt)
fold(simple_loop)
Walk(scf.unroll.ForLoop()).rewrite(simple_loop.code)
fold(simple_loop)

# after fold the `getitem` should be eliminated as well
# leaving just the block arguments being added directly
# to `x`
assert len(simple_loop.callable_region.blocks) == 1
block = simple_loop.callable_region.blocks[0]
args = block.args
stmts = block.stmts
assert isinstance(stmts.at(0), py.Constant)
assert isinstance(stmt := stmts.at(1), py.Add)
assert stmt.rhs is args[1]
assert isinstance(stmt := stmts.at(2), py.Add)
assert stmt.rhs is args[2]
assert isinstance(stmt := stmts.at(3), py.Add)
assert stmt.rhs is args[3]
assert isinstance(stmts.at(4), func.Return)
Loading