From 56b6b2bac28fb576fa2f19defe3a3762872b873d Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Thu, 4 Dec 2025 08:55:51 +0000 Subject: [PATCH 01/17] wip Signed-off-by: Yaoyao Ding --- .vscode/settings.json | 3 +- python/tilus/ir/layout/shared_layout.py | 173 +++++++++++++----------- 2 files changed, 99 insertions(+), 77 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index c41362fd..870ed23e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,6 @@ "tests" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "python.analysis.supportDocstringTemplate": true } diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index 3636c59b..0dbc496e 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -15,40 +15,71 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, Dict, List, Sequence +from typing import Callable, Dict, List, Sequence, Optional from hidet.ir.expr import Expr, Var, as_expr +from hidet.ir.utils.index_transform import index_deserialize from tilus.extensions.hidet.ir.expr import index_vars from tilus.ir.node import IRNode +from tilus.ir.layout.ops.utils import get_mode_groups +@dataclass(frozen=True, eq=True) +class Swizzle: + """ + A swizzle function. + + 0xxxxYYYxxZZZxxxx + z_mask = ((1 << self.bbits) - 1) << self.mbase + y_mask = ((1 << self.bbits) - 1) << (self.mbase + self.sshift) + return offset ^ ((offset & y_mask) >> self.sshift) + """ + base: int + bits: int + shift: int + + def __call__(self, index: Expr) -> Expr: + # we use a primitive function to here + # todo: use general computation after refactor to cute-like shared layout + from tilus.extensions.hidet.ir.primitives.swizzle import swizzle + + if self.bits == 0: + return index + return swizzle(index, self.base, self.bits, self.shift) + @dataclass(frozen=True, eq=False) class SharedLayout(IRNode): """The layout for shared tensor. + We use three components to describe a shared tensor layout: the shape, the mode shape, and the mode strides. + + The mode shape and mode strides are used to describe how to split each dimension into multiple sub-dimensions (modes), + and the strides of each mode. + + For example, consider a shape of (64, 32), we can split the first dimension into two sub-dimensions (modes) of size 8 and 8, + and the second dimension into two sub-dimensions (modes) of size 16 and 2. The mode shape would be (8, 8, 16, 2). We can + have strides for each mode, for example, (256, 2, 16, 1). Then given the indices (i, j), we can compute the indices in the + sub-dimensions (i1, i2, j1, j2) where i1 = i // 8, i2 = i % 8, j1 = j // 2, j2 = j % 2. The offset can be computed as: + offset = i1 * 256 + i2 * 2 + j1 * 16 + j2 * 1. To get the final offset in the shared tensor, we can use the formula: + (i, j) => ((i // 8) * 256) + ((i % 8) * 2) + ((j // 2) * 16) + ((j % 2) * 1). + Attributes ---------- shape: tuple[int, ...] The shape of the shared tensor. Each dimension is a constant integer. - size: int - The storage size of the shared tensor, in number of elements. If the layout is a `compact` layout, size - should be equal to the product of the shape dimensions. Otherwise, it can be either larger (in case of padding) - or smaller (in case of sharing data for different elements) than the product of the shape dimensions. The - size must be a constant integer. - axes: tuple[Var, ...] - The axes of the shared tensor. Each axis is a variable that represents the index of the corresponding dimension. - It should have the same length as the shape. - offset: Expr - The offset expression of the shared tensor based on the axes. It is an expression that computes the offset - of the shared tensor based on the axes. Only the axes and variables that are invariant in the lifetime of the - given corresponding shared tensor with this layout can be used in the expression. + mode_shape: tuple[int, ...] + We can split each dimension into multiple sub-dimensions (modes). + mode_strides: tuple[int, ...] + The strides of each mode. + swizzle: Optional[Swizzle] + The swizzle function to apply on the final offset. If None, no swizzling is applied. """ shape: tuple[int, ...] - size: int - axes: tuple[Var, ...] - offset: Expr + mode_shape: tuple[int, ...] + mode_strides: tuple[int, ...] + swizzle: Optional[Swizzle] def __call__(self, *indices: Expr) -> Expr: """Compute the offset on given indices. @@ -65,82 +96,72 @@ def __call__(self, *indices: Expr) -> Expr: ret: Expr The computed offset of the shared tensor element at the given indices. """ - assert len(indices) == len(self.axes) - from hidet.ir.tools import rewrite - - return rewrite(self.offset, rewrite_map={axis: index for axis, index in zip(self.axes, indices)}) + # get the stride-based index + group_modes = get_mode_groups(self.shape, self.mode_shape) + mode_indices: list[Expr] = [] + for index, modes in zip(indices, group_modes): + mode_indices.extend(index_deserialize(index, shape=[self.mode_shape[m] for m in modes])) + total_index = sum(index * stride for index, stride in zip(mode_indices, self.mode_strides)) + + # apply swizzle if exists + if self.swizzle is not None: + total_index = self.swizzle(total_index) + + return total_index @staticmethod - def create(shape: Sequence[int], size: int, f_offset: Callable[[Sequence[Var]], Expr | int]) -> SharedLayout: - """Create a shared layout. - - This method creates a shared layout with the given shape, size, and a function to compute the offset based on - the axes. The shape must be a sequence of constant integers, and the size must be a constant integer that is - larger than the maximum possible offset computed by the `f_offset` function. - + def create(shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequence[int], swizzle: Optional[Swizzle]) -> SharedLayout: + """ + Create a SharedLayout from shape, mode_shape, and mode_strides. + Parameters ---------- shape: Sequence[int] - The shape of the shared tensor. Each dimension is a constant integer. - size: int - The storage size of the shared tensor, in number of elements. - f_offset: Callable[[Sequence[Var]], Expr] - The function that computes the offset of the shared tensor based on the axes. It takes a sequence of - axes (variables) and returns an expression that computes the offset. The function must ensure that the - size is larger than the maximum possible offset computed by this function. + The shape of the shared tensor. + mode_shape: Sequence[int] + The mode shape of the shared tensor. + mode_strides: Sequence[int] + The mode strides of the shared tensor. + swizzle: Optional[Swizzle] + The swizzle function to apply on the final offset. If None, no swizzling is applied. Returns ------- ret: SharedLayout - A shared layout with the specified shape, size, axes, and offset. + The created SharedLayout. """ - axes: List[Var] = index_vars(num_vars=len(shape)) - return SharedLayout(shape=tuple(shape), size=size, axes=tuple(axes), offset=as_expr(f_offset(axes))) + if any(s < 1 for s in shape): + raise ValueError("All dimensions in shape must be positive integers.") + if len(mode_shape) != len(mode_strides): + raise ValueError("mode_shape and mode_strides must have the same length.") + return SharedLayout(shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), swizzle=swizzle) def slice(self, offsets: Sequence[Expr], slice_dims: Sequence[int], slice_shape: Sequence[int]) -> SharedLayout: - assert len(set(slice_dims)) == len(slice_dims), "slice_dims must be unique" - assert len(slice_shape) == len(slice_dims), "slice_dims and slice_shape must have the same length" - assert len(slice_dims) <= len(self.shape), "slice_dims must be less than or equal to the number of dimensions" - - def f_offset(axes: Sequence[Var]) -> Expr: - indices: List[Expr] = list(offsets) - for dim, axis in zip(slice_dims, axes): - indices[dim] = indices[dim] + axis - return self(*indices) - self(*offsets) - - return SharedLayout.create(shape=slice_shape, size=self.size, f_offset=f_offset) + raise RuntimeError("No slice anymore.") def simplify(self) -> SharedLayout: - from tilus.extensions.hidet.ir.tools import simplify_expr - from tilus.extensions.hidet.transforms.rule_based_simplifier import BoundInfo, RuleBasedSimplifier - - var2bound: Dict[Var, BoundInfo] = { - axis: BoundInfo(min_value=0, max_value=extent - 1) for axis, extent in zip(self.axes, self.shape) - } - simplifier = RuleBasedSimplifier(var2bound=var2bound) - return SharedLayout( - shape=self.shape, size=self.size, axes=self.axes, offset=simplify_expr(simplifier(self.offset)) - ) + raise RuntimeError("No need to simplify anymore.") def swizzle(self, dim: int, regards_dim: int, log_step: int) -> SharedLayout: - ndims = len(self.shape) - assert 0 <= dim < ndims and 0 <= regards_dim < ndims and dim != regards_dim - - def get_xor_index(indices: Sequence[Expr]) -> Expr: - indices = list(indices) # copy - step = 2**log_step - regards_index = indices[regards_dim] // step - regards_extent = self.shape[regards_dim] // step - if regards_extent > self.shape[dim]: - regards_index = regards_index % self.shape[dim] - return regards_index - - def f_offset(axes: Sequence[Var]) -> Expr: - swizzled_indices: List[Expr] = [axis for axis in axes] - swizzled_indices[dim] = swizzled_indices[dim] ^ get_xor_index(axes) - return self(*swizzled_indices) - - return SharedLayout.create(shape=self.shape, size=self.size, f_offset=f_offset) + raise RuntimeError("Update swizzle.") + # ndims = len(self.shape) + # assert 0 <= dim < ndims and 0 <= regards_dim < ndims and dim != regards_dim + + # def get_xor_index(indices: Sequence[Expr]) -> Expr: + # indices = list(indices) # copy + # step = 2**log_step + # regards_index = indices[regards_dim] // step + # regards_extent = self.shape[regards_dim] // step + # if regards_extent > self.shape[dim]: + # regards_index = regards_index % self.shape[dim] + # return regards_index + + # def f_offset(axes: Sequence[Var]) -> Expr: + # swizzled_indices: List[Expr] = [axis for axis in axes] + # swizzled_indices[dim] = swizzled_indices[dim] ^ get_xor_index(axes) + # return self(*swizzled_indices) + + # return SharedLayout.create(shape=self.shape, size=self.size, f_offset=f_offset) def prepend_dim(self, extent: int) -> SharedLayout: def f_offset(axes: Sequence[Var]) -> Expr: From 26304dd3ce53f10f3c3a914067ca9409298cb599 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Thu, 4 Dec 2025 20:57:55 +0000 Subject: [PATCH 02/17] wip Signed-off-by: Yaoyao Ding --- python/tilus/ir/layout/ops/shared_ops.py | 16 +++++++ python/tilus/ir/layout/shared_layout.py | 59 ++++++++++++++---------- python/tilus/lang/modules/cuda.py | 6 +-- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/python/tilus/ir/layout/ops/shared_ops.py b/python/tilus/ir/layout/ops/shared_ops.py index 31bb7bc8..5d56f102 100644 --- a/python/tilus/ir/layout/ops/shared_ops.py +++ b/python/tilus/ir/layout/ops/shared_ops.py @@ -133,6 +133,22 @@ def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: axes = tuple(layout.axes[d] for d in dims) return SharedLayout(shape=shape, size=layout.size, axes=axes, offset=layout.offset) +def shared_unsqueeze(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: + shape = [] + cur_dim = 0 + for i in range(len(self.shape) + len(dims)): + if i in dims: + shape.append(1) + else: + shape.append(self.shape[cur_dim]) + cur_dim += 1 + + def f_offset(axes: Sequence[Var]) -> Expr: + base_axes = [axis for i, axis in enumerate(axes) if i not in dims] + return self(*base_axes) + + return SharedLayout.create(shape=shape, size=self.size, f_offset=f_offset) + def visualize_layout(layout: SharedLayout, tablefmt: str = "simple_grid") -> str: """ diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index 0dbc496e..352779e7 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -101,7 +101,7 @@ def __call__(self, *indices: Expr) -> Expr: mode_indices: list[Expr] = [] for index, modes in zip(indices, group_modes): mode_indices.extend(index_deserialize(index, shape=[self.mode_shape[m] for m in modes])) - total_index = sum(index * stride for index, stride in zip(mode_indices, self.mode_strides)) + total_index: Expr = as_expr(sum(index * stride for index, stride in zip(mode_indices, self.mode_strides))) # apply swizzle if exists if self.swizzle is not None: @@ -136,17 +136,31 @@ def create(shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequen raise ValueError("mode_shape and mode_strides must have the same length.") return SharedLayout(shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), swizzle=swizzle) + @property + def size(self) -> int: + """Get the total size of the shared layout. + + It is the minimum number of elements required to store the tensor in shared memory. + + Returns + ------- + ret: int + The total size of the shared layout. + """ + indices = [extent - 1 for extent in self.mode_shape] + max_index = sum(a * b for a, b in zip(indices, self.mode_strides)) + return max_index + 1 + def slice(self, offsets: Sequence[Expr], slice_dims: Sequence[int], slice_shape: Sequence[int]) -> SharedLayout: raise RuntimeError("No slice anymore.") def simplify(self) -> SharedLayout: raise RuntimeError("No need to simplify anymore.") - def swizzle(self, dim: int, regards_dim: int, log_step: int) -> SharedLayout: + def with_swizzle(self, dim: int, regards_dim: int, log_step: int) -> SharedLayout: raise RuntimeError("Update swizzle.") # ndims = len(self.shape) # assert 0 <= dim < ndims and 0 <= regards_dim < ndims and dim != regards_dim - # def get_xor_index(indices: Sequence[Expr]) -> Expr: # indices = list(indices) # copy # step = 2**log_step @@ -155,20 +169,27 @@ def swizzle(self, dim: int, regards_dim: int, log_step: int) -> SharedLayout: # if regards_extent > self.shape[dim]: # regards_index = regards_index % self.shape[dim] # return regards_index - # def f_offset(axes: Sequence[Var]) -> Expr: # swizzled_indices: List[Expr] = [axis for axis in axes] # swizzled_indices[dim] = swizzled_indices[dim] ^ get_xor_index(axes) # return self(*swizzled_indices) - # return SharedLayout.create(shape=self.shape, size=self.size, f_offset=f_offset) def prepend_dim(self, extent: int) -> SharedLayout: - def f_offset(axes: Sequence[Var]) -> Expr: - tile_offset = axes[0] * self.size - return tile_offset + self(*axes[1:]) - - return SharedLayout.create(shape=(extent,) + self.shape, size=extent * self.size, f_offset=f_offset) + shape = (extent,) + self.shape + if extent > 1: + mode_shape = (extent,) + self.mode_shape + mode_strides = (self.size,) + self.mode_strides + else: + mode_shape = self.mode_shape + mode_strides = self.mode_strides + + return SharedLayout.create( + shape=shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + swizzle=self.swizzle, + ) def transpose(self) -> SharedLayout: assert len(self.shape) == 2 @@ -180,20 +201,10 @@ def permute(self, dims: Sequence[int]) -> SharedLayout: return shared_permute(self, dims) def unsqueeze(self, dims: Sequence[int]) -> SharedLayout: - shape = [] - cur_dim = 0 - for i in range(len(self.shape) + len(dims)): - if i in dims: - shape.append(1) - else: - shape.append(self.shape[cur_dim]) - cur_dim += 1 - - def f_offset(axes: Sequence[Var]) -> Expr: - base_axes = [axis for i, axis in enumerate(axes) if i not in dims] - return self(*base_axes) - - return SharedLayout.create(shape=shape, size=self.size, f_offset=f_offset) + from tilus.ir.layout.ops.shared_ops import shared_unsqueeze + + return shared_unsqueeze(self, dims) + def visualize(self, tablefmt: str = "simple_grid") -> str: from tilus.ir.layout.ops.shared_ops import visualize_layout diff --git a/python/tilus/lang/modules/cuda.py b/python/tilus/lang/modules/cuda.py index 2d2e13bf..f95e2215 100644 --- a/python/tilus/lang/modules/cuda.py +++ b/python/tilus/lang/modules/cuda.py @@ -309,7 +309,7 @@ def _swizzled_shared_layout(dtype: DataType, shape: tuple[int, ...]) -> SharedLa 6 7 4 5 2 3 0 1 7 6 5 4 3 2 1 0 """ - core = shared_row_major(rows, columns).swizzle(dim=1, regards_dim=0, log_step=0) + core = shared_row_major(rows, columns).with_swizzle(dim=1, regards_dim=0, log_step=0) elif columns % 4 == 0: """ 0 1 2 3 @@ -321,7 +321,7 @@ def _swizzled_shared_layout(dtype: DataType, shape: tuple[int, ...]) -> SharedLa 3 2 1 0 7 6 5 4 """ - core = shared_row_major(rows, 4).swizzle(dim=1, regards_dim=0, log_step=1) + core = shared_row_major(rows, 4).with_swizzle(dim=1, regards_dim=0, log_step=1) elif columns % 2 == 0: """ 0 1 @@ -333,7 +333,7 @@ def _swizzled_shared_layout(dtype: DataType, shape: tuple[int, ...]) -> SharedLa 5 4 7 6 """ - core = shared_row_major(rows, 2).swizzle(dim=1, regards_dim=0, log_step=2) + core = shared_row_major(rows, 2).with_swizzle(dim=1, regards_dim=0, log_step=2) else: """ 0 From 613b9bb6d720b9cc4580da2c83fecd19639a81e3 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Dec 2025 08:20:25 +0000 Subject: [PATCH 03/17] wip Signed-off-by: Yaoyao Ding --- examples/quantization/matmul_a16wx.py | 37 ++- python/tilus/ir/layout/cuda/tcgen05/smem.py | 17 +- python/tilus/ir/layout/inference/inference.py | 6 + .../inference/inference_rules/load_shared.py | 4 +- .../inference/inference_rules/store_shared.py | 4 +- python/tilus/ir/layout/ops/__init__.py | 7 +- python/tilus/ir/layout/ops/shared_ops.py | 309 ++++++++++++++---- python/tilus/ir/layout/ops/utils.py | 6 + python/tilus/ir/layout/shared_layout.py | 116 +++++-- python/tilus/ir/layout/utils/cute.py | 38 ++- python/tilus/ir/tools/printer.py | 7 +- python/tilus/lang/modules/cuda.py | 185 +---------- 12 files changed, 417 insertions(+), 319 deletions(-) diff --git a/examples/quantization/matmul_a16wx.py b/examples/quantization/matmul_a16wx.py index cb1deb77..b08ac3e2 100644 --- a/examples/quantization/matmul_a16wx.py +++ b/examples/quantization/matmul_a16wx.py @@ -197,8 +197,8 @@ def __init__( self.block_k = self.atomic_mma.k * wrk self.num_warps = wsm * wsn - k_tiles = wrk // tk - n_tiles = wsn * wrn // tn + self.k_tiles = wrk // tk + self.n_tiles = wsn * wrn // tn # we make sure that each weight_tile will be loaded by one warp assert wrk * self.atomic_mma.k % weight_tile[0] == 0 @@ -217,16 +217,16 @@ def __init__( ) self.layout_rs = reduce(self.mma.lb, dims=[0], keepdims=True) - self.layout_sa = self.cuda.swizzled_shared_layout( - self.a_dtype, shape=[num_stages, self.block_m, self.block_k] - ) - self.layout_sb = self.cuda.shared_layout( - shape=[self.num_stages, k_tiles, n_tiles, self.tile_bytes] - ) - self.layout_sc = self.cuda.swizzled_shared_layout( - self.a_dtype, shape=[self.block_m, self.block_n] - ) - self.layout_ss = self.cuda.shared_layout(shape=[self.num_stages, 1, self.block_n]) + # self.layout_sa = self.cuda.swizzled_shared_layout( + # self.a_dtype, shape=[num_stages, self.block_m, self.block_k] + # ) + # self.layout_sb = self.cuda.shared_layout( + # shape=[self.num_stages, k_tiles, n_tiles, self.tile_bytes] + # ) + # self.layout_sc = self.cuda.swizzled_shared_layout( + # self.a_dtype, shape=[self.block_m, self.block_n] + # ) + # self.layout_ss = self.cuda.shared_layout(shape=[self.num_stages, 1, self.block_n]) def __call__( self, @@ -275,7 +275,10 @@ def __call__( sa = self.shared_tensor( dtype=self.a_dtype, shape=[self.num_stages, block_m, block_k] ) - sb = self.shared_tensor(dtype=uint8, shape=self.layout_sb.shape) + sb = self.shared_tensor( + dtype=uint8, + shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes], + ) ss = self.shared_tensor(dtype=self.a_dtype, shape=[self.num_stages, 1, block_n]) acc = self.register_tensor( dtype=float32, @@ -395,10 +398,10 @@ def __call__( ) # annotate layouts - self.annotate_layout(sc, layout=self.layout_sc) - self.annotate_layout(sa, layout=self.layout_sa) - self.annotate_layout(sb, layout=self.layout_sb) - self.annotate_layout(ss, layout=self.layout_ss) + # self.annotate_layout(sc, layout=self.layout_sc) + # self.annotate_layout(sa, layout=self.layout_sa) + # self.annotate_layout(sb, layout=self.layout_sb) + # self.annotate_layout(ss, layout=self.layout_ss) self.annotate_layout(acc, layout=self.mma.lc) diff --git a/python/tilus/ir/layout/cuda/tcgen05/smem.py b/python/tilus/ir/layout/cuda/tcgen05/smem.py index c4f788c9..d903095d 100644 --- a/python/tilus/ir/layout/cuda/tcgen05/smem.py +++ b/python/tilus/ir/layout/cuda/tcgen05/smem.py @@ -18,9 +18,7 @@ from typing import Literal, Optional, Sequence, cast import numpy as np -from hidet.ir.expr import Expr, Var from hidet.ir.type import DataType -from hidet.utils.py import prod from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode from tilus.ir.layout.shared_layout import SharedLayout from tilus.ir.layout.utils.cute import CuteLayout, CuteSwizzle, IntTuple, SwizzledCuteLayout, cute_layout, tuple_product @@ -303,17 +301,14 @@ def get_shared_layout_from_canonical(canonical_layout: CanonicalSharedLayout) -> else: raise ValueError(f"Unsupported major_kind: {canonical_layout.major_kind}") - def f_offset(axes: Sequence[Var]) -> Expr | int: - nbytes = 16 // canonical_layout.T - swizzle = CuteSwizzle(bbits=bbits, mbase=mbase - floor_log2(nbytes), sshift=sshift) - return swizzle(layout(*axes)) + nbytes = 16 // canonical_layout.T + swizzle = CuteSwizzle(bbits=bbits, mbase=mbase - floor_log2(nbytes), sshift=sshift) + swizzled_cute_layout = SwizzledCuteLayout(layout, swizzle) - if not isinstance(layout.shape, Sequence): - smem_shape = [int(layout.shape)] - else: - smem_shape = [int(tuple_product(item)) for item in layout.shape] + assert isinstance(layout.shape, Sequence) + shape = [int(tuple_product(item)) for item in layout.shape] - return SharedLayout.create(shape=smem_shape, size=prod(smem_shape), f_offset=f_offset) + return swizzled_cute_layout.as_shared_layout(shape) def generate_canonical_layout( diff --git a/python/tilus/ir/layout/inference/inference.py b/python/tilus/ir/layout/inference/inference.py index fc9bb454..6f63a355 100644 --- a/python/tilus/ir/layout/inference/inference.py +++ b/python/tilus/ir/layout/inference/inference.py @@ -278,6 +278,12 @@ def pair_sort_key(pair: tuple[Instruction, Type[LayoutInferenceRule]]) -> tuple[ SharedTensor | RegisterTensor | TMemoryTensor, SharedTensor | RegisterTensor | TMemoryTensor ] = {} for tensor, layout in mapping.items(): + assert isinstance(tensor, (RegisterTensor, SharedTensor, TMemoryTensor)), ( + f"Invalid tensor type {type(tensor)} for rule {rule.__name__} " + ) + assert isinstance(layout, (RegisterLayout, SharedLayout, TMemoryLayout)), ( + f"Invalid layout type {type(layout)} for rule {rule.__name__} " + ) assert same_list(tensor.shape, layout.shape), ( f"Layout shape does not match tensor shape: {tensor.shape} vs {layout.shape} for rule {rule.__name__} " ) diff --git a/python/tilus/ir/layout/inference/inference_rules/load_shared.py b/python/tilus/ir/layout/inference/inference_rules/load_shared.py index e9c36fed..bb57444a 100644 --- a/python/tilus/ir/layout/inference/inference_rules/load_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/load_shared.py @@ -19,6 +19,7 @@ from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixConfig from tilus.ir.layout import LayoutOperationError, ops from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule +from tilus.ir.layout.ops import shared_row_major_swizzle from tilus.utils import gcd @@ -45,9 +46,8 @@ def inference( continue # use swizzle layout since we are using ldmatrix instruction - from tilus.lang.modules.cuda import cuda - return {a: cuda.swizzled_shared_layout(dtype=a.dtype, shape=a.shape)} + return {a: shared_row_major_swizzle(dtype_nbytes=a.dtype.nbytes, shape=a.shape)} return {} diff --git a/python/tilus/ir/layout/inference/inference_rules/store_shared.py b/python/tilus/ir/layout/inference/inference_rules/store_shared.py index be450b65..5d6d61b7 100644 --- a/python/tilus/ir/layout/inference/inference_rules/store_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/store_shared.py @@ -18,6 +18,7 @@ from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixConfig from tilus.ir.layout import LayoutOperationError, ops from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule +from tilus.ir.layout.ops import shared_row_major_swizzle @register_rule(StoreSharedGenericInst) @@ -42,8 +43,7 @@ def inference( continue # use swizzle layout since we are using ldmatrix instruction - from tilus.lang.modules.cuda import cuda - return {a: cuda.swizzled_shared_layout(dtype=a.dtype, shape=a.shape)} + return {a: shared_row_major_swizzle(dtype_nbytes=a.dtype.nbytes, shape=a.shape)} return {} diff --git a/python/tilus/ir/layout/ops/__init__.py b/python/tilus/ir/layout/ops/__init__.py index 3bb326ba..db421638 100644 --- a/python/tilus/ir/layout/ops/__init__.py +++ b/python/tilus/ir/layout/ops/__init__.py @@ -38,12 +38,7 @@ squeeze, unsqueeze, ) -from .shared_ops import ( - shared_column_major, - shared_compose, - shared_permute, - shared_row_major, -) +from .shared_ops import shared_column_major, shared_compose, shared_permute, shared_row_major, shared_row_major_swizzle from .tmemory_ops import ( tmemory_row_major, tmemory_slice, diff --git a/python/tilus/ir/layout/ops/shared_ops.py b/python/tilus/ir/layout/ops/shared_ops.py index 5d56f102..87eaba87 100644 --- a/python/tilus/ir/layout/ops/shared_ops.py +++ b/python/tilus/ir/layout/ops/shared_ops.py @@ -17,42 +17,35 @@ from typing import List, Sequence import tabulate -from hidet.ir.dtypes import int32 -from hidet.ir.expr import Expr, Var -from hidet.utils import prod +from hidet.utils import gcd, prod -from tilus.extensions.hidet.ir.utils.index_transform import vector_mul -from tilus.ir.layout.ops.utils import LayoutOperationError -from tilus.ir.layout.shared_layout import SharedLayout +from tilus.extensions.hidet.ir.expr import index_vars +from tilus.ir.layout.ops.utils import LayoutOperationError, get_mode_groups +from tilus.ir.layout.shared_layout import SharedLayout, Swizzle, shared_layout from tilus.ir.utils.veceval import meshgrid, vectorized_evaluate -def _generic_repeat(shape: List[int], ranks: List[int]) -> SharedLayout: - assert len(shape) == len(ranks) - assert len(ranks) == len(set(ranks)) and all(0 <= d < len(shape) for d in ranks) - strides: List[int] = [prod([s for j, s in enumerate(shape) if ranks[j] > ranks[i]]) for i in range(len(shape))] - - def f_offset(axes: Sequence[Var]) -> Expr: - return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero) - - return SharedLayout.create(shape=shape, size=prod(shape), f_offset=f_offset) - - -def _shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout: - assert len(lhs.shape) == len(rhs.shape) - ndims = len(lhs.shape) - - def f_offset(axes: Sequence[Var]) -> Expr: - lhs_axes = [axes[i] // rhs.shape[i] for i in range(ndims)] - rhs_axes = [axes[i] % rhs.shape[i] for i in range(ndims)] - lhs_offset = lhs(*lhs_axes) - rhs_offset = rhs(*rhs_axes) - return lhs_offset * rhs.size + rhs_offset +def strides_from_ranks(shape: Sequence[int], ranks: Sequence[int]) -> list[int]: + """ + Compute the strides from the ranks of each dimension. - shape = vector_mul(lhs.shape, rhs.shape) - size = lhs.size * rhs.size + Parameters + ---------- + shape: Sequence[int] + The shape of the tensor. + ranks: Sequence[int] + The ranks of each dimension. The length of ranks must be equal to the length of shape + and all elements in ranks must be unique and in the range [0, len(shape)). - return SharedLayout.create(shape=shape, size=size, f_offset=f_offset) + Returns + ------- + ret: list[int] + The strides of each dimension. + """ + assert len(shape) == len(ranks) + assert len(ranks) == len(set(ranks)) and all(0 <= d < len(shape) for d in ranks) + strides: list[int] = [prod([s for j, s in enumerate(shape) if ranks[j] > ranks[i]]) for i in range(len(shape))] + return strides def shared_row_major(*shape: int) -> SharedLayout: @@ -68,7 +61,9 @@ def shared_row_major(*shape: int) -> SharedLayout: ret: SharedLayout A shared layout with the specified shape in row-major order. """ - return _generic_repeat(shape=list(shape), ranks=list(range(len(shape)))) + mode_shape = shape + mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(range(len(mode_shape)))) + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None) def shared_column_major(*shape: int) -> SharedLayout: @@ -84,10 +79,12 @@ def shared_column_major(*shape: int) -> SharedLayout: ret: SharedLayout A shared layout with the specified shape in column-major order. """ - return _generic_repeat(shape=list(shape), ranks=list(reversed(range(len(shape))))) + mode_shape = shape + mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(reversed(range(len(mode_shape))))) + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None) -def shared_compose(lhs: SharedLayout, rhs: SharedLayout, *others: SharedLayout) -> SharedLayout: +def shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout: """Compose multiple shared layouts together. Parameters @@ -96,18 +93,34 @@ def shared_compose(lhs: SharedLayout, rhs: SharedLayout, *others: SharedLayout) The first shared layout to compose. rhs: SharedLayout The second shared layout to compose. - others: Sequence[SharedLayout] - The additional shared layouts to compose with the first two. It can be empty. Returns ------- ret: SharedLayout The composed shared layout. """ - if len(others) == 0: - return _shared_compose(lhs, rhs) - else: - return shared_compose(_shared_compose(lhs, rhs), *others) + assert len(lhs.shape) == len(rhs.shape) + ndims = len(lhs.shape) + + # shape + shape = tuple(lhs.shape[i] * rhs.shape[i] for i in range(ndims)) + + # mode shape + lhs_mode_groups = get_mode_groups(lhs.shape, lhs.mode_shape) + rhs_mode_groups = get_mode_groups(rhs.shape, rhs.mode_shape) + mode_shape: list[int] = [] + for lhs_group, rhs_group in zip(lhs_mode_groups, rhs_mode_groups): + mode_shape.extend([lhs.mode_shape[i] for i in lhs_group]) + mode_shape.extend([rhs.mode_shape[i] for i in rhs_group]) + + # mode strides + mode_strides: list[int] = [] + rhs_size = rhs.count_size() + for lhs_group, rhs_group in zip(lhs_mode_groups, rhs_mode_groups): + mode_strides.extend([stride * rhs_size for stride in (lhs.mode_strides[i] for i in lhs_group)]) + mode_strides.extend([rhs.mode_strides[i] for i in rhs_group]) + + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None) def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: @@ -127,27 +140,209 @@ def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: ret: SharedLayout The permuted layout. """ - if set(dims) != set(range(len(layout.shape))): - raise LayoutOperationError("Dims must be a permutation of {}, got {}".format(range(len(layout.shape)), dims)) + assert len(dims) == len(layout.shape) and set(dims) == set(range(len(layout.shape))) + + # shape shape = tuple(layout.shape[d] for d in dims) - axes = tuple(layout.axes[d] for d in dims) - return SharedLayout(shape=shape, size=layout.size, axes=axes, offset=layout.offset) + + # mode shape and mode strides + layout_mode_groups = get_mode_groups(layout.shape, layout.mode_shape) + mode_shape: list[int] = [] + mode_strides: list[int] = [] + for d in dims: + mode_shape.extend([layout.mode_shape[i] for i in layout_mode_groups[d]]) + mode_strides.extend([layout.mode_strides[i] for i in layout_mode_groups[d]]) + + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=layout.swizzle) + def shared_unsqueeze(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: - shape = [] - cur_dim = 0 - for i in range(len(self.shape) + len(dims)): - if i in dims: - shape.append(1) - else: - shape.append(self.shape[cur_dim]) - cur_dim += 1 + """Unsqueeze the shared layout by adding new dimensions of size 1. + + Parameters + ---------- + layout: SharedLayout + The layout to unsqueeze. + dims: Sequence[int] + The dimensions to unsqueeze. Each dimension should be in the range [0, len(layout.shape)]. + + Returns + ------- + ret: SharedLayout + The unsqueezed layout. + """ + assert all(0 <= d <= len(layout.shape) for d in dims) and len(dims) == len(set(dims)) + shape: List[int] = list(layout.shape) + for d in sorted(dims): + shape.insert(d, 1) + return shared_layout( + shape=shape, + mode_shape=layout.mode_shape, + mode_strides=layout.mode_strides, + swizzle=layout.swizzle, + ) + + +def shared_row_major_swizzle(shape: Sequence[int], dtype_nbytes: int) -> SharedLayout: + """ + Generate a shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. + + Both m and n must be a multiple of 8. + + We will divide each row into bank groups, and bank group has 16 bytes (16 x uint8, 8 x fp16, or 4 x fp32, etc.). + They correspond to 4 banks in shared memory. For example, if we have m = n = 8 and dtype=fp16, we can represent + bank groups as + + 0 # bank group 0, banks from 0 to 3 + 1 # bank group 1, banks from 4 to 7 + 2 # ... + 3 + 4 + 5 + 6 + 7 # bank groups 7, banks from 28 to 31 + + Given m, and n, we need to find a proper way to organize the m x (n / 8) bank groups in shared memory, so that + 1) each row has different bank groups + 2) each column has different bank groups + + When we have m = 8 and n = 64, we have 8 x 8 bank groups. If we store the elements in row-major order, we will + have the bank groups as + + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + + If we use ldmatrix to load the above 8 x 64 shared memory, we will need 8 ldmatrix.v1 instructions. Each instruction + loads one column (8 x 8 elements, or 8 x 1 bank groups). Since each instruction will access the same bank group, + severe bank conflicts will occur. Thus, we need to change the layout of shared memory to avoid bank conflicts. + + Let layout(i, j) be the shared memory address of logical elements (each element has 16 bytes) when we use + a specific `layout`. For example, the row-major layout row-major(i, j) = i * n + j * 8 (we assume the dtype has 2 + bytes). If we use the swizzled layout swizzled(i, j) = row-major(i, j ^ i) = i * n + (j ^ i) * 8, we can have the + following bank groups in shared memory. + + 0 1 2 3 4 5 6 7 + 1 0 3 2 5 4 7 6 + 2 3 0 1 6 7 4 5 + 3 2 1 0 7 6 5 4 + 4 5 6 7 0 1 2 3 + 5 4 7 6 1 0 3 2 + 6 7 4 5 2 3 0 1 + 7 6 5 4 3 2 1 0 + + (reader may need some time to figure out the above layout...) + + This layout has two benefits: + 1) Each row has different bank groups. In above example, we have 32 banks per row. + 2) Each column has different bank groups. In above example, we have 32 banks per column. + + The benefit 1 makes sure that when we load data from global memory to shared memory, we can store efficiently. + The benefit 2 makes sure that when we load data from shared memory to register memory, we can load efficiently. + + We can always generate the swizzled layout for arbitrary m and n as long as they are multiple of 8. See the + implementation for more details. + + Parameters + ---------- + shape: Sequence[int] + The shape of the shared memory. The shape must have at least two dimensions. + + dtype_nbytes: int + The element data type size in bytes. - def f_offset(axes: Sequence[Var]) -> Expr: - base_axes = [axis for i, axis in enumerate(axes) if i not in dims] - return self(*base_axes) + Returns + ------- + shared_layout: SharedLayout + The shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. + """ + if len(shape) < 2: + raise ValueError("The shape of swizzled shared layout must have at least two dimensions.") + head, m, n = tuple(shape[:-2]), shape[-2], shape[-1] + + if m % 8 != 0 or n * dtype_nbytes % 16 != 0: + raise ValueError("m must be a multiple of 8, and n * dtype_nbytes must be a multiple of 16.") + + n_vector_size: int = gcd(n, 128 // dtype_nbytes) + n_num_vectors: int = n // n_vector_size + + mode_shape = head + (m, n_num_vectors, n_vector_size) + + # use the order of head, columns_vectors, rows, columns_vec_size to compute the strides + ranks = list(range(len(head))) + [len(head) + 1, len(head), len(head) + 2] + mode_strides = strides_from_ranks(shape=mode_shape, ranks=ranks) + + log2 = { + 1: 0, + 2: 1, + 4: 2, + 8: 3, + 16: 4, + } + + if n_vector_size * dtype_nbytes == 128: + """ + (each number represents a 16-byte group of elements) + 0 1 2 3 4 5 6 7 + 1 0 3 2 5 4 7 6 + 2 3 0 1 6 7 4 5 + 3 2 1 0 7 6 5 4 + 4 5 6 7 0 1 2 3 + 5 4 7 6 1 0 3 2 + 6 7 4 5 2 3 0 1 + 7 6 5 4 3 2 1 0 + """ + swizzle = Swizzle(base=log2[16 // dtype_nbytes], bits=3, shift=3) + elif n_vector_size * dtype_nbytes == 64: + """ + 0 1 2 3 + 4 5 6 7 + 1 0 3 2 + 5 4 7 6 + 2 3 0 1 + 6 7 4 5 + 3 2 1 0 + 7 6 5 4 + """ + swizzle = Swizzle(base=log2[16 // dtype_nbytes], bits=2, shift=3) + elif n_vector_size * dtype_nbytes == 32: + """ + 0 1 + 2 3 + 4 5 + 6 7 + 1 0 + 3 2 + 5 4 + 7 6 + """ + swizzle = Swizzle(base=log2[16 // dtype_nbytes], bits=1, shift=3) + elif n_vector_size * dtype_nbytes == 16: + """ + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + """ + swizzle = None + else: + assert False - return SharedLayout.create(shape=shape, size=self.size, f_offset=f_offset) + return shared_layout( + shape=shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + swizzle=swizzle, + ) def visualize_layout(layout: SharedLayout, tablefmt: str = "simple_grid") -> str: @@ -186,7 +381,9 @@ def visualize_layout(layout: SharedLayout, tablefmt: str = "simple_grid") -> str if len(layout.shape) != 2: raise LayoutOperationError(f"Shared layout with shape {layout.shape} is not supported for visualization.") grid = meshgrid(layout.shape) - offset_grid = vectorized_evaluate(layout.offset, var2value={axis: grid[i] for i, axis in enumerate(layout.axes)}) + axes = index_vars(num_vars=len(layout.shape)) + offset = layout(*axes) + offset_grid = vectorized_evaluate(offset, var2value={axis: grid[i] for i, axis in enumerate(axes)}) table = [] for i in range(layout.shape[0]): row = [] diff --git a/python/tilus/ir/layout/ops/utils.py b/python/tilus/ir/layout/ops/utils.py index 4e0699f8..bd34fb43 100644 --- a/python/tilus/ir/layout/ops/utils.py +++ b/python/tilus/ir/layout/ops/utils.py @@ -25,6 +25,12 @@ def get_mode_groups(shape: Sequence[int], mode_shape: Sequence[int]) -> list[lis """ Get the groups of modes based on the shape and mode_shape. + Example: + >>> shape = [64, 32] + >>> mode_shape = [8, 8, 16, 2] + >>> get_mode_groups(shape, mode_shape) + [[0, 1], [2, 3]] + Parameters ---------- shape: Sequence[int] diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index 352779e7..32ca36c6 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -15,14 +15,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, Dict, List, Sequence, Optional +from typing import Optional, Sequence -from hidet.ir.expr import Expr, Var, as_expr +from hidet.ir.expr import Expr, as_expr from hidet.ir.utils.index_transform import index_deserialize +from hidet.utils import prod -from tilus.extensions.hidet.ir.expr import index_vars from tilus.ir.node import IRNode -from tilus.ir.layout.ops.utils import get_mode_groups + @dataclass(frozen=True, eq=True) class Swizzle: @@ -34,6 +34,7 @@ class Swizzle: y_mask = ((1 << self.bbits) - 1) << (self.mbase + self.sshift) return offset ^ ((offset & y_mask) >> self.sshift) """ + base: int bits: int shift: int @@ -46,19 +47,22 @@ def __call__(self, index: Expr) -> Expr: if self.bits == 0: return index return swizzle(index, self.base, self.bits, self.shift) - + + def __str__(self): + return f"Swizzle(base={self.base}, bits={self.bits}, shift={self.shift})" + @dataclass(frozen=True, eq=False) class SharedLayout(IRNode): """The layout for shared tensor. We use three components to describe a shared tensor layout: the shape, the mode shape, and the mode strides. - - The mode shape and mode strides are used to describe how to split each dimension into multiple sub-dimensions (modes), - and the strides of each mode. + + The mode shape and mode strides are used to describe how to split each dimension into multiple sub-dimensions (modes), + and the strides of each mode. For example, consider a shape of (64, 32), we can split the first dimension into two sub-dimensions (modes) of size 8 and 8, - and the second dimension into two sub-dimensions (modes) of size 16 and 2. The mode shape would be (8, 8, 16, 2). We can + and the second dimension into two sub-dimensions (modes) of size 16 and 2. The mode shape would be (8, 8, 16, 2). We can have strides for each mode, for example, (256, 2, 16, 1). Then given the indices (i, j), we can compute the indices in the sub-dimensions (i1, i2, j1, j2) where i1 = i // 8, i2 = i % 8, j1 = j // 2, j2 = j % 2. The offset can be computed as: offset = i1 * 256 + i2 * 2 + j1 * 16 + j2 * 1. To get the final offset in the shared tensor, we can use the formula: @@ -69,7 +73,7 @@ class SharedLayout(IRNode): shape: tuple[int, ...] The shape of the shared tensor. Each dimension is a constant integer. mode_shape: tuple[int, ...] - We can split each dimension into multiple sub-dimensions (modes). + We can split each dimension into multiple sub-dimensions (modes). mode_strides: tuple[int, ...] The strides of each mode. swizzle: Optional[Swizzle] @@ -96,24 +100,28 @@ def __call__(self, *indices: Expr) -> Expr: ret: Expr The computed offset of the shared tensor element at the given indices. """ + from tilus.ir.layout.ops.utils import get_mode_groups + # get the stride-based index group_modes = get_mode_groups(self.shape, self.mode_shape) mode_indices: list[Expr] = [] for index, modes in zip(indices, group_modes): mode_indices.extend(index_deserialize(index, shape=[self.mode_shape[m] for m in modes])) total_index: Expr = as_expr(sum(index * stride for index, stride in zip(mode_indices, self.mode_strides))) - + # apply swizzle if exists if self.swizzle is not None: total_index = self.swizzle(total_index) - + return total_index @staticmethod - def create(shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequence[int], swizzle: Optional[Swizzle]) -> SharedLayout: + def create( + shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequence[int], swizzle: Optional[Swizzle] + ) -> SharedLayout: """ Create a SharedLayout from shape, mode_shape, and mode_strides. - + Parameters ---------- shape: Sequence[int] @@ -134,11 +142,14 @@ def create(shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequen raise ValueError("All dimensions in shape must be positive integers.") if len(mode_shape) != len(mode_strides): raise ValueError("mode_shape and mode_strides must have the same length.") - return SharedLayout(shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), swizzle=swizzle) + if prod(mode_shape) != prod(shape): + raise ValueError("The product of mode_shape must equal to the product of shape.") + return SharedLayout( + shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), swizzle=swizzle + ) - @property - def size(self) -> int: - """Get the total size of the shared layout. + def count_size(self) -> int: + """Count the total size of the shared layout. It is the minimum number of elements required to store the tensor in shared memory. @@ -157,23 +168,15 @@ def slice(self, offsets: Sequence[Expr], slice_dims: Sequence[int], slice_shape: def simplify(self) -> SharedLayout: raise RuntimeError("No need to simplify anymore.") - def with_swizzle(self, dim: int, regards_dim: int, log_step: int) -> SharedLayout: - raise RuntimeError("Update swizzle.") - # ndims = len(self.shape) - # assert 0 <= dim < ndims and 0 <= regards_dim < ndims and dim != regards_dim - # def get_xor_index(indices: Sequence[Expr]) -> Expr: - # indices = list(indices) # copy - # step = 2**log_step - # regards_index = indices[regards_dim] // step - # regards_extent = self.shape[regards_dim] // step - # if regards_extent > self.shape[dim]: - # regards_index = regards_index % self.shape[dim] - # return regards_index - # def f_offset(axes: Sequence[Var]) -> Expr: - # swizzled_indices: List[Expr] = [axis for axis in axes] - # swizzled_indices[dim] = swizzled_indices[dim] ^ get_xor_index(axes) - # return self(*swizzled_indices) - # return SharedLayout.create(shape=self.shape, size=self.size, f_offset=f_offset) + def apply_swizzle(self, swizzle: Swizzle) -> SharedLayout: + if self.swizzle is not None: + raise RuntimeError("Chained swizzle is not supported.") + return SharedLayout.create( + shape=self.shape, + mode_shape=self.mode_shape, + mode_strides=self.mode_strides, + swizzle=swizzle, + ) def prepend_dim(self, extent: int) -> SharedLayout: shape = (extent,) + self.shape @@ -205,8 +208,49 @@ def unsqueeze(self, dims: Sequence[int]) -> SharedLayout: return shared_unsqueeze(self, dims) - def visualize(self, tablefmt: str = "simple_grid") -> str: from tilus.ir.layout.ops.shared_ops import visualize_layout return visualize_layout(self, tablefmt=tablefmt) + + +def shared_layout( + shape: Sequence[int], + mode_shape: Sequence[int], + mode_strides: Sequence[int], + swizzle: Optional[Swizzle] = None, +) -> SharedLayout: + """Create a SharedLayout from shape, mode_shape, and mode_strides. + + Parameters + ---------- + shape: Sequence[int] + The shape of the shared tensor. + mode_shape: Sequence[int] + The mode shape of the shared tensor. + mode_strides: Sequence[int] + The mode strides of the shared tensor. + swizzle: Optional[Swizzle] + The swizzle function to apply on the final offset. If None, no swizzling is applied. + + Returns + ------- + ret: SharedLayout + The created SharedLayout. + """ + # canonicalize mode shape: clean up mode_shape and mode_strides by removing size 1 modes + if any(s <= 1 for s in mode_shape): + updated_mode_shape = [] + updated_mode_strides = [] + for ms, stride in zip(mode_shape, mode_strides): + if ms > 1: + updated_mode_shape.append(ms) + updated_mode_strides.append(stride) + mode_shape = updated_mode_shape + mode_strides = updated_mode_strides + + # canonicalize swizzle: if swizzle has 0 bits, set it to None (both mean no swizzle) + if swizzle is not None and swizzle.bits == 0: + swizzle = None + + return SharedLayout.create(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=swizzle) diff --git a/python/tilus/ir/layout/utils/cute.py b/python/tilus/ir/layout/utils/cute.py index 6093c6b1..8843b118 100644 --- a/python/tilus/ir/layout/utils/cute.py +++ b/python/tilus/ir/layout/utils/cute.py @@ -21,6 +21,7 @@ from hidet.utils import prod from tilus.extensions.hidet.ir.primitives.swizzle import swizzle from tilus.extensions.hidet.ir.utils.index_transform import index_deserialize +from tilus.ir.layout.shared_layout import SharedLayout, Swizzle, shared_layout Int = Union[Expr, int] IntTuple = Int | Sequence[Union[Int, "IntTuple"]] @@ -79,8 +80,7 @@ def __str__(self) -> str: return f"{self.shape}:{self.strides}" def __call__(self, *coords: IntTuple) -> Int: - coords = specialize(coords, self.shape) - ret = tuple_sum(tuple_multiply(coords, self.strides)) + ret = tuple_sum(tuple_multiply(specialize(coords, self.shape), self.strides)) return ret @property @@ -110,6 +110,9 @@ def __call__(self, offset: Int) -> Int: # return offset ^ ((offset & y_mask) >> self.sshift) return swizzle(int32(offset), self.mbase, self.bbits, self.sshift) + def as_swizzle(self) -> Swizzle: + return Swizzle(base=self.mbase, bits=self.bbits, shift=self.sshift) + class SwizzledCuteLayout: def __init__(self, layout: CuteLayout, swizzle: CuteSwizzle): @@ -122,6 +125,37 @@ def __str__(self) -> str: def __call__(self, *coords: IntTuple) -> Int: return self.swizzle(self.layout(*coords)) + def as_shared_layout(self, tensor_shape: Sequence[int]) -> SharedLayout: + # since cute layout use column-major order when splitting modes, we need to reverse the shape and strides + def reverse_int_tuple(t: IntTuple) -> IntTuple: + if isinstance(t, Sequence): + return tuple(reverse_int_tuple(item) for item in reversed(t)) + else: + return t + + rev_shape = reverse_int_tuple(self.layout.shape) + rev_strides = reverse_int_tuple(self.layout.strides) + + # then, we flatten them into 1D lists + def flatten_int_tuple(t: IntTuple) -> list[Int]: + if isinstance(t, Sequence): + result = [] + for item in t: + result.extend(flatten_int_tuple(item)) + return result + else: + return [t] + + flat_shape = flatten_int_tuple(rev_shape) + flat_strides = flatten_int_tuple(rev_strides) + + mode_shape = [int(s) for s in flat_shape] + mode_strides = [int(s) for s in flat_strides] + + return shared_layout( + shape=tensor_shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=self.swizzle.as_swizzle() + ) + def cute_layout(shape: IntTuple, strides: IntTuple) -> CuteLayout: return CuteLayout(shape, strides) diff --git a/python/tilus/ir/tools/printer.py b/python/tilus/ir/tools/printer.py index fe51c446..52a06335 100644 --- a/python/tilus/ir/tools/printer.py +++ b/python/tilus/ir/tools/printer.py @@ -448,12 +448,11 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> Doc: return self.add_key_comment("layout", doc) def visit_SharedLayout(self, node: SharedLayout) -> Doc: - for i, axis in enumerate(node.axes): - self.set_var_name(axis, "u" + str(i)) items = [ "shape=[" + self(node.shape) + "]", - "axes=[" + self(node.axes) + "]", - "offset=" + self(node.offset), + "mode_shape=[" + self(node.mode_shape) + "]", + "mode_strides=[" + self(node.mode_strides) + "]", + "swizzle=" + (str(node.swizzle) if node.swizzle is not None else "None"), ] doc = Text("SharedLayout(") + doc_join(items, ", ") + ")" return self.add_key_comment("shared_layout", doc) diff --git a/python/tilus/lang/modules/cuda.py b/python/tilus/lang/modules/cuda.py index f95e2215..64b47949 100644 --- a/python/tilus/lang/modules/cuda.py +++ b/python/tilus/lang/modules/cuda.py @@ -18,14 +18,13 @@ import cuda.bindings.runtime as cudart from hidet.ir.dtypes import DataType, bfloat16, float16, float32, int8, int32 -from hidet.ir.expr import as_expr from tilus import RegisterLayout from tilus.backends.emitters.cuda.mma_dot import AtomicMmaConfig from tilus.ir.layout import SharedLayout -from tilus.ir.layout.ops import auto_local_spatial, reduce, shared_compose, shared_row_major, spatial +from tilus.ir.layout.ops import auto_local_spatial, reduce, shared_row_major, spatial from tilus.ir.utils import vector -from tilus.utils import gcd, idiv, prod +from tilus.utils import gcd, prod @dataclass(frozen=True, eq=False) @@ -206,186 +205,6 @@ def shared_layout(shape: Sequence[int]) -> SharedLayout: """ return shared_row_major(*shape) - @staticmethod - def swizzled_shared_layout(dtype: DataType, *, shape: Sequence[int]) -> SharedLayout: - """ - Generate a shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. - - Both m and n must be a multiple of 8. - - We will divide each row into bank groups, and bank group has 16 bytes (16 x uint8, 8 x fp16, or 4 x fp32, etc.). - They correspond to 4 banks in shared memory. For example, if we have m = n = 8, we can represent bank groups as - - 0 # bank group 0, banks from 0 to 3 - 1 # bank group 1, banks from 4 to 7 - 2 # ... - 3 - 4 - 5 - 6 - 7 # bank groups 7, banks from 28 to 31 - - Given m, and n, we need to find a proper way to organize the m x (n / 8) bank groups in shared memory, so that - 1) each row has different bank groups - 2) each column has different bank groups - - When we have m = 8 and n = 64, we have 8 x 8 bank groups. If we store the elements in row-major order, we will - have the bank groups as - - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - - If we use ldmatrix to load the above 8 x 64 shared memory, we will need 8 ldmatrix.v1 instructions. Each instruction - loads one column (8 x 8 elements, or 8 x 1 bank groups). Since each instruction will access the same bank group, - severe bank conflicts will occur. Thus, we need to change the layout of shared memory to avoid bank conflicts. - - Let layout(i, j) be the shared memory address of logical elements (each element has 16 bytes) when we use - a specific `layout`. For example, the row-major layout row-major(i, j) = i * n + j * 8 (we assume the dtype has 2 - bytes). If we use the swizzled layout swizzled(i, j) = row-major(i, j ^ i) = i * n + (j ^ i) * 8, we can have the - following bank groups in shared memory. - - 0 1 2 3 4 5 6 7 - 1 0 3 2 5 4 7 6 - 2 3 0 1 6 7 4 5 - 3 2 1 0 7 6 5 4 - 4 5 6 7 0 1 2 3 - 5 4 7 6 1 0 3 2 - 6 7 4 5 2 3 0 1 - 7 6 5 4 3 2 1 0 - - (reader may need some time to figure out the above layout...) - - This layout has two benefits: - 1) Each row has different bank groups. In above example, we have 32 banks per row. - 2) Each column has different bank groups. In above example, we have 32 banks per column. - - The benefit 1 makes sure that when we load data from global memory to shared memory, we can store efficiently. - The benefit 2 makes sure that when we load data from shared memory to register memory, we can load efficiently. - - We can always generate the swizzled layout for arbitrary m and n as long as they are multiple of 8. See the - implementation for more details. - - Parameters - ---------- - dtype: DataType - The element data type for both the shared memory and the register memory. - - shape: Sequence[int] - The shape of the shared memory. The shape must have at least two dimensions. - - Returns - ------- - shared_layout: SharedLayout - The shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. - """ - return cuda._swizzled_shared_layout_new(dtype, shape=tuple(shape)) - - @staticmethod - @functools.lru_cache - def _swizzled_shared_layout(dtype: DataType, shape: tuple[int, ...]) -> SharedLayout: - if len(shape) < 2: - raise ValueError("The shape of swizzled shared layout must have at least two dimensions.") - m, n = shape[-2:] - group_elements = idiv(16, dtype.nbytes) - if m % 8 != 0 or n % group_elements != 0: - raise ValueError("m must be a multiple of 8, and n must be a multiple of dtype.nbytes * 8.") - rows = m - columns = n // group_elements - - if columns % 8 == 0: - """ - 0 1 2 3 4 5 6 7 - 1 0 3 2 5 4 7 6 - 2 3 0 1 6 7 4 5 - 3 2 1 0 7 6 5 4 - 4 5 6 7 0 1 2 3 - 5 4 7 6 1 0 3 2 - 6 7 4 5 2 3 0 1 - 7 6 5 4 3 2 1 0 - """ - core = shared_row_major(rows, columns).with_swizzle(dim=1, regards_dim=0, log_step=0) - elif columns % 4 == 0: - """ - 0 1 2 3 - 4 5 6 7 - 1 0 3 2 - 5 4 7 6 - 2 3 0 1 - 6 7 4 5 - 3 2 1 0 - 7 6 5 4 - """ - core = shared_row_major(rows, 4).with_swizzle(dim=1, regards_dim=0, log_step=1) - elif columns % 2 == 0: - """ - 0 1 - 2 3 - 4 5 - 6 7 - 1 0 - 3 2 - 5 4 - 7 6 - """ - core = shared_row_major(rows, 2).with_swizzle(dim=1, regards_dim=0, log_step=2) - else: - """ - 0 - 1 - 2 - 3 - 4 - 5 - 6 - 7 - """ - core = shared_row_major(rows, 1) - layout = shared_compose(core, shared_row_major(1, group_elements)) - if m > layout.shape[0] or n > layout.shape[1]: - layout = shared_compose(shared_row_major(m // layout.shape[0], n // layout.shape[1]), layout) - if len(shape) > 2: - for extent in reversed(shape[:-2]): - layout = layout.prepend_dim(extent=extent) - return layout - - @staticmethod - @functools.lru_cache - def _swizzled_shared_layout_new(dtype: DataType, shape: tuple[int, ...]) -> SharedLayout: - if len(shape) < 2: - raise ValueError("The shape of swizzled shared layout must have at least two dimensions.") - m, n = shape[-2:] - group_elements = idiv(16, dtype.nbytes) - if m % 8 != 0 or n % group_elements != 0: - raise ValueError("m must be a multiple of 8, and n must be a multiple of dtype.nbytes * 8.") - - def f_offset(axes): - strides: list[int] = [prod(shape[i + 1 :]) for i in range(len(shape))] - - columns: int = n // group_elements - columns_vec_size: int = gcd(columns, 8) - - i, j = axes[-2:] - if columns_vec_size == 8: - i, j = i, j ^ ((i % 8) * group_elements) - elif columns_vec_size == 4: - i, j = i, j ^ (i // 2 % 4 * group_elements) - elif columns_vec_size == 2: - i, j = i, j ^ (i // 4 % 2 * group_elements) - else: - i, j = i, j - swizzled_axes = axes[:-2] + [i, j] - offset = as_expr(sum(axis * stride for axis, stride in zip(swizzled_axes, strides))) - return offset - - layout = SharedLayout.create(shape=shape, size=prod(shape), f_offset=f_offset).simplify() - return layout - @staticmethod def default_register_layout( num_warps: int, dtype: DataType, shape: Sequence[int], vector_size: Optional[int] = None From 46a389b34d2dc198d8ea572e383678218f5d5ecb Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Dec 2025 08:31:27 +0000 Subject: [PATCH 04/17] wip Signed-off-by: Yaoyao Ding --- pyproject.toml | 1 + python/tilus/py.typed | 0 2 files changed, 1 insertion(+) create mode 100644 python/tilus/py.typed diff --git a/pyproject.toml b/pyproject.toml index 36376f74..158428c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ Documentation = "https://nvidia.github.io/tilus" "" = "python" [tool.setuptools.package-data] +"tilus" = ["py.typed"] "tilus.extensions.hidet" = ["include/**/*.h"] [tool.ruff] diff --git a/python/tilus/py.typed b/python/tilus/py.typed new file mode 100644 index 00000000..e69de29b From fda540568a24d141bab0c1385a8866454969e85e Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Dec 2025 08:46:22 +0000 Subject: [PATCH 05/17] wip Signed-off-by: Yaoyao Ding --- .pre-commit-config.yaml | 23 ++++++++----------- pyproject.toml | 2 ++ .../inference/inference_rules/load_shared.py | 1 + 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index abe35007..ce387d71 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,15 @@ repos: language: system types: [python] pass_filenames: false + - id: mypy + name: MyPy type checking + entry: mypy + language: system + types_or: [python, pyi] + args: [--show-error-codes, --show-error-context] + files: ^(python|examples|tests)/ + exclude: ^python/tilus/extensions/hidet/ + require_serial: true - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version rev: v0.12.3 @@ -30,17 +39,3 @@ repos: - id: ruff-format name: Ruff formatter types_or: [python, pyi] -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 - hooks: - - id: mypy - name: MyPy type checking - # Uses [tool.mypy] configuration from pyproject.toml - additional_dependencies: [ - types-tabulate, - types-tqdm, - ] - # Check files in the python package, examples, and tests folders - files: ^(python|examples|tests)/ - # Exclude hidet extensions (they have special mypy overrides in pyproject.toml) - exclude: ^python/tilus/extensions/hidet/ diff --git a/pyproject.toml b/pyproject.toml index 158428c9..07ec68f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,8 @@ disallow_incomplete_defs = true allow_redefinition = true strict_optional = false explicit_package_bases = true +mypy_path = "python" +namespace_packages = true [[tool.mypy.overrides]] module = ["tilus.*"] diff --git a/python/tilus/ir/layout/inference/inference_rules/load_shared.py b/python/tilus/ir/layout/inference/inference_rules/load_shared.py index bb57444a..a7da96d8 100644 --- a/python/tilus/ir/layout/inference/inference_rules/load_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/load_shared.py @@ -85,6 +85,7 @@ def inference( if not (shared.has_layout() and not register.has_layout()): return {} + axes = shared.layout.axes axes = shared.layout.axes offset = shared.layout.offset From 2abd2095bae422a228b651d6defabd3a669fd46d Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Dec 2025 22:44:47 +0000 Subject: [PATCH 06/17] format & linting Signed-off-by: Yaoyao Ding --- examples/attention/flash_attention_v3.py | 2 +- .../flash_attention_decode/tilus_kernel.py | 4 +- pyproject.toml | 13 +++-- python/tilus/backends/codegen.py | 8 +-- .../tilus/backends/emitters/cuda/cp_async.py | 3 +- .../backends/emitters/cuda/cp_async_tensor.py | 5 +- python/tilus/backends/emitters/reduce.py | 4 +- python/tilus/ir/builders/stmt_builder.py | 27 +++++----- python/tilus/ir/functors/functor.py | 8 +-- python/tilus/ir/layout/cuda/tcgen05/smem.py | 12 +---- .../inference/inference_rules/load_shared.py | 4 +- python/tilus/ir/layout/ops/shared_ops.py | 9 +--- python/tilus/ir/layout/shared_layout.py | 19 ++++++- python/tilus/ir/tensor.py | 49 +++++++++++++++---- python/tilus/ir/tools/printer.py | 2 +- python/tilus/lang/instructions/root.py | 6 +-- python/tilus/lang/instructions/tcgen05.py | 1 + python/tilus/lang/script.py | 2 +- python/tilus/lang/transpiler/transpiler.py | 4 +- .../tools/verifier/test_verify_load_shared.py | 2 +- 20 files changed, 111 insertions(+), 73 deletions(-) diff --git a/examples/attention/flash_attention_v3.py b/examples/attention/flash_attention_v3.py index 9a44e758..d79ddf87 100644 --- a/examples/attention/flash_attention_v3.py +++ b/examples/attention/flash_attention_v3.py @@ -293,7 +293,7 @@ def store_back( shape=[num_q_blocks, batch_size, self.num_heads, self.block_q], requires_clean=False, ) - semaphore = ~semaphores[self.blockIdx.x, bs, head] + semaphore = semaphores[self.blockIdx.x, bs, head].item_ptr() sm = self.shared_tensor(dtype=f32, shape=[self.block_q]) sl = self.shared_tensor(dtype=f32, shape=[self.block_q]) diff --git a/examples/flash_attention_decode/tilus_kernel.py b/examples/flash_attention_decode/tilus_kernel.py index c0885326..bf670c35 100644 --- a/examples/flash_attention_decode/tilus_kernel.py +++ b/examples/flash_attention_decode/tilus_kernel.py @@ -152,7 +152,7 @@ def __call__( dims=[2, 3], ) else: - state_idx = -1 + state_idx = -1 # type: ignore r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0) # H' = alpha * H : [K, BV] = [] * [K, BV] @@ -388,7 +388,7 @@ def __call__( dims=[2, 3], ) else: - state_idx = -1 + state_idx = -1 # type: ignore r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0) # Apply gating to hidden state: H' = alpha * H diff --git a/pyproject.toml b/pyproject.toml index 07ec68f0..ce8125d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,9 @@ strict_optional = false explicit_package_bases = true mypy_path = "python" namespace_packages = true +disable_error_code = [ + "import-untyped", # we used some untyped third-party libraries +] [[tool.mypy.overrides]] module = ["tilus.*"] @@ -123,9 +126,13 @@ disable_error_code = [ "override", ] -[[tool.mypy.overrides]] -module = ["examples.*", "tests.*"] -disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef"] +[[tool.mypy.overrides]] # disable type checking in these modules that might define tilus scripts +module = [ + "examples.*", + "tests.*", + "tilus.lang.classes.pipeline" +] +disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef", "assignment", "import-not-found"] [[tool.mypy.overrides]] module = ["hidet.*"] diff --git a/python/tilus/backends/codegen.py b/python/tilus/backends/codegen.py index aa66394c..129dc6a9 100644 --- a/python/tilus/backends/codegen.py +++ b/python/tilus/backends/codegen.py @@ -143,15 +143,15 @@ def check_emitter_existence(self) -> None: if failed_instructions: rows = [f"Failed to find emitter for the following instructions (target: {get_current_target()}):"] required_targets: list[str] = [] - for inst in failed_instructions: + for inst_cls in failed_instructions: for registry_inst_cls, emitter_classes in BaseInstEmitter.REGISTRY.items(): - if issubclass(inst, registry_inst_cls): + if issubclass(inst_cls, registry_inst_cls): required_targets.extend([str(target) for target in emitter_classes.keys()]) break if not required_targets: - rows.append(f" - {inst.__name__} (no registered emitters)") + rows.append(f" - {inst_cls.__name__} (no registered emitters)") else: - rows.append(f" - {inst.__name__} (registered targets: {', '.join(required_targets)})") + rows.append(f" - {inst_cls.__name__} (registered targets: {', '.join(required_targets)})") raise CodeGenerationFailed("\n".join(rows)) def launch_kernel(self, kernel_func: HidetFunction) -> None: diff --git a/python/tilus/backends/emitters/cuda/cp_async.py b/python/tilus/backends/emitters/cuda/cp_async.py index 3d1976af..2ae2955e 100644 --- a/python/tilus/backends/emitters/cuda/cp_async.py +++ b/python/tilus/backends/emitters/cuda/cp_async.py @@ -47,7 +47,8 @@ def emit(self, inst: CopyAsyncGenericInst) -> None: # get shared, global, and mask info inst_mask = inst.mask if inst.mask is not None else boolean.true - shared_info: TensorInfo = analyze_grid(shape=shape, axes=layout.axes, analysis=analysis, expr=layout.offset) + axes, offset = layout.as_axes_mapping() + shared_info: TensorInfo = analyze_grid(shape=shape, axes=axes, analysis=analysis, expr=offset) mask_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst_mask) global_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst.offset) diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index 98629b66..2f7e4006 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -202,12 +202,9 @@ def resolve_shared_tensor_info(self, shared_tensor: SharedTensor) -> SharedTenso range_indices: list[np.ndarray] = [] for dim, extent in enumerate(shared_tensor.shape): range_indices.append(np.arange(extent, dtype=np.int32)) - grid = np.meshgrid(*range_indices, indexing="ij") layout: SharedLayout = shared_tensor.layout - offset_grid: np.ndarray = vectorized_evaluate( - expr=layout.offset, var2value={axis: grid[i] for i, axis in enumerate(layout.axes)} - ) + offset_grid: np.ndarray = layout.as_numpy_grid() for swizzle in [ TensorMapSwizzle.NONE, TensorMapSwizzle.B32, diff --git a/python/tilus/backends/emitters/reduce.py b/python/tilus/backends/emitters/reduce.py index 32ac8dfc..a2aba1f3 100644 --- a/python/tilus/backends/emitters/reduce.py +++ b/python/tilus/backends/emitters/reduce.py @@ -284,8 +284,8 @@ def inter_warp_reduce(self, inst: ReduceInst) -> None: smem_ctx = self.contexts.smem_alloc_ctx smem_buf = self.declare_var( "smem_buf", - tensor_pointer_type(dtype=dst.dtype, shape=[shared_layout.size]), - init=cast(smem_ctx.request_shared_workspace(dst.dtype.nbytes * shared_layout.size), ~dst.dtype), + tensor_pointer_type(dtype=dst.dtype, shape=[shared_layout.count_size()]), + init=cast(smem_ctx.request_shared_workspace(dst.dtype.nbytes * shared_layout.count_size()), ~dst.dtype), ) reduced_mode_shape = [ diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 6b362f17..d50f37c1 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -165,21 +165,15 @@ def __init__( iter_vars: List[Var], extents: List[Expr], unrolls: List[Optional[int]], - unwrap: bool = False, ): super().__init__(vb) self.iter_vars: List[Var] = iter_vars self.extents: List[Expr] = extents self.unrolls: List[Optional[int]] = unrolls - self.unwrap: bool = unwrap - def __enter__(self): + def __enter__(self) -> List[Var]: self.enter() - if self.unwrap: - assert len(self.iter_vars) == 1 - return self.iter_vars[0] - else: - return self.iter_vars + return self.iter_vars def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: @@ -193,6 +187,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.append(body) +class ForRangeContext(ForContext): + def __enter__(self): + iter_vars = super().__enter__() + assert len(iter_vars) == 1 + return iter_vars[0] + + class IfContext(StmtContext): def __init__(self, vb: StmtBuilderCore, cond: Expr): super().__init__(vb) @@ -315,11 +316,13 @@ def is_empty(self): def for_range( self, extent: Union[Expr, int], iter_name_hint: str = "i", unroll_factor: Optional[int] = None - ) -> ForContext: + ) -> ForRangeContext: iter_var = Var(iter_name_hint, type=int32) - return ForContext(self, [iter_var], [as_expr(extent)], [unroll_factor], unwrap=True) + return ForRangeContext(self, [iter_var], [as_expr(extent)], [unroll_factor]) - def for_grid(self, extents: List[Union[Expr, int]], iter_name_hints: Optional[List[str]] = None) -> ForContext: + def for_grid( + self, extents: Sequence[Union[Expr, int]], iter_name_hints: Optional[Sequence[str]] = None + ) -> ForContext: expr_extents = [as_expr(extent) for extent in extents] if iter_name_hints is None: names = "ijkpqrstuvw" @@ -1231,6 +1234,7 @@ def wait_barrier(self, barrier: Expr | RegisterTensor, phase: Expr | int | Regis phase = self.tensor_item_value(phase) elif isinstance(phase, int): phase = uint32(phase) + assert isinstance(phase, Expr) inst = WaitBarrierInst.create(barrier=barrier, phase=phase) self.append(inst) @@ -1245,6 +1249,7 @@ def cluster_launch_control_try_cancel( mbarrier = self.tensor_item_value(mbarrier) if isinstance(multicast, bool): multicast = boolean(multicast) + assert isinstance(multicast, Expr) inst = ClusterLaunchControlTryCancelInst.create(response=response, mbarrier=mbarrier, multicast=multicast) self.append(inst) diff --git a/python/tilus/ir/functors/functor.py b/python/tilus/ir/functors/functor.py index dcc3ec16..aa77f669 100644 --- a/python/tilus/ir/functors/functor.py +++ b/python/tilus/ir/functors/functor.py @@ -435,11 +435,7 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> RegisterLayout: return layout def visit_SharedLayout(self, layout: SharedLayout) -> SharedLayout: - offset = self.visit(layout.offset) - if offset is layout.offset: - return layout - else: - return SharedLayout(shape=layout.shape, size=layout.size, axes=layout.axes, offset=offset) + return layout def visit_GlobalLayout(self, layout: GlobalLayout) -> GlobalLayout: shape = self.visit(layout.shape) @@ -573,7 +569,7 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> None: pass def visit_SharedLayout(self, layout: SharedLayout) -> None: - self.visit(layout.offset) + pass def visit_GlobalLayout(self, layout: GlobalLayout) -> None: self.visit(layout.shape) diff --git a/python/tilus/ir/layout/cuda/tcgen05/smem.py b/python/tilus/ir/layout/cuda/tcgen05/smem.py index d903095d..6c8a5000 100644 --- a/python/tilus/ir/layout/cuda/tcgen05/smem.py +++ b/python/tilus/ir/layout/cuda/tcgen05/smem.py @@ -22,7 +22,6 @@ from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode from tilus.ir.layout.shared_layout import SharedLayout from tilus.ir.layout.utils.cute import CuteLayout, CuteSwizzle, IntTuple, SwizzledCuteLayout, cute_layout, tuple_product -from tilus.ir.utils.veceval import meshgrid, vectorized_evaluate from tilus.utils import floor_log2 # class Tcgen05SwizzleMode(Enum): @@ -161,11 +160,7 @@ def _generate_atom_grid(major_kind: Literal["MN", "K"], swizzle_mode: Tcgen05Swi major_kind=major_kind, swizzle_mode=swizzle_mode, SBO=0, LBO=0, m=1, k=1, T=t ) atom_layout = get_shared_layout_from_canonical(canonical_layout) - grid_axes = meshgrid(atom_layout.shape) - atom_grid = vectorized_evaluate( - expr=atom_layout.offset, var2value={axis: grid_axes[i] for i, axis in enumerate(atom_layout.axes)} - ) - return atom_grid + return atom_layout.as_numpy_grid() def canonicalize_shared_layout(shared_layout: SharedLayout, dtype: DataType) -> Optional[CanonicalSharedLayout]: @@ -195,10 +190,7 @@ def canonicalize_shared_layout(shared_layout: SharedLayout, dtype: DataType) -> T = 128 // dtype.nbits # Create meshgrid for the entire layout - grid_axes = meshgrid(shared_layout.shape) - entire_grid = vectorized_evaluate( - expr=shared_layout.offset, var2value={axis: grid_axes[i] for i, axis in enumerate(shared_layout.axes)} - ) + entire_grid = shared_layout.as_numpy_grid() entire_shape = shared_layout.shape # Try each swizzle mode and majorness using direct pattern analysis diff --git a/python/tilus/ir/layout/inference/inference_rules/load_shared.py b/python/tilus/ir/layout/inference/inference_rules/load_shared.py index a7da96d8..ce09e96d 100644 --- a/python/tilus/ir/layout/inference/inference_rules/load_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/load_shared.py @@ -85,9 +85,7 @@ def inference( if not (shared.has_layout() and not register.has_layout()): return {} - axes = shared.layout.axes - axes = shared.layout.axes - offset = shared.layout.offset + axes, offset = shared.layout.as_axes_mapping() info = analyze_grid( shape=shared.shape, diff --git a/python/tilus/ir/layout/ops/shared_ops.py b/python/tilus/ir/layout/ops/shared_ops.py index 87eaba87..968bc558 100644 --- a/python/tilus/ir/layout/ops/shared_ops.py +++ b/python/tilus/ir/layout/ops/shared_ops.py @@ -19,10 +19,8 @@ import tabulate from hidet.utils import gcd, prod -from tilus.extensions.hidet.ir.expr import index_vars from tilus.ir.layout.ops.utils import LayoutOperationError, get_mode_groups from tilus.ir.layout.shared_layout import SharedLayout, Swizzle, shared_layout -from tilus.ir.utils.veceval import meshgrid, vectorized_evaluate def strides_from_ranks(shape: Sequence[int], ranks: Sequence[int]) -> list[int]: @@ -380,14 +378,11 @@ def visualize_layout(layout: SharedLayout, tablefmt: str = "simple_grid") -> str if len(layout.shape) != 2: raise LayoutOperationError(f"Shared layout with shape {layout.shape} is not supported for visualization.") - grid = meshgrid(layout.shape) - axes = index_vars(num_vars=len(layout.shape)) - offset = layout(*axes) - offset_grid = vectorized_evaluate(offset, var2value={axis: grid[i] for i, axis in enumerate(axes)}) + grid = layout.as_numpy_grid() table = [] for i in range(layout.shape[0]): row = [] for j in range(layout.shape[1]): - row.append(f"{offset_grid[i, j]}") + row.append(f"{grid[i, j]}") table.append(row) return head + "\n" + tabulate.tabulate(table, tablefmt=tablefmt) diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index 32ca36c6..c6bc1ccd 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -17,11 +17,14 @@ from dataclasses import dataclass from typing import Optional, Sequence -from hidet.ir.expr import Expr, as_expr +import numpy as np +from hidet.ir.expr import Expr, Var, as_expr from hidet.ir.utils.index_transform import index_deserialize from hidet.utils import prod +from tilus.extensions.hidet.ir.expr import index_vars from tilus.ir.node import IRNode +from tilus.ir.utils.veceval import vectorized_evaluate @dataclass(frozen=True, eq=True) @@ -148,6 +151,18 @@ def create( shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), swizzle=swizzle ) + def as_numpy_grid(self) -> np.ndarray: + grid_axes = np.meshgrid(*[np.arange(extent) for extent in self.shape]) + axes = index_vars(num_vars=len(self.shape)) + offset = self(*axes) + atom_grid = vectorized_evaluate(expr=offset, var2value={axis: grid_axes[i] for i, axis in enumerate(axes)}) + return atom_grid + + def as_axes_mapping(self) -> tuple[list[Var], Expr]: + axes = index_vars(num_vars=len(self.shape)) + offset = self(*axes) + return axes, offset + def count_size(self) -> int: """Count the total size of the shared layout. @@ -182,7 +197,7 @@ def prepend_dim(self, extent: int) -> SharedLayout: shape = (extent,) + self.shape if extent > 1: mode_shape = (extent,) + self.mode_shape - mode_strides = (self.size,) + self.mode_strides + mode_strides = (self.count_size(),) + self.mode_strides else: mode_shape = self.mode_shape mode_strides = self.mode_strides diff --git a/python/tilus/ir/tensor.py b/python/tilus/ir/tensor.py index 7361edd3..10d1dc16 100644 --- a/python/tilus/ir/tensor.py +++ b/python/tilus/ir/tensor.py @@ -371,7 +371,23 @@ def __eq__(self, other): """ raise RuntimeError("tensor == tensor could only be used in Tilus Script.") - def __xor__(self, other): + def __ne__(self, value): + """ + Not equal to comparison. + + Parameters + ---------- + value: RegisterTensor | int | float | Expr + The tensor or scalar to compare with this tensor. + + Returns + ------- + ret: RegisterTensor + A new tensor that is the result of the comparison. + """ + raise RuntimeError("tensor != tensor could only be used in Tilus Script.") + + def __xor__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """Bitwise XOR operation. Parameters @@ -419,8 +435,23 @@ def __rsub__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor """ raise RuntimeError("tensor - tensor could only be used in Tilus Script.") + def __rtruediv__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: + """Perform right-side division with another tensor or a scalar. + + Parameters + ---------- + other: RegisterTensor | int | float | Expr + The tensor or scalar to divide this tensor by. + + Returns + ------- + ret: RegisterTensor + A new tensor that is the result of the division. + """ + raise RuntimeError("tensor / tensor could only be used in Tilus Script.") + # i-version of operator - def __iadd__(self, other: RegisterTensor | int | float | Expr) -> None: + def __iadd__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place addition operation. Parameters @@ -430,7 +461,7 @@ def __iadd__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor += tensor could only be used in Tilus Script.") - def __isub__(self, other: RegisterTensor | int | float | Expr) -> None: + def __isub__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place subtraction operation. Parameters @@ -440,7 +471,7 @@ def __isub__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor -= tensor could only be used in Tilus Script.") - def __imul__(self, other: RegisterTensor | int | float | Expr) -> None: + def __imul__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place multiplication operation. Parameters @@ -450,7 +481,7 @@ def __imul__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor *= tensor could only be used in Tilus Script.") - def __itruediv__(self, other: RegisterTensor | int | float | Expr) -> None: + def __itruediv__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place division operation. Parameters @@ -460,7 +491,7 @@ def __itruediv__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor /= tensor could only be used in Tilus Script.") - def __imod__(self, other: RegisterTensor | int | float | Expr) -> None: + def __imod__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place modulus operation. Parameters @@ -470,7 +501,7 @@ def __imod__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor %= tensor could only be used in Tilus Script.") - def __ixor__(self, other: RegisterTensor | int | float | Expr) -> None: + def __ixor__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place bitwise XOR operation. Parameters @@ -509,7 +540,7 @@ def transpose(self) -> RegisterTensor: def to(self, dtype: DataType) -> RegisterTensor: raise RuntimeError("tensor.to(...) could only be used in Tilus Script.") - def tolist(self) -> Expr | list: + def tolist(self) -> list: raise RuntimeError("tensor.tolist() could only be used in Tilus Script.") @@ -584,7 +615,7 @@ def size(self) -> int: ret: int The size of the SharedTensor, which is the number of elements it contains. """ - return self.layout.size + return self.layout.count_size() @property def nbytes(self) -> int: diff --git a/python/tilus/ir/tools/printer.py b/python/tilus/ir/tools/printer.py index 52a06335..a37cb634 100644 --- a/python/tilus/ir/tools/printer.py +++ b/python/tilus/ir/tools/printer.py @@ -114,7 +114,7 @@ def get_tensor_type(self, tensor: Tensor) -> Doc: doc = Text("shared, ") doc += self.printer(tensor.dtype) + "[" + self.visit(tensor.shape) + "]" if tensor.optional_layout is not None: - doc += ", size={}".format(tensor.layout.size) + doc += ", size={}".format(tensor.size) doc += ", {}".format(self.visit(tensor.layout)) return doc elif isinstance(tensor, GlobalTensor): diff --git a/python/tilus/lang/instructions/root.py b/python/tilus/lang/instructions/root.py index ea9cccda..77898a8c 100644 --- a/python/tilus/lang/instructions/root.py +++ b/python/tilus/lang/instructions/root.py @@ -34,12 +34,12 @@ class RootInstructionGroup(InstructionGroup): @property def blockIdx(self) -> Dim3: """Get the block index of the current thread block.""" - return Dim3(blockIdx.x, blockIdx.y, blockIdx.z) + return Dim3(blockIdx.x, blockIdx.y, blockIdx.z) # type: ignore[attr-defined] @property def gridDim(self) -> Dim3: """Get the grid dimension of the kernel.""" - return Dim3(gridDim.x, gridDim.y, gridDim.z) + return Dim3(gridDim.x, gridDim.y, gridDim.z) # type: ignore[attr-defined] @property def current_thread_begin(self) -> int: @@ -1627,7 +1627,7 @@ def print_tensor(self, msg: str, tensor: Tensor, fmt: Optional[str] = None) -> N """ self._builder.print_tensor(msg=msg, tensor=tensor, fmt=fmt) - def printf(self, fstring: str, *args: Expr | int | float) -> None: + def printf(self, fstring: str, *args: Expr | int | float | str) -> None: """Print a formatted string. This instruction prints a formatted string to the standard output. The `fstring` parameter is a format string diff --git a/python/tilus/lang/instructions/tcgen05.py b/python/tilus/lang/instructions/tcgen05.py index 187aa816..a1aa770d 100644 --- a/python/tilus/lang/instructions/tcgen05.py +++ b/python/tilus/lang/instructions/tcgen05.py @@ -53,6 +53,7 @@ def alloc( "The thread group used to allocate with initialization must start at a multiple of 128 " "and have at least 128 threads." ) + ctx: contextlib.AbstractContextManager if thread_end - thread_begin == 128: ctx = contextlib.nullcontext() else: diff --git a/python/tilus/lang/script.py b/python/tilus/lang/script.py index 63657abf..9f7e4893 100644 --- a/python/tilus/lang/script.py +++ b/python/tilus/lang/script.py @@ -43,7 +43,7 @@ class Script(InstructionInterface): # specify the schedule used for debugging. it will override any autotune space debug_schedule: Optional[dict[str, Any]] = None - def __new__(cls, *args, **kwargs) -> InstantiatedScript: # type: ignore[no-untyped-def] + def __new__(cls, *args, **kwargs) -> InstantiatedScript: # type: ignore[no-untyped-def, misc] from tilus.lang.instantiated_script import InstantiatedScriptCache instantiated_script: InstantiatedScript = InstantiatedScriptCache.get( diff --git a/python/tilus/lang/transpiler/transpiler.py b/python/tilus/lang/transpiler/transpiler.py index 3b2aec13..09fcf678 100644 --- a/python/tilus/lang/transpiler/transpiler.py +++ b/python/tilus/lang/transpiler/transpiler.py @@ -203,7 +203,7 @@ def transpile( metadata = Metadata( grid_blocks=normalize_grid_blocks(script.attrs.blocks), cluster_blocks=normalize_cluster_blocks(script.attrs.cluster_blocks), - block_indices=(blockIdx.x, blockIdx.y, blockIdx.z), + block_indices=(blockIdx.x, blockIdx.y, blockIdx.z), # type: ignore[attr-defined] num_warps=script.attrs.warps, param2divisibility=frozendict(param2divisibility), analysis=None, @@ -686,7 +686,7 @@ def visit_Call(self, expr: ast.Call) -> Any: # case 4 class_cls: Type[Class] = func obj = object.__new__(class_cls) - self.transpile_call(obj.__init__, args, kwargs) + self.transpile_call(obj.__init__, args, kwargs) # type: ignore ret = obj elif func is super: # case 5 diff --git a/tests/ir/tools/verifier/test_verify_load_shared.py b/tests/ir/tools/verifier/test_verify_load_shared.py index 693671ae..0a90c567 100644 --- a/tests/ir/tools/verifier/test_verify_load_shared.py +++ b/tests/ir/tools/verifier/test_verify_load_shared.py @@ -41,7 +41,7 @@ def __call__(self): ) def test_verify_load_shared(shared_shape, register_shape, success): script = DemoLoadShared(shared_shape=shared_shape, register_shape=register_shape) - program = script._jit_instance_for().transpiled_programs[0] + program = script._jit_instance_for().transpiled_programs[0] # type: ignore if success: verify(program) From 31f4dbfdd21f0fc243e9896ff995afa19f62e111 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 00:27:51 +0000 Subject: [PATCH 07/17] wip Signed-off-by: Yaoyao Ding --- .../inference_rules/transform_shared.py | 10 ++++-- python/tilus/ir/layout/ops/shared_ops.py | 32 +++++++++++++++++++ python/tilus/ir/layout/shared_layout.py | 6 ++-- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py index 172d8627..4fd73d57 100644 --- a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py @@ -28,13 +28,19 @@ def inference(ctx: LayoutInferenceContext, inst: SliceSharedInst) -> dict[Shared if a.optional_layout is not None and b.optional_layout is not None: return {} elif a.optional_layout is not None: - return {b: a.layout.slice(offsets=inst.offsets, slice_dims=inst.dims, slice_shape=b.shape).simplify()} + if inst.dims is None: + dims = list(range(len(a.shape))) + else: + dims = inst.dims + return { + b: a.layout.slice(retain_dims=dims) + } elif b.optional_layout is not None: b_layout = b.layout.unsqueeze(dims=range(len(a.shape) - len(b.shape))) outer_shape = [] for i in range(len(a.shape)): outer_shape.append(a.shape[i] // b_layout.shape[i]) - return {a: shared_compose(shared_row_major(*outer_shape), b_layout).simplify()} + return {a: shared_compose(shared_row_major(*outer_shape), b_layout)} else: return {} diff --git a/python/tilus/ir/layout/ops/shared_ops.py b/python/tilus/ir/layout/ops/shared_ops.py index 968bc558..260344d6 100644 --- a/python/tilus/ir/layout/ops/shared_ops.py +++ b/python/tilus/ir/layout/ops/shared_ops.py @@ -153,6 +153,38 @@ def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=layout.swizzle) +def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayout: + """Slice the shared layout by removing specified dimensions. + + Parameters + ---------- + layout: SharedLayout + The layout to slice. + dims: Sequence[int] + The dimensions to slice. Each dimension should be in the range [0, len(layout.shape)). The dimensions will + be kept in the output layout. + + Returns + ------- + ret: SharedLayout + The sliced layout. + """ + assert all(0 <= d < len(layout.shape) for d in retain_dims) and len(retain_dims) == len(set(retain_dims)) + shape: List[int] = [] + mode_shape: List[int] = [] + mode_strides: List[int] = [] + layout_mode_groups = get_mode_groups(layout.shape, layout.mode_shape) + for i in retain_dims: + shape.append(layout.shape[i]) + mode_shape.extend([layout.mode_shape[j] for j in layout_mode_groups[i]]) + mode_strides.extend([layout.mode_strides[j] for j in layout_mode_groups[i]]) + + return shared_layout( + shape=shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + swizzle=layout.swizzle, + ) def shared_unsqueeze(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: """Unsqueeze the shared layout by adding new dimensions of size 1. diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index c6bc1ccd..b5d9ece2 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -177,8 +177,10 @@ def count_size(self) -> int: max_index = sum(a * b for a, b in zip(indices, self.mode_strides)) return max_index + 1 - def slice(self, offsets: Sequence[Expr], slice_dims: Sequence[int], slice_shape: Sequence[int]) -> SharedLayout: - raise RuntimeError("No slice anymore.") + def slice(self, retain_dims: Sequence[int]) -> SharedLayout: + from tilus.ir.layout.ops.shared_ops import shared_slice + + return shared_slice(self, retain_dims) def simplify(self) -> SharedLayout: raise RuntimeError("No need to simplify anymore.") From 47fb33bedaca8b6f33cce0c8b2b1d7ca5803b691 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 00:57:20 +0000 Subject: [PATCH 08/17] wip Signed-off-by: Yaoyao Ding --- python/tilus/ir/layout/cuda/tcgen05/smem.py | 1 + python/tilus/ir/layout/utils/cute.py | 7 ++- tests/instructions/test_tcgen05_copy.py | 5 ++- tests/ir/layout/test_tcgen05_smem.py | 49 ++++++++++++++++++++- 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/python/tilus/ir/layout/cuda/tcgen05/smem.py b/python/tilus/ir/layout/cuda/tcgen05/smem.py index 6c8a5000..c58daf1a 100644 --- a/python/tilus/ir/layout/cuda/tcgen05/smem.py +++ b/python/tilus/ir/layout/cuda/tcgen05/smem.py @@ -181,6 +181,7 @@ def canonicalize_shared_layout(shared_layout: SharedLayout, dtype: DataType) -> ret: Optional[CanonicalSharedLayout] The canonical form if found, None otherwise """ + # todo: simplify the implementation of this function since we used a similar layout system as cute now if len(shared_layout.shape) != 2: return None diff --git a/python/tilus/ir/layout/utils/cute.py b/python/tilus/ir/layout/utils/cute.py index 8843b118..f9e277ed 100644 --- a/python/tilus/ir/layout/utils/cute.py +++ b/python/tilus/ir/layout/utils/cute.py @@ -132,9 +132,12 @@ def reverse_int_tuple(t: IntTuple) -> IntTuple: return tuple(reverse_int_tuple(item) for item in reversed(t)) else: return t + + assert isinstance(self.layout.shape, Sequence) + assert isinstance(self.layout.strides, Sequence) - rev_shape = reverse_int_tuple(self.layout.shape) - rev_strides = reverse_int_tuple(self.layout.strides) + rev_shape = [reverse_int_tuple(item) for item in self.layout.shape] + rev_strides = [reverse_int_tuple(item) for item in self.layout.strides] # then, we flatten them into 1D lists def flatten_int_tuple(t: IntTuple) -> list[Int]: diff --git a/tests/instructions/test_tcgen05_copy.py b/tests/instructions/test_tcgen05_copy.py index 6426fd82..4b1383cc 100644 --- a/tests/instructions/test_tcgen05_copy.py +++ b/tests/instructions/test_tcgen05_copy.py @@ -58,8 +58,9 @@ def __call__(self, m_size: int, n_size: int, x_ptr: ~int32, y_ptr: ~int32): self.sync() # copy x from shared to tmem - self.tcgen05.copy(src=s_x, dst=t_x) - self.tcgen05.commit(mbarrier=barriers[0]) + with self.single_thread(): + self.tcgen05.copy(src=s_x, dst=t_x) + self.tcgen05.commit(mbarrier=barriers[0]) self.mbarrier.wait(barriers[0], phase=0) # load y from tmem to register diff --git a/tests/ir/layout/test_tcgen05_smem.py b/tests/ir/layout/test_tcgen05_smem.py index e3901edc..f1f9f74b 100644 --- a/tests/ir/layout/test_tcgen05_smem.py +++ b/tests/ir/layout/test_tcgen05_smem.py @@ -362,6 +362,8 @@ def test_shared_layout_from_canonical(canonical, expected): expected = expected.strip() if actual != expected: + print(actual) + print(expected) assert False, canonical @@ -393,9 +395,54 @@ def test_canonicalize_shared_layout(canonical): dtype = t_to_dtype[canonical.T] layout = get_shared_layout_from_canonical(canonical) recovered_canonical = canonicalize_shared_layout(layout, dtype) + print(canonical) + print(recovered_canonical) assert recovered_canonical is not None assert recovered_canonical == canonical, f"{recovered_canonical} != {canonical}" if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + # test_canonicalize_shared_layout( + # CanonicalSharedLayout(major_kind="K", swizzle_mode=Tcgen05SwizzleMode.B32_SWIZZLE, SBO=64, LBO=128, m=2, k=2, T=4) + # ) + canonical = CanonicalSharedLayout( + major_kind="K", swizzle_mode=Tcgen05SwizzleMode.B32_SWIZZLE, SBO=64, LBO=128, m=2, k=2, T=4 + ) + expected = """ +┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ +│ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 128 │ 129 │ 130 │ 131 │ 132 │ 133 │ 134 │ 135 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 8 │ 9 │ 10 │ 11 │ 12 │ 13 │ 14 │ 15 │ 136 │ 137 │ 138 │ 139 │ 140 │ 141 │ 142 │ 143 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 16 │ 17 │ 18 │ 19 │ 20 │ 21 │ 22 │ 23 │ 144 │ 145 │ 146 │ 147 │ 148 │ 149 │ 150 │ 151 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 24 │ 25 │ 26 │ 27 │ 28 │ 29 │ 30 │ 31 │ 152 │ 153 │ 154 │ 155 │ 156 │ 157 │ 158 │ 159 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 36 │ 37 │ 38 │ 39 │ 32 │ 33 │ 34 │ 35 │ 164 │ 165 │ 166 │ 167 │ 160 │ 161 │ 162 │ 163 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 44 │ 45 │ 46 │ 47 │ 40 │ 41 │ 42 │ 43 │ 172 │ 173 │ 174 │ 175 │ 168 │ 169 │ 170 │ 171 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 52 │ 53 │ 54 │ 55 │ 48 │ 49 │ 50 │ 51 │ 180 │ 181 │ 182 │ 183 │ 176 │ 177 │ 178 │ 179 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 60 │ 61 │ 62 │ 63 │ 56 │ 57 │ 58 │ 59 │ 188 │ 189 │ 190 │ 191 │ 184 │ 185 │ 186 │ 187 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 64 │ 65 │ 66 │ 67 │ 68 │ 69 │ 70 │ 71 │ 192 │ 193 │ 194 │ 195 │ 196 │ 197 │ 198 │ 199 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 72 │ 73 │ 74 │ 75 │ 76 │ 77 │ 78 │ 79 │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ 206 │ 207 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 80 │ 81 │ 82 │ 83 │ 84 │ 85 │ 86 │ 87 │ 208 │ 209 │ 210 │ 211 │ 212 │ 213 │ 214 │ 215 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 88 │ 89 │ 90 │ 91 │ 92 │ 93 │ 94 │ 95 │ 216 │ 217 │ 218 │ 219 │ 220 │ 221 │ 222 │ 223 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 100 │ 101 │ 102 │ 103 │ 96 │ 97 │ 98 │ 99 │ 228 │ 229 │ 230 │ 231 │ 224 │ 225 │ 226 │ 227 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 108 │ 109 │ 110 │ 111 │ 104 │ 105 │ 106 │ 107 │ 236 │ 237 │ 238 │ 239 │ 232 │ 233 │ 234 │ 235 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 116 │ 117 │ 118 │ 119 │ 112 │ 113 │ 114 │ 115 │ 244 │ 245 │ 246 │ 247 │ 240 │ 241 │ 242 │ 243 │ +├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ +│ 124 │ 125 │ 126 │ 127 │ 120 │ 121 │ 122 │ 123 │ 252 │ 253 │ 254 │ 255 │ 248 │ 249 │ 250 │ 251 │ +└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ + """ + test_shared_layout_from_canonical(canonical=canonical, expected=expected) + From 47d5607562b40d901419f96d6143e340f797c2a9 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 01:24:21 +0000 Subject: [PATCH 09/17] wip Signed-off-by: Yaoyao Ding --- .../backends/emitters/cuda/tcgen05/copy.py | 9 +--- .../inference/inference_rules/tcgen05/mma.py | 5 +- python/tilus/ir/layout/shared_layout.py | 5 +- tests/instructions/test_tcgen05_mma.py | 5 +- tests/ir/layout/test_tcgen05_smem.py | 48 +------------------ 5 files changed, 9 insertions(+), 63 deletions(-) diff --git a/python/tilus/backends/emitters/cuda/tcgen05/copy.py b/python/tilus/backends/emitters/cuda/tcgen05/copy.py index ab0d4ae8..0d64253c 100644 --- a/python/tilus/backends/emitters/cuda/tcgen05/copy.py +++ b/python/tilus/backends/emitters/cuda/tcgen05/copy.py @@ -188,17 +188,12 @@ def generate_instructions( raise ValueError("No valid instructions generated") - def check_warp_group(self) -> None: - begin = self.current_thread_group_begin - end = self.current_thread_group_end - if begin % 128 != 0 or end - begin != 128: - raise ValueError("The number of threads in the current thread group must be 128") - def emit(self, inst: Tcgen05CopyInst) -> None: shared_tensor = inst.inputs[1].as_shared_tensor() tmem_tensor = inst.inputs[0].as_tmemory_tensor() - self.check_warp_group() + if self.current_num_threads != 1: + raise ValueError("Tcgen05CopyInst can only be emitted in thread group with a single thread") if len(shared_tensor.shape) != 2: raise ValueError("The shared tensor must be a 2D tensor, got shape {}".format(shared_tensor.shape)) diff --git a/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py b/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py index 60243e1f..e35e6142 100644 --- a/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py +++ b/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py @@ -61,7 +61,7 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05MmaSSInst) -> dict[Share a_layout_canonical = generate_canonical_layout( shape=(m, k), dtype=a_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[a_tensor] = a_layout_canonical.as_shared_layout().simplify() + ret[a_tensor] = a_layout_canonical.as_shared_layout() except ValueError: continue else: @@ -77,7 +77,7 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05MmaSSInst) -> dict[Share b_layout_canonical = generate_canonical_layout( shape=(n, k), dtype=b_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]).simplify() + ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]) except ValueError: continue else: @@ -123,7 +123,6 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05MmaTSInst) -> dict[Share ) .as_shared_layout() .permute(dims=[1, 0]) - .simplify() ) ret[b_tensor] = b_layout except ValueError: diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index b5d9ece2..47a98e4d 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -152,7 +152,7 @@ def create( ) def as_numpy_grid(self) -> np.ndarray: - grid_axes = np.meshgrid(*[np.arange(extent) for extent in self.shape]) + grid_axes = np.meshgrid(*[np.arange(extent) for extent in self.shape], indexing="ij") axes = index_vars(num_vars=len(self.shape)) offset = self(*axes) atom_grid = vectorized_evaluate(expr=offset, var2value={axis: grid_axes[i] for i, axis in enumerate(axes)}) @@ -182,9 +182,6 @@ def slice(self, retain_dims: Sequence[int]) -> SharedLayout: return shared_slice(self, retain_dims) - def simplify(self) -> SharedLayout: - raise RuntimeError("No need to simplify anymore.") - def apply_swizzle(self, swizzle: Swizzle) -> SharedLayout: if self.swizzle is not None: raise RuntimeError("Chained swizzle is not supported.") diff --git a/tests/instructions/test_tcgen05_mma.py b/tests/instructions/test_tcgen05_mma.py index f8db7bdf..c249c511 100644 --- a/tests/instructions/test_tcgen05_mma.py +++ b/tests/instructions/test_tcgen05_mma.py @@ -78,8 +78,9 @@ def __call__(self, a_ptr: void_p, b_ptr: void_p, d_ptr: void_p) -> None: self.mbarrier.wait(tma_mbarrier, phase=0) # perform mma - self.tcgen05.mma(a=s_a, b=s_b.transpose(), d=t_d) - self.tcgen05.commit(mma_mbarrier) + with self.single_thread(): + self.tcgen05.mma(a=s_a, b=s_b.transpose(), d=t_d) + self.tcgen05.commit(mma_mbarrier) self.mbarrier.wait(mma_mbarrier, phase=0) # store d from t_d to global diff --git a/tests/ir/layout/test_tcgen05_smem.py b/tests/ir/layout/test_tcgen05_smem.py index f1f9f74b..b08ddf17 100644 --- a/tests/ir/layout/test_tcgen05_smem.py +++ b/tests/ir/layout/test_tcgen05_smem.py @@ -362,8 +362,6 @@ def test_shared_layout_from_canonical(canonical, expected): expected = expected.strip() if actual != expected: - print(actual) - print(expected) assert False, canonical @@ -395,54 +393,10 @@ def test_canonicalize_shared_layout(canonical): dtype = t_to_dtype[canonical.T] layout = get_shared_layout_from_canonical(canonical) recovered_canonical = canonicalize_shared_layout(layout, dtype) - print(canonical) - print(recovered_canonical) assert recovered_canonical is not None assert recovered_canonical == canonical, f"{recovered_canonical} != {canonical}" if __name__ == "__main__": - # pytest.main([__file__]) - # test_canonicalize_shared_layout( - # CanonicalSharedLayout(major_kind="K", swizzle_mode=Tcgen05SwizzleMode.B32_SWIZZLE, SBO=64, LBO=128, m=2, k=2, T=4) - # ) - canonical = CanonicalSharedLayout( - major_kind="K", swizzle_mode=Tcgen05SwizzleMode.B32_SWIZZLE, SBO=64, LBO=128, m=2, k=2, T=4 - ) - expected = """ -┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ -│ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 128 │ 129 │ 130 │ 131 │ 132 │ 133 │ 134 │ 135 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 8 │ 9 │ 10 │ 11 │ 12 │ 13 │ 14 │ 15 │ 136 │ 137 │ 138 │ 139 │ 140 │ 141 │ 142 │ 143 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 16 │ 17 │ 18 │ 19 │ 20 │ 21 │ 22 │ 23 │ 144 │ 145 │ 146 │ 147 │ 148 │ 149 │ 150 │ 151 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 24 │ 25 │ 26 │ 27 │ 28 │ 29 │ 30 │ 31 │ 152 │ 153 │ 154 │ 155 │ 156 │ 157 │ 158 │ 159 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 36 │ 37 │ 38 │ 39 │ 32 │ 33 │ 34 │ 35 │ 164 │ 165 │ 166 │ 167 │ 160 │ 161 │ 162 │ 163 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 44 │ 45 │ 46 │ 47 │ 40 │ 41 │ 42 │ 43 │ 172 │ 173 │ 174 │ 175 │ 168 │ 169 │ 170 │ 171 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 52 │ 53 │ 54 │ 55 │ 48 │ 49 │ 50 │ 51 │ 180 │ 181 │ 182 │ 183 │ 176 │ 177 │ 178 │ 179 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 60 │ 61 │ 62 │ 63 │ 56 │ 57 │ 58 │ 59 │ 188 │ 189 │ 190 │ 191 │ 184 │ 185 │ 186 │ 187 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 64 │ 65 │ 66 │ 67 │ 68 │ 69 │ 70 │ 71 │ 192 │ 193 │ 194 │ 195 │ 196 │ 197 │ 198 │ 199 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 72 │ 73 │ 74 │ 75 │ 76 │ 77 │ 78 │ 79 │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ 206 │ 207 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 80 │ 81 │ 82 │ 83 │ 84 │ 85 │ 86 │ 87 │ 208 │ 209 │ 210 │ 211 │ 212 │ 213 │ 214 │ 215 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 88 │ 89 │ 90 │ 91 │ 92 │ 93 │ 94 │ 95 │ 216 │ 217 │ 218 │ 219 │ 220 │ 221 │ 222 │ 223 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 100 │ 101 │ 102 │ 103 │ 96 │ 97 │ 98 │ 99 │ 228 │ 229 │ 230 │ 231 │ 224 │ 225 │ 226 │ 227 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 108 │ 109 │ 110 │ 111 │ 104 │ 105 │ 106 │ 107 │ 236 │ 237 │ 238 │ 239 │ 232 │ 233 │ 234 │ 235 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 116 │ 117 │ 118 │ 119 │ 112 │ 113 │ 114 │ 115 │ 244 │ 245 │ 246 │ 247 │ 240 │ 241 │ 242 │ 243 │ -├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤ -│ 124 │ 125 │ 126 │ 127 │ 120 │ 121 │ 122 │ 123 │ 252 │ 253 │ 254 │ 255 │ 248 │ 249 │ 250 │ 251 │ -└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ - """ - test_shared_layout_from_canonical(canonical=canonical, expected=expected) + pytest.main([__file__]) From 88f5fb4ddf365c41ec48b462d49d686efc8d724b Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 01:27:46 +0000 Subject: [PATCH 10/17] fix tests Signed-off-by: Yaoyao Ding --- examples/quantization/matmul_a16wx.py | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/quantization/matmul_a16wx.py b/examples/quantization/matmul_a16wx.py index b08ac3e2..70efdf0b 100644 --- a/examples/quantization/matmul_a16wx.py +++ b/examples/quantization/matmul_a16wx.py @@ -22,7 +22,7 @@ int32, uint8, ) -from tilus.ir.layout.ops import concat, local, reduce, spatial +from tilus.ir.layout.ops import concat, local, reduce, spatial, shared_row_major_swizzle from tilus.utils import benchmark_func, cdiv, dtype_to_torch, gcd from torch import nn @@ -217,16 +217,16 @@ def __init__( ) self.layout_rs = reduce(self.mma.lb, dims=[0], keepdims=True) - # self.layout_sa = self.cuda.swizzled_shared_layout( - # self.a_dtype, shape=[num_stages, self.block_m, self.block_k] - # ) - # self.layout_sb = self.cuda.shared_layout( - # shape=[self.num_stages, k_tiles, n_tiles, self.tile_bytes] - # ) - # self.layout_sc = self.cuda.swizzled_shared_layout( - # self.a_dtype, shape=[self.block_m, self.block_n] - # ) - # self.layout_ss = self.cuda.shared_layout(shape=[self.num_stages, 1, self.block_n]) + self.layout_sa = shared_row_major_swizzle( + dtype_nbytes=self.a_dtype.nbytes, shape=[num_stages, self.block_m, self.block_k] + ) + self.layout_sb = self.cuda.shared_layout( + shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes] + ) + self.layout_sc = shared_row_major_swizzle( + dtype_nbytes=self.a_dtype.nbytes, shape=[self.block_m, self.block_n] + ) + self.layout_ss = self.cuda.shared_layout(shape=[self.num_stages, 1, self.block_n]) def __call__( self, @@ -398,10 +398,10 @@ def __call__( ) # annotate layouts - # self.annotate_layout(sc, layout=self.layout_sc) - # self.annotate_layout(sa, layout=self.layout_sa) - # self.annotate_layout(sb, layout=self.layout_sb) - # self.annotate_layout(ss, layout=self.layout_ss) + self.annotate_layout(sc, layout=self.layout_sc) + self.annotate_layout(sa, layout=self.layout_sa) + self.annotate_layout(sb, layout=self.layout_sb) + self.annotate_layout(ss, layout=self.layout_ss) self.annotate_layout(acc, layout=self.mma.lc) From 522e4c709a7df78338ae46c88a5d76798b344446 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 01:28:28 +0000 Subject: [PATCH 11/17] finish Signed-off-by: Yaoyao Ding --- examples/quantization/matmul_a16wx.py | 5 +++-- .../ir/layout/inference/inference_rules/transform_shared.py | 6 ++---- python/tilus/ir/layout/ops/shared_ops.py | 2 ++ python/tilus/ir/layout/utils/cute.py | 2 +- tests/ir/layout/test_tcgen05_smem.py | 1 - 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/quantization/matmul_a16wx.py b/examples/quantization/matmul_a16wx.py index 70efdf0b..aeb54684 100644 --- a/examples/quantization/matmul_a16wx.py +++ b/examples/quantization/matmul_a16wx.py @@ -22,7 +22,7 @@ int32, uint8, ) -from tilus.ir.layout.ops import concat, local, reduce, spatial, shared_row_major_swizzle +from tilus.ir.layout.ops import concat, local, reduce, shared_row_major_swizzle, spatial from tilus.utils import benchmark_func, cdiv, dtype_to_torch, gcd from torch import nn @@ -218,7 +218,8 @@ def __init__( self.layout_rs = reduce(self.mma.lb, dims=[0], keepdims=True) self.layout_sa = shared_row_major_swizzle( - dtype_nbytes=self.a_dtype.nbytes, shape=[num_stages, self.block_m, self.block_k] + dtype_nbytes=self.a_dtype.nbytes, + shape=[num_stages, self.block_m, self.block_k], ) self.layout_sb = self.cuda.shared_layout( shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes] diff --git a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py index 4fd73d57..3b37c298 100644 --- a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py @@ -31,10 +31,8 @@ def inference(ctx: LayoutInferenceContext, inst: SliceSharedInst) -> dict[Shared if inst.dims is None: dims = list(range(len(a.shape))) else: - dims = inst.dims - return { - b: a.layout.slice(retain_dims=dims) - } + dims = list(inst.dims) + return {b: a.layout.slice(retain_dims=dims)} elif b.optional_layout is not None: b_layout = b.layout.unsqueeze(dims=range(len(a.shape) - len(b.shape))) outer_shape = [] diff --git a/python/tilus/ir/layout/ops/shared_ops.py b/python/tilus/ir/layout/ops/shared_ops.py index 260344d6..dde80b68 100644 --- a/python/tilus/ir/layout/ops/shared_ops.py +++ b/python/tilus/ir/layout/ops/shared_ops.py @@ -153,6 +153,7 @@ def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=layout.swizzle) + def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayout: """Slice the shared layout by removing specified dimensions. @@ -186,6 +187,7 @@ def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayo swizzle=layout.swizzle, ) + def shared_unsqueeze(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: """Unsqueeze the shared layout by adding new dimensions of size 1. diff --git a/python/tilus/ir/layout/utils/cute.py b/python/tilus/ir/layout/utils/cute.py index f9e277ed..d66ed4f9 100644 --- a/python/tilus/ir/layout/utils/cute.py +++ b/python/tilus/ir/layout/utils/cute.py @@ -132,7 +132,7 @@ def reverse_int_tuple(t: IntTuple) -> IntTuple: return tuple(reverse_int_tuple(item) for item in reversed(t)) else: return t - + assert isinstance(self.layout.shape, Sequence) assert isinstance(self.layout.strides, Sequence) diff --git a/tests/ir/layout/test_tcgen05_smem.py b/tests/ir/layout/test_tcgen05_smem.py index b08ddf17..e3901edc 100644 --- a/tests/ir/layout/test_tcgen05_smem.py +++ b/tests/ir/layout/test_tcgen05_smem.py @@ -399,4 +399,3 @@ def test_canonicalize_shared_layout(canonical): if __name__ == "__main__": pytest.main([__file__]) - From 62c3d05e3b062aa2bb3ca6b3d8a24874fa560e68 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 06:59:31 +0000 Subject: [PATCH 12/17] fix test Signed-off-by: Yaoyao Ding --- .../inference_rules/transform_shared.py | 2 +- python/tilus/ir/layout/shared_layout.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py index 3b37c298..cd653ced 100644 --- a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py @@ -38,7 +38,7 @@ def inference(ctx: LayoutInferenceContext, inst: SliceSharedInst) -> dict[Shared outer_shape = [] for i in range(len(a.shape)): outer_shape.append(a.shape[i] // b_layout.shape[i]) - return {a: shared_compose(shared_row_major(*outer_shape), b_layout)} + return {a: shared_compose(shared_row_major(*outer_shape), b_layout).apply_swizzle(b.layout.swizzle)} else: return {} diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index 47a98e4d..9be5b067 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -86,7 +86,7 @@ class SharedLayout(IRNode): shape: tuple[int, ...] mode_shape: tuple[int, ...] mode_strides: tuple[int, ...] - swizzle: Optional[Swizzle] + optional_swizzle: Optional[Swizzle] def __call__(self, *indices: Expr) -> Expr: """Compute the offset on given indices. @@ -113,11 +113,17 @@ def __call__(self, *indices: Expr) -> Expr: total_index: Expr = as_expr(sum(index * stride for index, stride in zip(mode_indices, self.mode_strides))) # apply swizzle if exists - if self.swizzle is not None: - total_index = self.swizzle(total_index) + if self.optional_swizzle is not None: + total_index = self.optional_swizzle(total_index) return total_index + @property + def swizzle(self) -> Swizzle: + if self.optional_swizzle is None: + raise ValueError("No swizzle is applied on this layout.") + return self.optional_swizzle + @staticmethod def create( shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequence[int], swizzle: Optional[Swizzle] @@ -148,7 +154,7 @@ def create( if prod(mode_shape) != prod(shape): raise ValueError("The product of mode_shape must equal to the product of shape.") return SharedLayout( - shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), swizzle=swizzle + shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), optional_swizzle=swizzle ) def as_numpy_grid(self) -> np.ndarray: @@ -183,7 +189,7 @@ def slice(self, retain_dims: Sequence[int]) -> SharedLayout: return shared_slice(self, retain_dims) def apply_swizzle(self, swizzle: Swizzle) -> SharedLayout: - if self.swizzle is not None: + if self.optional_swizzle is not None: raise RuntimeError("Chained swizzle is not supported.") return SharedLayout.create( shape=self.shape, @@ -205,7 +211,7 @@ def prepend_dim(self, extent: int) -> SharedLayout: shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, - swizzle=self.swizzle, + swizzle=self.optional_swizzle, ) def transpose(self) -> SharedLayout: From 9dd3e2b679e5e83ff96f2e7e5cef6342ec722df2 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 07:24:30 +0000 Subject: [PATCH 13/17] fix Signed-off-by: Yaoyao Ding --- .../inference_rules/transform_shared.py | 5 +++- python/tilus/ir/layout/ops/shared_ops.py | 16 +++++++------ python/tilus/ir/layout/shared_layout.py | 24 ++++++++++++------- python/tilus/ir/layout/utils/cute.py | 5 +++- python/tilus/ir/tools/printer.py | 2 +- 5 files changed, 34 insertions(+), 18 deletions(-) diff --git a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py index cd653ced..98244436 100644 --- a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py @@ -38,7 +38,10 @@ def inference(ctx: LayoutInferenceContext, inst: SliceSharedInst) -> dict[Shared outer_shape = [] for i in range(len(a.shape)): outer_shape.append(a.shape[i] // b_layout.shape[i]) - return {a: shared_compose(shared_row_major(*outer_shape), b_layout).apply_swizzle(b.layout.swizzle)} + layout = shared_compose(shared_row_major(*outer_shape), b_layout) + if b.layout.optional_swizzle is not None: + layout = layout.apply_swizzle(b.layout.swizzle) + return {a: layout} else: return {} diff --git a/python/tilus/ir/layout/ops/shared_ops.py b/python/tilus/ir/layout/ops/shared_ops.py index dde80b68..bf3cba51 100644 --- a/python/tilus/ir/layout/ops/shared_ops.py +++ b/python/tilus/ir/layout/ops/shared_ops.py @@ -61,7 +61,7 @@ def shared_row_major(*shape: int) -> SharedLayout: """ mode_shape = shape mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(range(len(mode_shape)))) - return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None) + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None) def shared_column_major(*shape: int) -> SharedLayout: @@ -79,7 +79,7 @@ def shared_column_major(*shape: int) -> SharedLayout: """ mode_shape = shape mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(reversed(range(len(mode_shape))))) - return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None) + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None) def shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout: @@ -118,7 +118,7 @@ def shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout: mode_strides.extend([stride * rhs_size for stride in (lhs.mode_strides[i] for i in lhs_group)]) mode_strides.extend([rhs.mode_strides[i] for i in rhs_group]) - return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None) + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None) def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: @@ -151,7 +151,9 @@ def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: mode_shape.extend([layout.mode_shape[i] for i in layout_mode_groups[d]]) mode_strides.extend([layout.mode_strides[i] for i in layout_mode_groups[d]]) - return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=layout.swizzle) + return shared_layout( + shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=layout.optional_swizzle + ) def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayout: @@ -184,7 +186,7 @@ def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayo shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, - swizzle=layout.swizzle, + optional_swizzle=layout.optional_swizzle, ) @@ -211,7 +213,7 @@ def shared_unsqueeze(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: shape=shape, mode_shape=layout.mode_shape, mode_strides=layout.mode_strides, - swizzle=layout.swizzle, + optional_swizzle=layout.optional_swizzle, ) @@ -373,7 +375,7 @@ def shared_row_major_swizzle(shape: Sequence[int], dtype_nbytes: int) -> SharedL shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, - swizzle=swizzle, + optional_swizzle=swizzle, ) diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index 9be5b067..c9c8dbe2 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -126,7 +126,10 @@ def swizzle(self) -> Swizzle: @staticmethod def create( - shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequence[int], swizzle: Optional[Swizzle] + shape: Sequence[int], + mode_shape: Sequence[int], + mode_strides: Sequence[int], + optional_swizzle: Optional[Swizzle], ) -> SharedLayout: """ Create a SharedLayout from shape, mode_shape, and mode_strides. @@ -154,7 +157,10 @@ def create( if prod(mode_shape) != prod(shape): raise ValueError("The product of mode_shape must equal to the product of shape.") return SharedLayout( - shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), optional_swizzle=swizzle + shape=tuple(shape), + mode_shape=tuple(mode_shape), + mode_strides=tuple(mode_strides), + optional_swizzle=optional_swizzle, ) def as_numpy_grid(self) -> np.ndarray: @@ -195,7 +201,7 @@ def apply_swizzle(self, swizzle: Swizzle) -> SharedLayout: shape=self.shape, mode_shape=self.mode_shape, mode_strides=self.mode_strides, - swizzle=swizzle, + optional_swizzle=swizzle, ) def prepend_dim(self, extent: int) -> SharedLayout: @@ -211,7 +217,7 @@ def prepend_dim(self, extent: int) -> SharedLayout: shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, - swizzle=self.optional_swizzle, + optional_swizzle=self.optional_swizzle, ) def transpose(self) -> SharedLayout: @@ -238,7 +244,7 @@ def shared_layout( shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequence[int], - swizzle: Optional[Swizzle] = None, + optional_swizzle: Optional[Swizzle] = None, ) -> SharedLayout: """Create a SharedLayout from shape, mode_shape, and mode_strides. @@ -270,7 +276,9 @@ def shared_layout( mode_strides = updated_mode_strides # canonicalize swizzle: if swizzle has 0 bits, set it to None (both mean no swizzle) - if swizzle is not None and swizzle.bits == 0: - swizzle = None + if optional_swizzle is not None and optional_swizzle.bits == 0: + optional_swizzle = None - return SharedLayout.create(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=swizzle) + return SharedLayout.create( + shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=optional_swizzle + ) diff --git a/python/tilus/ir/layout/utils/cute.py b/python/tilus/ir/layout/utils/cute.py index d66ed4f9..09327b16 100644 --- a/python/tilus/ir/layout/utils/cute.py +++ b/python/tilus/ir/layout/utils/cute.py @@ -156,7 +156,10 @@ def flatten_int_tuple(t: IntTuple) -> list[Int]: mode_strides = [int(s) for s in flat_strides] return shared_layout( - shape=tensor_shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=self.swizzle.as_swizzle() + shape=tensor_shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + optional_swizzle=self.swizzle.as_swizzle(), ) diff --git a/python/tilus/ir/tools/printer.py b/python/tilus/ir/tools/printer.py index a37cb634..d94d33d9 100644 --- a/python/tilus/ir/tools/printer.py +++ b/python/tilus/ir/tools/printer.py @@ -452,7 +452,7 @@ def visit_SharedLayout(self, node: SharedLayout) -> Doc: "shape=[" + self(node.shape) + "]", "mode_shape=[" + self(node.mode_shape) + "]", "mode_strides=[" + self(node.mode_strides) + "]", - "swizzle=" + (str(node.swizzle) if node.swizzle is not None else "None"), + "swizzle=" + (str(node.swizzle) if node.optional_swizzle is not None else "None"), ] doc = Text("SharedLayout(") + doc_join(items, ", ") + ")" return self.add_key_comment("shared_layout", doc) From b60df1909b1c20cd788247c98e96bb8785e02cdb Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 07:39:11 +0000 Subject: [PATCH 14/17] fix Signed-off-by: Yaoyao Ding --- .../tilus/ir/layout/inference/inference_rules/wgmma.py | 10 ++++++---- tests/examples/test_examples.py | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py index 630c99f6..f887fc34 100644 --- a/python/tilus/ir/layout/inference/inference_rules/wgmma.py +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -45,7 +45,9 @@ def generate_wgmma_register_layout(m: int, n: int, inst_m: int, inst_n: int, ins @register_rule(WgmmaMmaSSInst) class WgmmaMmaSSRule(LayoutInferenceRule): @staticmethod - def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedTensor, SharedLayout]: + def inference( + ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst + ) -> dict[SharedTensor | RegisterTensor, SharedLayout | RegisterLayout]: a_tensor: SharedTensor = inst.inputs[0].as_shared_tensor() b_tensor: SharedTensor = inst.inputs[1].as_shared_tensor() d_tensor: RegisterTensor = inst.inputs[2].as_register_tensor() @@ -64,7 +66,7 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT ) m, n, k = d_shape[0], d_shape[1], a_shape[1] - ret = {} + ret: dict[SharedTensor | RegisterTensor, SharedLayout | RegisterLayout] = {} if not a_tensor.has_layout(): for swizzle_mode in [ Tcgen05SwizzleMode.B128_SWIZZLE, @@ -76,7 +78,7 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT a_layout_canonical = generate_canonical_layout( shape=(m, k), dtype=a_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[a_tensor] = a_layout_canonical.as_shared_layout().simplify() + ret[a_tensor] = a_layout_canonical.as_shared_layout() except ValueError: continue else: @@ -92,7 +94,7 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT b_layout_canonical = generate_canonical_layout( shape=(n, k), dtype=b_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]).simplify() + ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]) except ValueError: continue else: diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index c284d40a..28881c0a 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -23,7 +23,7 @@ from typing import Optional import pytest -from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90, nvgpu_sm100a +from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90a, nvgpu_sm100a # Get the project root directory PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -55,9 +55,9 @@ ("blackwell_matmul", "matmul_v4.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v5.py", nvgpu_sm100a), # hopper matmul example (SM 9.0) - ("hopper_matmul", "matmul_v0.py", nvgpu_sm90), - ("hopper_matmul", "matmul_v1.py", nvgpu_sm90), - ("hopper_matmul", "matmul_v2.py", nvgpu_sm90), + ("hopper_matmul", "matmul_v0.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v1.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v2.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+) From f905bc59be1d1a8ea117d2321210f981fcf25f9b Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 07:48:56 +0000 Subject: [PATCH 15/17] fix Signed-off-by: Yaoyao Ding --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce387d71..f31578f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: pass_filenames: false - id: mypy name: MyPy type checking - entry: mypy + entry: bash -c 'mypy --version >&2; mypy "$@"' -- language: system types_or: [python, pyi] args: [--show-error-codes, --show-error-context] From 474821bbd7fe2f92d36f8433090df3b5bd53eb67 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 08:09:37 +0000 Subject: [PATCH 16/17] fix Signed-off-by: Yaoyao Ding --- python/tilus/backends/emitters/cuda/wgmma.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py index 4435543f..d2e891a3 100644 --- a/python/tilus/backends/emitters/cuda/wgmma.py +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -180,7 +180,7 @@ def emit_wgmma(self, inst: WgmmaMmaSSInst) -> None: a_desc.encoded(), d_register_addr + d_offset, b_desc.encoded(), - trans_a=0, - trans_b=0, + trans_a=0, # type: ignore + trans_b=0, # type: ignore ) ) From 5c5ca583e4540c50288b465bcdefafd925ac801f Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 8 Dec 2025 08:24:03 +0000 Subject: [PATCH 17/17] format & lint Signed-off-by: Yaoyao Ding --- .pre-commit-config.yaml | 1 - python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f31578f0..cb99c2e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,6 @@ repos: types_or: [python, pyi] args: [--show-error-codes, --show-error-context] files: ^(python|examples|tests)/ - exclude: ^python/tilus/extensions/hidet/ require_serial: true - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version diff --git a/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py b/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py index 81a988a1..180a08e0 100644 --- a/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py +++ b/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py @@ -207,9 +207,8 @@ def mbarrier_arrive_shared(mbarrier_addr: Expr, count: Expr | int) -> Expr: -------- mbarrier.arrive : PTX ISA documentation section 9.7.13.15.13 """ - if isinstance(count, int): - count = u32(count) - return call_primitive_func("cuda_mbarrier_arrive_shared", args=[mbarrier_addr, count]) + count_expr = count if isinstance(count, Expr) else u32(count) + return call_primitive_func("cuda_mbarrier_arrive_shared", args=[mbarrier_addr, count_expr]) def mbarrier_arrive_remote_shared(