Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 95 additions & 3 deletions src/pydsl/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

from pydsl.affine import AffineContext, AffineMapExpr, AffineMapExprWalk
from pydsl.macro import CallMacro, Compiled, Evaluated
from pydsl.macro import CallMacro, Compiled, Evaluated, MethodType
from pydsl.protocols import (
ArgContainer,
canonicalize_args,
Expand Down Expand Up @@ -181,6 +181,18 @@ def are_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> bool:
)


def assert_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> None:
"""
Checks that non-dynamic dimensions of arr1 and arr2 match, throws a
ValueError if not.
"""
if not are_shapes_compatible(arr1, arr2):
raise ValueError(
f"incompatible shapes: {repr(arr1)} and {repr(arr2)}, non-dynamic "
f"dimensions must be equal"
)


class UsesRMRD:
"""
A mixin class for adding CType support for classes that eventually lower
Expand Down Expand Up @@ -679,6 +691,86 @@ def on_class_getitem(
# Equivalent to cls.__class_getitem__(args)
return cls[args]

@CallMacro.generate(method_type=MethodType.INSTANCE)
def cast(
visitor: ToMLIRBase,
self: typing.Self,
shape: Evaluated = None,
*,
offset: Evaluated = None,
strides: Evaluated = (-1,),
) -> typing.Self:
"""
Converts a memref from one type to an equivalent type with a compatible
shape. The source and destination types are compatible if all of the
following are true:
- Both are ranked memref types with the same element type, address
space, and rank.
- Both have the same layout or both have compatible strided layouts.
- The individual sizes (resp. offset and strides in the case of strided
memrefs) may convert constant dimensions to dynamic dimensions and
vice-versa.

If the cast converts any dimensions from an unknown to a known size,
then it acts as an assertion that fails at runtime if the dynamic
dimensions disagree with the resultant destination size (i.e. it is
illegal to do a conversion that causes a mismatch, and it would invoke
undefined behaviour).

Note: a new memref with the new type is returned, the type of the
original memref is not modified.

Example:
```
def f(m1: MemRef[F32, DYNAMIC, 32, 5]) -> MemRef[F32, 64, 32, DYNAMIC]:
# Only valid if the first dimension of m1 is always 64
m2 = m1.cast((64, 32, DYNAMIC))
return m2
```
"""
# NOTE: default value of strides is (-1,) because None is already used
# to represent default layout... This is definitely not a great
# solution, but it's unclear how to do this better.

shape = tuple(shape) if shape is not None else self.shape
offset = int(offset) if offset is not None else self.offset
strides = (
None
if strides is None
else self.strides
if strides == (-1,)
else tuple(strides)
)

if not all(isinstance(x, int) for x in shape):
raise ValueError(
f"shape should be a tuple of integers known at compile time ",
f"got {repr(shape)}",
)

if not (strides is None or all(isinstance(x, int) for x in strides)):
raise ValueError(
f"strides should be a tuple of integers known at compile ",
f"time, got {repr(strides)}",
)

assert_shapes_compatible(self.shape, shape)
assert_shapes_compatible([self.offset], [offset])

# TODO: also do a type check if one of the MemRefs is default layout.
# This is a reasonably complicated check, and even MLIR doesn't do it
# properly. E.g. mlir-opt allows
# `memref<3x4xf64, strided<[8, 1], offset: ?>> to memref<?x?xf64>`
# to compile, even though it is impossible for this to be correct.
if self.strides is not None and strides is not None:
assert_shapes_compatible(self.strides, strides)

result_type = self.class_factory(
shape, self.element_type, offset=offset, strides=strides
)
rep = memref.cast(lower_single(result_type), lower_single(self))
return result_type(rep)

def lower(self) -> tuple[Value]:
return (self.value,)

Expand Down Expand Up @@ -739,8 +831,8 @@ def _alloc_generic(
alloc_func: Callable,
shape: Compiled,
dtype: Evaluated,
memory_space: MemorySpace | None,
alignment: int | None,
memory_space: MemorySpace | None = None,
alignment: int | None = None,
) -> SubtreeOut:
"""
Does the logic required for alloc/alloca. It was silly having two functions
Expand Down
41 changes: 40 additions & 1 deletion src/pydsl/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from mlir.ir import DenseI64ArrayAttr, OpView, RankedTensorType, Value

