diff --git a/src/pydsl/memref.py b/src/pydsl/memref.py index 157fd0e..f264a51 100644 --- a/src/pydsl/memref.py +++ b/src/pydsl/memref.py @@ -21,7 +21,7 @@ ) from pydsl.affine import AffineContext, AffineMapExpr, AffineMapExprWalk -from pydsl.macro import CallMacro, Compiled, Evaluated +from pydsl.macro import CallMacro, Compiled, Evaluated, MethodType from pydsl.protocols import ( ArgContainer, canonicalize_args, @@ -181,6 +181,18 @@ def are_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> bool: ) +def assert_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> None: + """ + Checks that non-dynamic dimensions of arr1 and arr2 match, throws a + ValueError if not. + """ + if not are_shapes_compatible(arr1, arr2): + raise ValueError( + f"incompatible shapes: {repr(arr1)} and {repr(arr2)}, non-dynamic " + f"dimensions must be equal" + ) + + class UsesRMRD: """ A mixin class for adding CType support for classes that eventually lower @@ -679,6 +691,86 @@ def on_class_getitem( # Equivalent to cls.__class_getitem__(args) return cls[args] + @CallMacro.generate(method_type=MethodType.INSTANCE) + def cast( + visitor: ToMLIRBase, + self: typing.Self, + shape: Evaluated = None, + *, + offset: Evaluated = None, + strides: Evaluated = (-1,), + ) -> typing.Self: + """ + Converts a memref from one type to an equivalent type with a compatible + shape. The source and destination types are compatible if all of the + following are true: + - Both are ranked memref types with the same element type, address + space, and rank. + - Both have the same layout or both have compatible strided layouts. + - The individual sizes (resp. offset and strides in the case of strided + memrefs) may convert constant dimensions to dynamic dimensions and + vice-versa. + + If the cast converts any dimensions from an unknown to a known size, + then it acts as an assertion that fails at runtime if the dynamic + dimensions disagree with the resultant destination size (i.e. it is + illegal to do a conversion that causes a mismatch, and it would invoke + undefined behaviour). + + Note: a new memref with the new type is returned, the type of the + original memref is not modified. + + Example: + ``` + def f(m1: MemRef[F32, DYNAMIC, 32, 5]) -> MemRef[F32, 64, 32, DYNAMIC]: + # Only valid if the first dimension of m1 is always 64 + m2 = m1.cast((64, 32, DYNAMIC)) + return m2 + ``` + """ + # NOTE: default value of strides is (-1,) because None is already used + # to represent default layout... This is definitely not a great + # solution, but it's unclear how to do this better. + + shape = tuple(shape) if shape is not None else self.shape + offset = int(offset) if offset is not None else self.offset + strides = ( + None + if strides is None + else self.strides + if strides == (-1,) + else tuple(strides) + ) + + if not all(isinstance(x, int) for x in shape): + raise ValueError( + f"shape should be a tuple of integers known at compile time ", + f"got {repr(shape)}", + ) + + if not (strides is None or all(isinstance(x, int) for x in strides)): + raise ValueError( + f"strides should be a tuple of integers known at compile ", + f"time, got {repr(strides)}", + ) + + assert_shapes_compatible(self.shape, shape) + assert_shapes_compatible([self.offset], [offset]) + + # TODO: also do a type check if one of the MemRefs is default layout. + # This is a reasonably complicated check, and even MLIR doesn't do it + # properly. E.g. mlir-opt allows + # `memref<3x4xf64, strided<[8, 1], offset: ?>> to memref` + # to compile, even though it is impossible for this to be correct. + if self.strides is not None and strides is not None: + assert_shapes_compatible(self.strides, strides) + + result_type = self.class_factory( + shape, self.element_type, offset=offset, strides=strides + ) + rep = memref.cast(lower_single(result_type), lower_single(self)) + return result_type(rep) + def lower(self) -> tuple[Value]: return (self.value,) @@ -739,8 +831,8 @@ def _alloc_generic( alloc_func: Callable, shape: Compiled, dtype: Evaluated, - memory_space: MemorySpace | None, - alignment: int | None, + memory_space: MemorySpace | None = None, + alignment: int | None = None, ) -> SubtreeOut: """ Does the logic required for alloc/alloca. It was silly having two functions diff --git a/src/pydsl/tensor.py b/src/pydsl/tensor.py index 3495c93..4236a2b 100644 --- a/src/pydsl/tensor.py +++ b/src/pydsl/tensor.py @@ -8,8 +8,9 @@ from mlir.ir import DenseI64ArrayAttr, OpView, RankedTensorType, Value from pydsl.func import InlineFunction -from pydsl.macro import CallMacro, Compiled, Evaluated +from pydsl.macro import CallMacro, Compiled, Evaluated, MethodType from pydsl.memref import ( + assert_shapes_compatible, UsesRMRD, RuntimeMemrefShape, slices_to_mlir_format, @@ -275,6 +276,44 @@ def on_class_getitem( # Equivalent to cls.__class_getitem__(args) return cls[args] + @CallMacro.generate(method_type=MethodType.INSTANCE) + def cast( + visitor: ToMLIRBase, self: typing.Self, shape: Evaluated + ) -> typing.Self: + """ + Convert a tensor from one type to an equivalent type without changing + any data elements. The resulting tensor type will have the same element + type. shape is the shape of the new tensor and must be known at compile + time. For any constant dimensions of shape, the input tensor must + actually have that dimension at runtime, otherwise the operation is + invalid. + + Note: this function only returns a tensor with the updated type, it + does not modify the type of the input tensor. + + Example: + ``` + def f(t1: Tensor[F32, DYNAMIC, 32, 5]) -> Tensor[F32, 64, 32, DYNAMIC]: + # Only valid if the first dimension of t1 is always 64 + t2 = t1.cast((64, 32, DYNAMIC)) + return t2 + ``` + """ + + shape = tuple(shape) + + if not all(isinstance(x, int) for x in shape): + raise TypeError( + f"shape should be a tuple of integers known at compile time ", + f"got {repr(shape)}", + ) + + assert_shapes_compatible(self.shape, shape) + + result_type = self.class_factory(shape, self.element_type) + rep = tensor.cast(lower_single(result_type), lower_single(self)) + return result_type(rep) + # Convenient alias TensorFactory = Tensor.class_factory diff --git a/tests/e2e/test_memref.py b/tests/e2e/test_memref.py index e79bb70..8b0b94f 100644 --- a/tests/e2e/test_memref.py +++ b/tests/e2e/test_memref.py @@ -182,18 +182,15 @@ def f() -> MemRef[F64, 4, 6]: def test_alloc_bad_align(): with compilation_failed_from(TypeError): - @compile() def f(): alloc((4, 6), F64, alignment="xyz") with compilation_failed_from(ValueError): - @compile() def f(): alloc((4, 6), F64, alignment=-123) - def test_slice_memory_space(): """ We can't use GPU address spaces on CPU, so just test if it compiles. @@ -483,6 +480,59 @@ def f(m1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]: assert res2.shape == () +def test_cast_basic(): + @compile() + def f( + m1: MemRef[F32, 10, DYNAMIC, 30], + ) -> MemRef[F32, DYNAMIC, 20, DYNAMIC]: + m1 = m1.cast((DYNAMIC, 20, 30), strides=(600, 30, 1)) + m1 = m1.cast((DYNAMIC, DYNAMIC, DYNAMIC)) + m1 = m1.cast((10, 20, 30), strides=None) + m1 = m1.cast((DYNAMIC, 20, DYNAMIC)) + return m1 + + n1 = multi_arange((10, 20, 30), dtype=np.float32) + cor_res = n1.copy() + assert (f(n1) == cor_res).all() + + +def test_cast_strided(): + MemRef1 = MemRef.get((8, DYNAMIC), SInt16, offset=10, strides=(1, 8)) + MemRef2 = MemRef.get((8, 4), SInt16, offset=DYNAMIC, strides=(DYNAMIC, 8)) + + @compile() + def f(m1: MemRef1) -> MemRef2: + m1.cast(strides=(1, DYNAMIC)) + m1 = m1.cast((8, 4), offset=DYNAMIC, strides=(DYNAMIC, 8)) + return m1 + + i16_sz = np.int16().nbytes + n1 = multi_arange((8, 4), np.int16) + n1 = as_strided(n1, shape=(8, 4), strides=(i16_sz, 8 * i16_sz)) + cor_res = n1.copy() + assert (f(n1) == cor_res).all() + + +def test_cast_bad(): + with compilation_failed_from(ValueError): + + @compile() + def f1(m1: MemRef[UInt32, 5, 8]): + m1.cast((5, 8, 7)) + + with compilation_failed_from(ValueError): + + @compile() + def f2(m1: MemRef[F64, DYNAMIC, 4]): + m1.cast((DYNAMIC, 5)) + + with compilation_failed_from(ValueError): + MemRef1 = MemRef.get((8, DYNAMIC), F32, offset=10, strides=(1, 8)) + + @compile() + def f3(m1: MemRef1): + m1.cast(strides=(DYNAMIC, 16)) + def test_copy_basic(): @compile() def f(m1: MemRef[SInt16, 10, DYNAMIC], m2: MemRef[SInt16, 10, 10]): @@ -562,6 +612,10 @@ def f(m1: MemRef[F32, 5], m2: MemRef[F64, 5]): run(test_link_ndarray) run(test_chain_link_ndarray) run(test_zero_d) + run(test_cast_basic) + run(test_cast_strided) + run(test_cast_strided) + run(test_cast_bad) run(test_copy_basic) run(test_copy_overlap) run(test_copy_strided) diff --git a/tests/e2e/test_tensor.py b/tests/e2e/test_tensor.py index b9d63bc..191a668 100644 --- a/tests/e2e/test_tensor.py +++ b/tests/e2e/test_tensor.py @@ -348,6 +348,31 @@ def f() -> Tensor[F64, DYNAMIC, DYNAMIC]: assert (test_res == cor_res).all() +def test_cast(): + @compile() + def f(t1: Tensor[F32, DYNAMIC, 32, 5]) -> Tensor[F32, 64, 32, DYNAMIC]: + t2 = t1.cast((64, 32, DYNAMIC)) + return t2 + + n1 = multi_arange((64, 32, 5), np.float32) + cor_res = n1.copy() + assert (f(n1) == cor_res).all() + + +def test_cast_bad(): + with compilation_failed_from(ValueError): + + @compile() + def f1(t1: Tensor[SInt32, DYNAMIC, 13]): + t1.cast((7, 13, DYNAMIC)) + + with compilation_failed_from(ValueError): + + @compile() + def f2(t1: Tensor[UInt64, 3, 7]): + t1.cast((1, 21)) + + if __name__ == "__main__": run(test_wrong_dim) run(test_load) @@ -372,3 +397,5 @@ def f() -> Tensor[F64, DYNAMIC, DYNAMIC]: run(test_full) run(test_zeros) run(test_ones) + run(test_cast) + run(test_cast_bad)