From f91df87d5b7a1a8cd9f343fa9314c1761791e655 Mon Sep 17 00:00:00 2001 From: Niklas Hansson Date: Fri, 19 Dec 2025 00:14:20 +0100 Subject: [PATCH 1/2] Remove redundant check to simplify the code Signed-off-by: Niklas Hansson --- src/cuda/tile/_ir/ops_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cuda/tile/_ir/ops_utils.py b/src/cuda/tile/_ir/ops_utils.py index b25716c..2070993 100644 --- a/src/cuda/tile/_ir/ops_utils.py +++ b/src/cuda/tile/_ir/ops_utils.py @@ -154,7 +154,7 @@ def check_rd_and_ftz(fn: str, rounding_mode: Optional[RoundingMode], flush_to_ze f'Rounding mode {rounding_mode.value} can only be used for float32 type, ' f'but got {dtype}') if flush_to_zero: - if flush_to_zero and not math_op_def.support_flush_to_zero: + if not math_op_def.support_flush_to_zero: raise TileTypeError(f'Flush to zero is not supported for {fn}') if dtype != datatype.float32: raise TileTypeError( From 6373e5331de72a367f4983196ac4049846ec5d86 Mon Sep 17 00:00:00 2001 From: Niklas Hansson Date: Fri, 19 Dec 2025 00:15:43 +0100 Subject: [PATCH 2/2] flatten and remove ifelse due to early returns Signed-off-by: Niklas Hansson --- src/cuda/tile/_ir/ops_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/cuda/tile/_ir/ops_utils.py b/src/cuda/tile/_ir/ops_utils.py index 2070993..c96cedd 100644 --- a/src/cuda/tile/_ir/ops_utils.py +++ b/src/cuda/tile/_ir/ops_utils.py @@ -188,14 +188,13 @@ def memory_order_has_release(memory_order: MemoryOrder): def get_dtype(ty: TileTy | datatype.DType | LooselyTypedScalar) -> datatype.DType | PointerTy: if isinstance(ty, TileTy): return ty.dtype - elif isinstance(ty, datatype.DType): + if isinstance(ty, datatype.DType): return ty - elif isinstance(ty, PointerTy): + if isinstance(ty, PointerTy): return ty - elif isinstance(ty, LooselyTypedScalar): + if isinstance(ty, LooselyTypedScalar): return typeof_pyval(ty.value) - else: - raise TypeError(f"Cannot get dtype from {ty}") + raise TypeError(f"Cannot get dtype from {ty}") def change_dtype(ty: TileTy | datatype.DType | PointerTy,