from pydsl.func import InlineFunction
from pydsl.macro import CallMacro, Compiled, Evaluated
from pydsl.macro import CallMacro, Compiled, Evaluated, MethodType
from pydsl.memref import (
assert_shapes_compatible,
UsesRMRD,
RuntimeMemrefShape,
slices_to_mlir_format,
Expand Down Expand Up @@ -275,6 +276,44 @@ def on_class_getitem(
# Equivalent to cls.__class_getitem__(args)
return cls[args]

@CallMacro.generate(method_type=MethodType.INSTANCE)
def cast(
visitor: ToMLIRBase, self: typing.Self, shape: Evaluated
) -> typing.Self:
"""
Convert a tensor from one type to an equivalent type without changing
any data elements. The resulting tensor type will have the same element
type. shape is the shape of the new tensor and must be known at compile
time. For any constant dimensions of shape, the input tensor must
actually have that dimension at runtime, otherwise the operation is
invalid.

Note: this function only returns a tensor with the updated type, it
does not modify the type of the input tensor.

Example:
```
def f(t1: Tensor[F32, DYNAMIC, 32, 5]) -> Tensor[F32, 64, 32, DYNAMIC]:
# Only valid if the first dimension of t1 is always 64
t2 = t1.cast((64, 32, DYNAMIC))
return t2
```
"""

shape = tuple(shape)

if not all(isinstance(x, int) for x in shape):
raise TypeError(
f"shape should be a tuple of integers known at compile time ",
f"got {repr(shape)}",
)

assert_shapes_compatible(self.shape, shape)

result_type = self.class_factory(shape, self.element_type)
rep = tensor.cast(lower_single(result_type), lower_single(self))
return result_type(rep)


# Convenient alias
TensorFactory = Tensor.class_factory
Expand Down
60 changes: 57 additions & 3 deletions tests/e2e/test_memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,15 @@ def f() -> MemRef[F64, 4, 6]:

def test_alloc_bad_align():
with compilation_failed_from(TypeError):

@compile()
def f():
alloc((4, 6), F64, alignment="xyz")

with compilation_failed_from(ValueError):

@compile()
def f():
alloc((4, 6), F64, alignment=-123)


def test_slice_memory_space():
"""
We can't use GPU address spaces on CPU, so just test if it compiles.
Expand Down Expand Up @@ -483,6 +480,59 @@ def f(m1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]:
assert res2.shape == ()


def test_cast_basic():
@compile()
def f(
m1: MemRef[F32, 10, DYNAMIC, 30],
) -> MemRef[F32, DYNAMIC, 20, DYNAMIC]:
m1 = m1.cast((DYNAMIC, 20, 30), strides=(600, 30, 1))
m1 = m1.cast((DYNAMIC, DYNAMIC, DYNAMIC))
m1 = m1.cast((10, 20, 30), strides=None)
m1 = m1.cast((DYNAMIC, 20, DYNAMIC))
return m1

n1 = multi_arange((10, 20, 30), dtype=np.float32)
cor_res = n1.copy()
assert (f(n1) == cor_res).all()


def test_cast_strided():
MemRef1 = MemRef.get((8, DYNAMIC), SInt16, offset=10, strides=(1, 8))
MemRef2 = MemRef.get((8, 4), SInt16, offset=DYNAMIC, strides=(DYNAMIC, 8))

@compile()
def f(m1: MemRef1) -> MemRef2:
m1.cast(strides=(1, DYNAMIC))
m1 = m1.cast((8, 4), offset=DYNAMIC, strides=(DYNAMIC, 8))
return m1

i16_sz = np.int16().nbytes
n1 = multi_arange((8, 4), np.int16)
n1 = as_strided(n1, shape=(8, 4), strides=(i16_sz, 8 * i16_sz))
cor_res = n1.copy()
assert (f(n1) == cor_res).all()


def test_cast_bad():
with compilation_failed_from(ValueError):

@compile()
def f1(m1: MemRef[UInt32, 5, 8]):
m1.cast((5, 8, 7))

with compilation_failed_from(ValueError):

@compile()
def f2(m1: MemRef[F64, DYNAMIC, 4]):
m1.cast((DYNAMIC, 5))

with compilation_failed_from(ValueError):
MemRef1 = MemRef.get((8, DYNAMIC), F32, offset=10, strides=(1, 8))

@compile()
def f3(m1: MemRef1):
m1.cast(strides=(DYNAMIC, 16))

def test_copy_basic():
@compile()
def f(m1: MemRef[SInt16, 10, DYNAMIC], m2: MemRef[SInt16, 10, 10]):
Expand Down Expand Up @@ -562,6 +612,10 @@ def f(m1: MemRef[F32, 5], m2: MemRef[F64, 5]):
run(test_link_ndarray)
run(test_chain_link_ndarray)
run(test_zero_d)
run(test_cast_basic)
run(test_cast_strided)
run(test_cast_strided)
run(test_cast_bad)
run(test_copy_basic)
run(test_copy_overlap)
run(test_copy_strided)
Expand Down
27 changes: 27 additions & 0 deletions tests/e2e/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,31 @@ def f() -> Tensor[F64, DYNAMIC, DYNAMIC]:
assert (test_res == cor_res).all()


def test_cast():
@compile()
def f(t1: Tensor[F32, DYNAMIC, 32, 5]) -> Tensor[F32, 64, 32, DYNAMIC]:
t2 = t1.cast((64, 32, DYNAMIC))
return t2

n1 = multi_arange((64, 32, 5), np.float32)
cor_res = n1.copy()
assert (f(n1) == cor_res).all()


def test_cast_bad():
with compilation_failed_from(ValueError):

@compile()
def f1(t1: Tensor[SInt32, DYNAMIC, 13]):
t1.cast((7, 13, DYNAMIC))

with compilation_failed_from(ValueError):

@compile()
def f2(t1: Tensor[UInt64, 3, 7]):
t1.cast((1, 21))


if __name__ == "__main__":
run(test_wrong_dim)
run(test_load)
Expand All @@ -372,3 +397,5 @@ def f() -> Tensor[F64, DYNAMIC, DYNAMIC]:
run(test_full)
run(test_zeros)
run(test_ones)
run(test_cast)
run(test_cast_bad)
Loading