From c73b7aacb00142308787bdeaca8d431659eb0aaa Mon Sep 17 00:00:00 2001 From: Balint Rozsa Date: Fri, 15 Aug 2025 18:00:38 -0400 Subject: [PATCH 1/3] Rework memref.alloc and support memory spaces --- src/pydsl/gpu.py | 17 +++ src/pydsl/memref.py | 222 +++++++++++++++++------------------ tests/e2e/test_algorithms.py | 9 +- tests/e2e/test_linalg.py | 2 +- tests/e2e/test_memref.py | 81 ++++++------- 5 files changed, 165 insertions(+), 166 deletions(-) create mode 100644 src/pydsl/gpu.py diff --git a/src/pydsl/gpu.py b/src/pydsl/gpu.py new file mode 100644 index 0000000..4a75d4f --- /dev/null +++ b/src/pydsl/gpu.py @@ -0,0 +1,17 @@ +from enum import Enum + +from mlir.ir import AttrBuilder +import mlir.dialects.gpu as gpu + +from pydsl.memref import MemorySpace + + +class GPU_AddrSpace(MemorySpace, Enum): + Global = gpu.AddressSpace.Global + Workgroup = gpu.AddressSpace.Workgroup + Private = gpu.AddressSpace.Private + + def lower(self): + return ( + AttrBuilder.get("GPU_AddressSpaceAttr")(self.value, context=None), + ) diff --git a/src/pydsl/memref.py b/src/pydsl/memref.py index 9914885..a51cfe3 100644 --- a/src/pydsl/memref.py +++ b/src/pydsl/memref.py @@ -1,9 +1,11 @@ import ast import ctypes import typing + from collections.abc import Callable, Iterable from ctypes import POINTER, c_void_p from dataclasses import dataclass +from enum import Enum from functools import cache from typing import TYPE_CHECKING, Final @@ -25,7 +27,6 @@ canonicalize_args, SubtreeOut, ToMLIRBase, - lower, ) from pydsl.type import ( Index, @@ -55,6 +56,31 @@ Shape = typing.TypeVarTuple("Shape") +class MemorySpace(Enum): + """ + Superclass for memory spaces. Mostly used for making type-hints nicer and + type checking. It would probably make more sense for this to be an Enum + with 1 element, or an ABC, but neither of those are possible, so we use an + Enum with 0 elements instead. + + Subclasses should implement lower to define what Attribute object in MLIR + they correspond to. lower_class is only defined so that this is considered + a Lowerable. + """ + + def lower_class(cls): + raise AssertionError( + f"class of {cls.__qualname__} cannot be lowered, only its " + f"instances" + ) + + def lower(self) -> tuple[mlir.Attribute | None]: + raise AssertionError( + "MemorySpace.lower should never be called, subclasses of " + "MemorySpace should define their own lower methods" + ) + + @dataclass class RankedMemRefDescriptor: """ @@ -372,13 +398,12 @@ class MemRef(typing.Generic[DType, *Shape], UsesRMRD): bytes. """ - # Only strides is allowed to be None. If anything else remains None, - # something has gone wrong. value: Value - shape: tuple[int] = None - element_type: Lowerable = None - offset: int = None - strides: tuple[int] | None = None + shape: tuple[int] + element_type: Lowerable + offset: int + strides: tuple[int] | None + memory_space: MemorySpace | None _default_subclass_name = "AnnonymousMemRefSubclass" _supported_mlir_type = [ @@ -399,7 +424,8 @@ def class_factory( *, offset: int = 0, strides: tuple[int] | None = None, - name=_default_subclass_name, + memory_space: MemorySpace | None = None, + name: str = _default_subclass_name, ): """ Create a new subclass of MemRef with the specified dimensions and type. @@ -428,6 +454,15 @@ def class_factory( f"MemRef requires shape to be iterable, got {type(shape)}" ) + if strides is not None: + strides = tuple(strides) + + if not isinstance(memory_space, (MemorySpace, type(None))): + raise TypeError( + f"MemRef memory_space must be an instance of MemorySpace or ", + f"None, got {type(memory_space)}", + ) + return type( name, (MemRef,), @@ -435,7 +470,8 @@ def class_factory( "shape": tuple(shape), "element_type": element_type, "offset": int(offset), - "strides": None if strides is None else tuple(strides), + "strides": strides, + "memory_space": memory_space, }, ) @@ -635,19 +671,25 @@ def lower_class(cls) -> tuple[mlir.Type]: e.add_note(f"hint: class name is {clsname}") raise e - if cls.strides is None: - return ( - MemRefType.get( - list(cls.shape), lower_single(cls.element_type) - ), - ) - else: - layout = StridedLayoutAttr.get(cls.offset, list(cls.strides)) - return ( - MemRefType.get( - list(cls.shape), lower_single(cls.element_type), layout - ), - ) + layout = ( + None + if cls.strides is None + else StridedLayoutAttr.get(cls.offset, list(cls.strides)) + ) + memory_space = ( + None + if cls.memory_space is None + else lower_single(cls.memory_space) + ) + + return ( + MemRefType.get( + list(cls.shape), + lower_single(cls.element_type), + layout, + memory_space, + ), + ) @property def runtime_shape(self) -> RuntimeMemrefShape: @@ -671,122 +713,74 @@ def runtime_shape(self) -> RuntimeMemrefShape: MemRefFactory = MemRef.class_factory -def verify_memory(mem: MemRef): - if not isinstance(mem, MemRef): - raise TypeError( - f"the type being allocated must be a subclass of MemRef, got {mem}" - ) - - -def verify_memory_type(mtype: type[MemRef]): - if not issubclass(mtype, MemRef): - raise TypeError( - f"the type being allocated must be a subclass of MemRef, got " - f"{mtype}" - ) - - -def verify_dynamic_sizes(mtype: type[MemRef], dynamic_sizes: Tuple) -> None: - dynamic_sizes = lower(dynamic_sizes) - - # TODO: does this check do anything, since lower returns a tuple? - if not isinstance(dynamic_sizes, Iterable): - raise TypeError(f"{repr(dynamic_sizes)} is not iterable") - - if (actual_dyn := len(dynamic_sizes)) != ( - target_dyn := mtype.shape.count(DYNAMIC) - ): - raise ValueError( - f"MemRef has {target_dyn} dynamic dimensions to be filled, " - f"but alloc/alloca received {actual_dyn}" - ) - - -def verify_dynamic_symbols( - mtype: type[MemRef], dynamic_symbols: Tuple -) -> None: - dynamic_symbols = lower(dynamic_symbols) - - # TODO: does this check do anything, since lower returns a tuple? - if not isinstance(dynamic_symbols, Iterable): - raise TypeError(f"{repr(dynamic_symbols)} is not iterable") - - if (actual_dyn := len(dynamic_symbols)) != ( - target_dyn := 0 - if mtype.strides is None - else mtype.strides.count(DYNAMIC) - ): - raise ValueError( - f"MemRef has {target_dyn} dynamic strides to be filled, " - f"but alloc/alloca received {actual_dyn}" - ) - - def _alloc_generic( visitor: ToMLIRBase, - mtype: Compiled, - dynamic_sizes: Compiled, - dynamic_symbols: Compiled, - alloc_func: Callable[..., SubtreeOut], + alloc_func: Callable, + shape: Compiled, + dtype: Evaluated, + memory_space: MemorySpace | None = None, ) -> SubtreeOut: """ - Does the logic required for alloc/alloca. It was silly having - two functions that differed by only one character. alloc_func - should be memref.alloc or memref.alloca. + Does the logic required for alloc/alloca. It was silly having two functions + that differed by only one character. alloc_func should be memref.alloc or + memref.alloca. Currently only supports allocating non-strided MemRefs of + default layout. """ - if dynamic_sizes is None: - dynamic_sizes = Tuple.from_values(visitor, *()) - - if dynamic_symbols is None: - dynamic_symbols = Tuple.from_values(visitor, *()) - - verify_memory_type(mtype) - - if mtype.strides is not None: - raise NotImplementedError( - "allocating MemRefs with a strided layout is currently" - "not supported, since there seems to be no way to lower" - "the resulting MLIR to LLVMIR. Allocate a MemRef with" - "consecutive memory instead" + # NOTE: the dynamic_symbols parameter of memref.alloc is relevant for + # allocating MemRefs with an affine map layout. MLIR also supports + # allocating a strided MemRef, you simply change m_type to be a strided + # MemRef type. However, it seems we don't know how to lower such + # allocations from MLIR -> LLVMIR, so this feature is not implemented now. + # If this feature is implemented in the future, you can steal + # test_alloca_strided test case from an older version of test_memref.py + # (although that uses slightly different syntax). + + if not isinstance(shape, Tuple): + raise TypeError( + f"shape should be a Tuple, got {type(shape).__qualname__}" ) - verify_dynamic_sizes(mtype, dynamic_sizes) - verify_dynamic_symbols(mtype, dynamic_symbols) - dynamic_sizes = [lower_single(Index(i)) for i in lower(dynamic_sizes)] - dynamic_symbols = [lower_single(Index(i)) for i in lower(dynamic_symbols)] + shape = shape.as_iterable(visitor) + static_shape, dynamic_sizes = split_static_dynamic_dims(shape) + + m_type = MemRefFactory( + tuple(static_shape), dtype, memory_space=memory_space + ) - return mtype( - alloc_func(lower_single(mtype), dynamic_sizes, dynamic_symbols) + return m_type( + alloc_func(lower_single(m_type), lower_flatten(dynamic_sizes), []) ) @CallMacro.generate() def alloca( visitor: ToMLIRBase, - mtype: Compiled, - dynamic_sizes: Compiled = None, - dynamic_symbols: Compiled = None, + shape: Compiled, + dtype: Evaluated, + *, + memory_space: Evaluated = None, ) -> SubtreeOut: - return _alloc_generic( - visitor, mtype, dynamic_sizes, dynamic_symbols, memref.alloca - ) + return _alloc_generic(visitor, memref.alloca, shape, dtype, memory_space) @CallMacro.generate() def alloc( visitor: ToMLIRBase, - mtype: Compiled, - dynamic_sizes: Compiled = None, - dynamic_symbols: Compiled = None, + shape: Compiled, + dtype: Evaluated, + *, + memory_space: Evaluated = None, ) -> SubtreeOut: - return _alloc_generic( - visitor, mtype, dynamic_sizes, dynamic_symbols, memref.alloc - ) + return _alloc_generic(visitor, memref.alloc, shape, dtype, memory_space) @CallMacro.generate() -def dealloc(visitor: ToMLIRBase, mem: Evaluated) -> None: - verify_memory(mem) +def dealloc(visitor: ToMLIRBase, mem: Compiled) -> None: + if not isinstance(mem, MemRef): + raise TypeError( + f"the type being deallocated must be a MemRef, got {type(mem)}" + ) + return memref.dealloc(lower_single(mem)) diff --git a/tests/e2e/test_algorithms.py b/tests/e2e/test_algorithms.py index eb183d7..b9223a7 100644 --- a/tests/e2e/test_algorithms.py +++ b/tests/e2e/test_algorithms.py @@ -23,7 +23,7 @@ tile, ) from pydsl.transform import match_tag as match -from pydsl.type import F32, F64, AnyOp, Index, Tuple +from pydsl.type import AnyOp, F32, F64, Index, Number, Tuple from helper import multi_arange, run MemRefRank2F32 = MemRefFactory((DYNAMIC, DYNAMIC), F32) @@ -334,6 +334,9 @@ def test_softmax(): N = 64 M = 2048 neg_inf = -math.inf + # TODO: remove once we support constexprs better + N_num = Number(N) + M_num = Number(M) @InlineFunction.generate() def _add(a, b) -> Any: @@ -341,11 +344,11 @@ def _add(a, b) -> Any: @compile() def softmax_memref(arr: MemRef[F64, N, M]) -> MemRef[F64, N, M]: - reduce_res = alloca(MemRef[F64, N]) + reduce_res = alloca((N_num,), F64) linalg.fill(reduce_res, neg_inf) linalg.reduce(arith.max, arr, init=reduce_res, dims=[1]) - mx = alloca(MemRef[F64, N, M]) + mx = alloca((N_num, M_num), F64) linalg.broadcast(reduce_res, out=mx, dims=[1]) linalg.sub(arr, mx, out=arr) diff --git a/tests/e2e/test_linalg.py b/tests/e2e/test_linalg.py index d640ced..813b327 100644 --- a/tests/e2e/test_linalg.py +++ b/tests/e2e/test_linalg.py @@ -385,7 +385,7 @@ def combine(a, b) -> typing.Any: @compile() def f(arr: MemRef[UInt8, 4, 4]) -> UInt64: - out = alloca(MemRef[UInt64]) + out = alloca((), UInt64) linalg.reduce(combine, arr, init=out, dims=[0, 1]) return out[()] diff --git a/tests/e2e/test_memref.py b/tests/e2e/test_memref.py index f11a90b..a48fc00 100644 --- a/tests/e2e/test_memref.py +++ b/tests/e2e/test_memref.py @@ -6,8 +6,9 @@ from pydsl.affine import affine_range as arange from pydsl.frontend import compile +from pydsl.gpu import GPU_AddrSpace import pydsl.linalg as linalg -from pydsl.memref import alloc, alloca, DYNAMIC, Dynamic, MemRef, MemRefFactory +from pydsl.memref import alloc, alloca, dealloc, DYNAMIC, MemRef, MemRefFactory from pydsl.type import Bool, F32, F64, Index, SInt16, Tuple, UInt32 from helper import compilation_failed_from, failed_from, multi_arange, run @@ -100,7 +101,7 @@ def f(m: MemRef[F64, 2, 2]) -> Tuple[MemRef[F64, 2, 2]]: def test_alloca_scalar(): @compile() def f() -> UInt32: - m_scalar = alloca(MemRef[UInt32, 1]) + m_scalar = alloca((1,), UInt32) for i in arange(10): m_scalar[0] = i @@ -113,7 +114,7 @@ def f() -> UInt32: def test_alloca_dynamic(): @compile() def f(a: Index, b: Index) -> Tuple[UInt32, UInt32]: - m = alloca(MemRef[UInt32, Dynamic, 2, Dynamic], (a, b)) + m = alloca((a, 2, b), UInt32) m[0, 0, 0] = 1 m[a - 1, 1, b - 1] = 2 @@ -127,62 +128,46 @@ def f(a: Index, b: Index) -> Tuple[UInt32, UInt32]: def test_alloc_scalar(): @compile() def f() -> MemRef[UInt32, 1]: - m_scalar = alloc(MemRef[UInt32, 1]) + m_scalar = alloc((1,), UInt32) m_scalar[0] = 1 return m_scalar assert (f() == np.asarray([1], dtype=np.uint32)).all() -def test_alloc_wrong_dynamic_sizes(): - with compilation_failed_from(ValueError): - # 2 dynamic dimensions, 1 specified - @compile() - def f1(a: Index): - m1 = alloc(MemRef[F32, 4, 5, DYNAMIC, DYNAMIC], (a,)) - - with compilation_failed_from(ValueError): - # 1 dynamic dimension, 3 specified - @compile() - def f2(a: Index, b: Index, c: Index): - m1 = alloc(MemRef[UInt32, DYNAMIC, 8], (a, b, c)) +def test_dealloc(): + """ + Doesn't really test whether the memref is deallocated, since that's hard + to test, but does test that the syntax of dealloc works. + """ - with compilation_failed_from(ValueError): - # 0 dynamic dimensins, 2 specified - @compile() - def f3(a: Index, b: Index): - m1 = alloc(MemRef[SInt16, 4, 4], (a, b)) + @compile() + def f() -> SInt16: + m = alloc((3, Index(7)), SInt16) + m[1, 2] = 5 + res = m[1, 2] + dealloc(m) + return res - with compilation_failed_from(ValueError): - # 1 dynamic dimension, 0 specified - @compile() - def f4(): - m1 = alloc(MemRef[F64, DYNAMIC]) + assert f() == 5 + mlir = f.emit_mlir() + assert r"memref.dealloc" in mlir -def test_alloca_strided(): - # TODO: currently, it seems we cannot lower alloc/alloca - # calls that use memrefs of non-trivial layouts from MLIR -> LLVMIR. - # The code for generating the MLIR from Python exists (but cannot be tested). - # If we find the right MLIR pass for this lowering, we should make - # this code compile and add more tests that check alloc/alloca of - # strided MemRefs (e.g. dynamic strides, invalid dynamic_symbols etc.). - with compilation_failed_from(NotImplementedError): - MemRefStrided = MemRefFactory((2, 3), SInt16, strides=(4, 7)) +def test_alloc_memory_space(): + """ + We can't use GPU address spaces on CPU, so just test if it compiles and + the MLIR is reasonable. + """ - @compile() - def f() -> Tuple[SInt16, SInt16, SInt16]: - m1 = alloca(MemRefStrided) - m1[0, 0] = 5 - m1[0, 1] = 8 - m1[0, 2] = 40 - m1[1, 0] = 3 - m1[1, 1] = 5 - m1[1, 2] = -12 - return m1[0, 1], m1[0, 2], m1[1, 2] + @compile(auto_build=False) + def f(): + m = alloc((4, 2), F32, memory_space=GPU_AddrSpace.Global) + dealloc(m) - assert f() == (8, 40, -12) + mlir = f.emit_mlir() + assert r"memref<4x2xf32, #gpu.address_space>" in mlir def test_load_strided(): @@ -459,8 +444,8 @@ def f(t1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]: run(test_alloca_scalar) run(test_alloca_dynamic) run(test_alloc_scalar) - run(test_alloc_wrong_dynamic_sizes) - run(test_alloca_strided) + run(test_dealloc) + run(test_alloc_memory_space) run(test_load_strided) run(test_load_strided_big) run(test_load_strided_wrong) From da0e34d1259f71180bbfcc7b8ce2994e6ce138da Mon Sep 17 00:00:00 2001 From: Balint Rozsa Date: Wed, 20 Aug 2025 11:28:03 -0400 Subject: [PATCH 2/3] Implement tensor and memref casting --- src/pydsl/memref.py | 94 +++++++++++++++++++++++++++++++++++++++- src/pydsl/tensor.py | 41 +++++++++++++++++- tests/e2e/test_memref.py | 57 ++++++++++++++++++++++++ tests/e2e/test_tensor.py | 27 ++++++++++++ 4 files changed, 217 insertions(+), 2 deletions(-) diff --git a/src/pydsl/memref.py b/src/pydsl/memref.py index a51cfe3..5048ab4 100644 --- a/src/pydsl/memref.py +++ b/src/pydsl/memref.py @@ -21,7 +21,7 @@ ) from pydsl.affine import AffineContext, AffineMapExpr, AffineMapExprWalk -from pydsl.macro import CallMacro, Compiled, Evaluated +from pydsl.macro import CallMacro, Compiled, Evaluated, MethodType from pydsl.protocols import ( ArgContainer, canonicalize_args, @@ -181,6 +181,18 @@ def are_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> bool: ) +def assert_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> None: + """ + Checks that non-dynamic dimensions of arr1 and arr2 match, throws a + ValueError if not. + """ + if not are_shapes_compatible(arr1, arr2): + raise ValueError( + f"incompatible shapes: {repr(arr1)} and {repr(arr2)}, non-dynamic " + f"dimensions must be equal" + ) + + class UsesRMRD: """ A mixin class for adding CType support for classes that eventually lower @@ -658,6 +670,86 @@ def on_class_getitem( # Equivalent to cls.__class_getitem__(args) return cls[args] + @CallMacro.generate(method_type=MethodType.INSTANCE) + def cast( + visitor: ToMLIRBase, + self: typing.Self, + shape: Evaluated = None, + *, + offset: Evaluated = None, + strides: Evaluated = (-1,), + ) -> typing.Self: + """ + Converts a memref from one type to an equivalent type with a compatible + shape. The source and destination types are compatible if all of the + following are true: + - Both are ranked memref types with the same element type, address + space, and rank. + - Both have the same layout or both have compatible strided layouts. + - The individual sizes (resp. offset and strides in the case of strided + memrefs) may convert constant dimensions to dynamic dimensions and + vice-versa. + + If the cast converts any dimensions from an unknown to a known size, + then it acts as an assertion that fails at runtime if the dynamic + dimensions disagree with the resultant destination size (i.e. it is + illegal to do a conversion that causes a mismatch, and it would invoke + undefined behaviour). + + Note: a new memref with the new type is returned, the type of the + original memref is not modified. + + Example: + ``` + def f(m1: MemRef[F32, DYNAMIC, 32, 5]) -> MemRef[F32, 64, 32, DYNAMIC]: + # Only valid if the first dimension of m1 is always 64 + m2 = m1.cast((64, 32, DYNAMIC)) + return m2 + ``` + """ + # NOTE: default value of strides is (-1,) because None is already used + # to represent default layout... This is definitely not a great + # solution, but it's unclear how to do this better. + + shape = tuple(shape) if shape is not None else self.shape + offset = int(offset) if offset is not None else self.offset + strides = ( + None + if strides is None + else self.strides + if strides == (-1,) + else tuple(strides) + ) + + if not all(isinstance(x, int) for x in shape): + raise ValueError( + f"shape should be a tuple of integers known at compile time ", + f"got {repr(shape)}", + ) + + if not (strides is None or all(isinstance(x, int) for x in strides)): + raise ValueError( + f"strides should be a tuple of integers known at compile ", + f"time, got {repr(strides)}", + ) + + assert_shapes_compatible(self.shape, shape) + assert_shapes_compatible([self.offset], [offset]) + + # TODO: also do a type check if one of the MemRefs is default layout. + # This is a reasonably complicated check, and even MLIR doesn't do it + # properly. E.g. mlir-opt allows + # `memref<3x4xf64, strided<[8, 1], offset: ?>> to memref` + # to compile, even though it is impossible for this to be correct. + if self.strides is not None and strides is not None: + assert_shapes_compatible(self.strides, strides) + + result_type = self.class_factory( + shape, self.element_type, offset=offset, strides=strides + ) + rep = memref.cast(lower_single(result_type), lower_single(self)) + return result_type(rep) + def lower(self) -> tuple[Value]: return (self.value,) diff --git a/src/pydsl/tensor.py b/src/pydsl/tensor.py index 18c9765..6854657 100644 --- a/src/pydsl/tensor.py +++ b/src/pydsl/tensor.py @@ -8,8 +8,9 @@ from mlir.ir import DenseI64ArrayAttr, OpView, RankedTensorType, Value from pydsl.func import InlineFunction -from pydsl.macro import CallMacro, Compiled, Evaluated +from pydsl.macro import CallMacro, Compiled, Evaluated, MethodType from pydsl.memref import ( + assert_shapes_compatible, UsesRMRD, RuntimeMemrefShape, slices_to_mlir_format, @@ -275,6 +276,44 @@ def on_class_getitem( # Equivalent to cls.__class_getitem__(args) return cls[args] + @CallMacro.generate(method_type=MethodType.INSTANCE) + def cast( + visitor: ToMLIRBase, self: typing.Self, shape: Evaluated + ) -> typing.Self: + """ + Convert a tensor from one type to an equivalent type without changing + any data elements. The resulting tensor type will have the same element + type. shape is the shape of the new tensor and must be known at compile + time. For any constant dimensions of shape, the input tensor must + actually have that dimension at runtime, otherwise the operation is + invalid. + + Note: this function only returns a tensor with the updated type, it + does not modify the type of the input tensor. + + Example: + ``` + def f(t1: Tensor[F32, DYNAMIC, 32, 5]) -> Tensor[F32, 64, 32, DYNAMIC]: + # Only valid if the first dimension of t1 is always 64 + t2 = t1.cast((64, 32, DYNAMIC)) + return t2 + ``` + """ + + shape = tuple(shape) + + if not all(isinstance(x, int) for x in shape): + raise TypeError( + f"shape should be a tuple of integers known at compile time ", + f"got {repr(shape)}", + ) + + assert_shapes_compatible(self.shape, shape) + + result_type = self.class_factory(shape, self.element_type) + rep = tensor.cast(lower_single(result_type), lower_single(self)) + return result_type(rep) + # Convenient alias TensorFactory = Tensor.class_factory diff --git a/tests/e2e/test_memref.py b/tests/e2e/test_memref.py index a48fc00..92dd050 100644 --- a/tests/e2e/test_memref.py +++ b/tests/e2e/test_memref.py @@ -434,6 +434,60 @@ def f(t1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]: assert res2.shape == () +def test_cast_basic(): + @compile() + def f( + m1: MemRef[F32, 10, DYNAMIC, 30], + ) -> MemRef[F32, DYNAMIC, 20, DYNAMIC]: + m1 = m1.cast((DYNAMIC, 20, 30), strides=(600, 30, 1)) + m1 = m1.cast((DYNAMIC, DYNAMIC, DYNAMIC)) + m1 = m1.cast((10, 20, 30), strides=None) + m1 = m1.cast((DYNAMIC, 20, DYNAMIC)) + return m1 + + n1 = multi_arange((10, 20, 30), dtype=np.float32) + cor_res = n1.copy() + assert (f(n1) == cor_res).all() + + +def test_cast_strided(): + MemRef1 = MemRef.get((8, DYNAMIC), SInt16, offset=10, strides=(1, 8)) + MemRef2 = MemRef.get((8, 4), SInt16, offset=DYNAMIC, strides=(DYNAMIC, 8)) + + @compile() + def f(m1: MemRef1) -> MemRef2: + m1.cast(strides=(1, DYNAMIC)) + m1 = m1.cast((8, 4), offset=DYNAMIC, strides=(DYNAMIC, 8)) + return m1 + + i16_sz = np.int16().nbytes + n1 = multi_arange((8, 4), np.int16) + n1 = as_strided(n1, shape=(8, 4), strides=(i16_sz, 8 * i16_sz)) + cor_res = n1.copy() + assert (f(n1) == cor_res).all() + + +def test_cast_bad(): + with compilation_failed_from(ValueError): + + @compile() + def f1(m1: MemRef[UInt32, 5, 8]): + m1.cast((5, 8, 7)) + + with compilation_failed_from(ValueError): + + @compile() + def f2(m1: MemRef[F64, DYNAMIC, 4]): + m1.cast((DYNAMIC, 5)) + + with compilation_failed_from(ValueError): + MemRef1 = MemRef.get((8, DYNAMIC), F32, offset=10, strides=(1, 8)) + + @compile() + def f3(m1: MemRef1): + m1.cast(strides=(DYNAMIC, 16)) + + if __name__ == "__main__": run(test_load_implicit_index_uint32) run(test_load_implicit_index_f64) @@ -457,3 +511,6 @@ def f(t1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]: run(test_link_ndarray) run(test_chain_link_ndarray) run(test_zero_d) + run(test_cast_basic) + run(test_cast_strided) + run(test_cast_strided) diff --git a/tests/e2e/test_tensor.py b/tests/e2e/test_tensor.py index b9d63bc..191a668 100644 --- a/tests/e2e/test_tensor.py +++ b/tests/e2e/test_tensor.py @@ -348,6 +348,31 @@ def f() -> Tensor[F64, DYNAMIC, DYNAMIC]: assert (test_res == cor_res).all() +def test_cast(): + @compile() + def f(t1: Tensor[F32, DYNAMIC, 32, 5]) -> Tensor[F32, 64, 32, DYNAMIC]: + t2 = t1.cast((64, 32, DYNAMIC)) + return t2 + + n1 = multi_arange((64, 32, 5), np.float32) + cor_res = n1.copy() + assert (f(n1) == cor_res).all() + + +def test_cast_bad(): + with compilation_failed_from(ValueError): + + @compile() + def f1(t1: Tensor[SInt32, DYNAMIC, 13]): + t1.cast((7, 13, DYNAMIC)) + + with compilation_failed_from(ValueError): + + @compile() + def f2(t1: Tensor[UInt64, 3, 7]): + t1.cast((1, 21)) + + if __name__ == "__main__": run(test_wrong_dim) run(test_load) @@ -372,3 +397,5 @@ def f() -> Tensor[F64, DYNAMIC, DYNAMIC]: run(test_full) run(test_zeros) run(test_ones) + run(test_cast) + run(test_cast_bad) From edf1630a318ac51ad3ddb844690f03cb826b9a41 Mon Sep 17 00:00:00 2001 From: jamesthejellyfish Date: Tue, 20 Jan 2026 11:13:49 -0500 Subject: [PATCH 3/3] Refactor test_alloc_bad_align by removing assertions Remove assertions from test_alloc_bad_align function. --- tests/e2e/test_memref.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/e2e/test_memref.py b/tests/e2e/test_memref.py index 0e8686e..8b0b94f 100644 --- a/tests/e2e/test_memref.py +++ b/tests/e2e/test_memref.py @@ -182,20 +182,15 @@ def f() -> MemRef[F64, 4, 6]: def test_alloc_bad_align(): with compilation_failed_from(TypeError): - @compile() def f(): alloc((4, 6), F64, alignment="xyz") with compilation_failed_from(ValueError): - @compile() def f(): alloc((4, 6), F64, alignment=-123) - mlir = f.emit_mlir() - assert r"memref.dealloc" in mlir - def test_slice_memory_space(): """ We can't use GPU address spaces on CPU, so just test if it compiles.