From a3e787a8c5179b5c1362edcd6e9b69788699c489 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 13 Dec 2025 18:26:25 +0000 Subject: [PATCH 1/7] Initial plan From f124077f72bd9b25aa62d1c560c89c83847fd85c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 13 Dec 2025 18:39:32 +0000 Subject: [PATCH 2/7] Add MLX tensor adapter with comprehensive tests Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/tensor_adapters.py | 120 +++++++++++++++++++ src/onnx_ir/tensor_adapters_test.py | 172 +++++++++++++++++++++++++++- 2 files changed, 291 insertions(+), 1 deletion(-) diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 8be17108..f60b41da 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -32,6 +32,9 @@ "from_torch_dtype", "to_torch_dtype", "TorchTensor", + "from_mlx_dtype", + "to_mlx_dtype", + "MlxArray", ] import ctypes @@ -43,11 +46,14 @@ from onnx_ir import _core if TYPE_CHECKING: + import mlx.core as mx import torch _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None _ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None +_MLX_DTYPE_TO_ONNX: dict[mx.Dtype, ir.DataType] | None = None +_ONNX_DTYPE_TO_MLX: dict[ir.DataType, mx.Dtype] | None = None def from_torch_dtype(dtype: torch.dtype) -> ir.DataType: @@ -198,3 +204,117 @@ def tobytes(self) -> bytes: def tofile(self, file) -> None: _, data = self._get_cbytes() return file.write(data) + + +def from_mlx_dtype(dtype: mx.Dtype) -> ir.DataType: + """Convert an MLX dtype to an ONNX IR DataType.""" + global _MLX_DTYPE_TO_ONNX + if _MLX_DTYPE_TO_ONNX is None: + import mlx.core as mx + + _MLX_DTYPE_TO_ONNX = { + mx.bool_: ir.DataType.BOOL, + mx.uint8: ir.DataType.UINT8, + mx.uint16: ir.DataType.UINT16, + mx.uint32: ir.DataType.UINT32, + mx.uint64: ir.DataType.UINT64, + mx.int8: ir.DataType.INT8, + mx.int16: ir.DataType.INT16, + mx.int32: ir.DataType.INT32, + mx.int64: ir.DataType.INT64, + mx.float16: ir.DataType.FLOAT16, + mx.float32: ir.DataType.FLOAT, + mx.bfloat16: ir.DataType.BFLOAT16, + mx.complex64: ir.DataType.COMPLEX64, + } + + if dtype not in _MLX_DTYPE_TO_ONNX: + raise TypeError( + f"Unsupported MLX dtype '{dtype}'. " + "Please use a supported dtype from the list: " + f"{list(_MLX_DTYPE_TO_ONNX.keys())}" + ) + return _MLX_DTYPE_TO_ONNX[dtype] + + +def to_mlx_dtype(dtype: ir.DataType) -> mx.Dtype: + """Convert an ONNX IR DataType to an MLX dtype.""" + global _ONNX_DTYPE_TO_MLX + if _ONNX_DTYPE_TO_MLX is None: + import mlx.core as mx + + _ONNX_DTYPE_TO_MLX = { + ir.DataType.BOOL: mx.bool_, + ir.DataType.UINT8: mx.uint8, + ir.DataType.UINT16: mx.uint16, + ir.DataType.UINT32: mx.uint32, + ir.DataType.UINT64: mx.uint64, + ir.DataType.INT8: mx.int8, + ir.DataType.INT16: mx.int16, + ir.DataType.INT32: mx.int32, + ir.DataType.INT64: mx.int64, + ir.DataType.FLOAT16: mx.float16, + ir.DataType.FLOAT: mx.float32, + ir.DataType.BFLOAT16: mx.bfloat16, + ir.DataType.COMPLEX64: mx.complex64, + } + + if dtype not in _ONNX_DTYPE_TO_MLX: + raise TypeError( + f"Unsupported conversion from ONNX dtype '{dtype}' to mlx. " + "Please use a supported dtype from the list: " + f"{list(_ONNX_DTYPE_TO_MLX.keys())}" + ) + return _ONNX_DTYPE_TO_MLX[dtype] + + +class MlxArray(_core.Tensor): + """Tensor adapter for MLX arrays. + + This class wraps MLX arrays to make them compatible with the ONNX IR tensor protocol. + MLX arrays are efficiently converted to numpy arrays for serialization. + + Example:: + import mlx.core as mx + import onnx_ir as ir + + # Create an MLX array + mlx_array = mx.array([1, 2, 3], dtype=mx.float32) + + # Wrap the MLX array in an MlxArray object + ir_tensor = ir.tensor_adapters.MlxArray(mlx_array) + + # Use the IR tensor in the graph + attr = ir.AttrTensor("x", ir_tensor) + print(attr) + """ + + def __init__( + self, array: mx.array, name: str | None = None, doc_string: str | None = None + ): + # Pass the array as the raw data to ir.Tensor's constructor + super().__init__( + array, dtype=from_mlx_dtype(array.dtype), name=name, doc_string=doc_string + ) + + def numpy(self) -> npt.NDArray: + import numpy as np + + self.raw: mx.array + # MLX arrays support __dlpack__ for efficient zero-copy conversion + # We use numpy's from_dlpack to convert + return np.asarray(self.raw) + + def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: + del copy # Unused, but needed for the signature + if dtype is None: + return self.numpy() + return self.numpy().__array__(dtype) + + def tobytes(self) -> bytes: + # Convert to numpy and then to bytes for serialization + return self.numpy().tobytes() + + def tofile(self, file) -> None: + # Convert to numpy and write to file + return self.numpy().tofile(file) diff --git a/src/onnx_ir/tensor_adapters_test.py b/src/onnx_ir/tensor_adapters_test.py index 0af13627..447fa70f 100644 --- a/src/onnx_ir/tensor_adapters_test.py +++ b/src/onnx_ir/tensor_adapters_test.py @@ -19,9 +19,14 @@ def skip_if_no(module_name: str): - """Decorator to skip a test if a module is not installed.""" + """Decorator to skip a test if a module is not installed or cannot be imported.""" if importlib.util.find_spec(module_name) is None: return unittest.skip(f"{module_name} not installed") + # Try to actually import the module to check if it works + try: + __import__(module_name) + except Exception as e: + return unittest.skip(f"{module_name} cannot be imported: {e}") return lambda func: func @@ -209,5 +214,170 @@ def test_tofile_non_contiguous(self): self.assertEqual(tensor.tobytes(), expected_manual) +@skip_if_no("mlx.core") +class MlxArrayTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("mlx.core.bool_", np.bool_), + ("mlx.core.uint8", np.uint8), + ("mlx.core.uint16", np.uint16), + ("mlx.core.uint32", np.uint32), + ("mlx.core.uint64", np.uint64), + ("mlx.core.int8", np.int8), + ("mlx.core.int16", np.int16), + ("mlx.core.int32", np.int32), + ("mlx.core.int64", np.int64), + ("mlx.core.float16", np.float16), + ("mlx.core.float32", np.float32), + ("mlx.core.bfloat16", ml_dtypes.bfloat16), + ("mlx.core.complex64", np.complex64), + ], + ) + def test_numpy_returns_correct_dtype(self, mlx_dtype_str: str, np_dtype): + import mlx.core as mx + + # Get the actual mlx dtype from string like "mlx.core.float32" + dtype_name = mlx_dtype_str.split(".")[-1] + mlx_dtype = getattr(mx, dtype_name) + + mlx_array = mx.array([1], dtype=mlx_dtype) + tensor = tensor_adapters.MlxArray(mlx_array) + self.assertEqual(tensor.numpy().dtype, np_dtype) + self.assertEqual(tensor.__array__().dtype, np_dtype) + self.assertEqual(np.array(tensor).dtype, np_dtype) + + @parameterized.parameterized.expand( + [ + ("mlx.core.bool_",), + ("mlx.core.uint8",), + ("mlx.core.uint16",), + ("mlx.core.uint32",), + ("mlx.core.uint64",), + ("mlx.core.int8",), + ("mlx.core.int16",), + ("mlx.core.int32",), + ("mlx.core.int64",), + ("mlx.core.float16",), + ("mlx.core.float32",), + ("mlx.core.bfloat16",), + ("mlx.core.complex64",), + ], + ) + def test_tobytes(self, mlx_dtype_str: str): + import mlx.core as mx + + dtype_name = mlx_dtype_str.split(".")[-1] + mlx_dtype = getattr(mx, dtype_name) + + mlx_array = mx.array([1], dtype=mlx_dtype) + tensor = tensor_adapters.MlxArray(mlx_array) + self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes()) + + def test_tofile_method_exists_and_works(self): + """Test that tofile() method exists and works correctly.""" + import mlx.core as mx + + mlx_array = mx.array([1.0, 2.0, 3.0], dtype=mx.float32) + tensor = tensor_adapters.MlxArray(mlx_array) + + # Test with BytesIO buffer + buffer = io.BytesIO() + tensor.tofile(buffer) + result_bytes = buffer.getvalue() + + expected_bytes = tensor.tobytes() + self.assertEqual(result_bytes, expected_bytes) + + @parameterized.parameterized.expand( + [ + ("mlx.core.bool_",), + ("mlx.core.uint8",), + ("mlx.core.uint16",), + ("mlx.core.uint32",), + ("mlx.core.uint64",), + ("mlx.core.int8",), + ("mlx.core.int16",), + ("mlx.core.int32",), + ("mlx.core.int64",), + ("mlx.core.float16",), + ("mlx.core.float32",), + ("mlx.core.bfloat16",), + ("mlx.core.complex64",), + ], + ) + def test_tofile(self, mlx_dtype_str: str): + """Test tofile() method for various data types.""" + import mlx.core as mx + + dtype_name = mlx_dtype_str.split(".")[-1] + mlx_dtype = getattr(mx, dtype_name) + + mlx_array = mx.array([1], dtype=mlx_dtype) + tensor = tensor_adapters.MlxArray(mlx_array) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + expected_bytes = tensor.tobytes() + self.assertEqual(result_bytes, expected_bytes) + + +@skip_if_no("mlx.core") +class MlxDtypeConversionTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + (ir.DataType.BOOL, "mlx.core.bool_"), + (ir.DataType.UINT8, "mlx.core.uint8"), + (ir.DataType.UINT16, "mlx.core.uint16"), + (ir.DataType.UINT32, "mlx.core.uint32"), + (ir.DataType.UINT64, "mlx.core.uint64"), + (ir.DataType.INT8, "mlx.core.int8"), + (ir.DataType.INT16, "mlx.core.int16"), + (ir.DataType.INT32, "mlx.core.int32"), + (ir.DataType.INT64, "mlx.core.int64"), + (ir.DataType.FLOAT16, "mlx.core.float16"), + (ir.DataType.FLOAT, "mlx.core.float32"), + (ir.DataType.BFLOAT16, "mlx.core.bfloat16"), + (ir.DataType.COMPLEX64, "mlx.core.complex64"), + ] + ) + def test_to_mlx_dtype(self, onnx_dtype: ir.DataType, expected_mlx_dtype_str: str): + import mlx.core as mx + + dtype_name = expected_mlx_dtype_str.split(".")[-1] + expected_mlx_dtype = getattr(mx, dtype_name) + + actual = tensor_adapters.to_mlx_dtype(onnx_dtype) + self.assertEqual(actual, expected_mlx_dtype) + + @parameterized.parameterized.expand( + [ + ("mlx.core.bool_", ir.DataType.BOOL), + ("mlx.core.uint8", ir.DataType.UINT8), + ("mlx.core.uint16", ir.DataType.UINT16), + ("mlx.core.uint32", ir.DataType.UINT32), + ("mlx.core.uint64", ir.DataType.UINT64), + ("mlx.core.int8", ir.DataType.INT8), + ("mlx.core.int16", ir.DataType.INT16), + ("mlx.core.int32", ir.DataType.INT32), + ("mlx.core.int64", ir.DataType.INT64), + ("mlx.core.float16", ir.DataType.FLOAT16), + ("mlx.core.float32", ir.DataType.FLOAT), + ("mlx.core.bfloat16", ir.DataType.BFLOAT16), + ("mlx.core.complex64", ir.DataType.COMPLEX64), + ] + ) + def test_from_mlx_dtype(self, mlx_dtype_str: str, expected_onnx_dtype: ir.DataType): + import mlx.core as mx + + dtype_name = mlx_dtype_str.split(".")[-1] + mlx_dtype = getattr(mx, dtype_name) + + actual = tensor_adapters.from_mlx_dtype(mlx_dtype) + self.assertEqual(actual, expected_onnx_dtype) + + if __name__ == "__main__": unittest.main() From a9234c7c4ac6a650eab4c3e67e7394ca40deeadd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 13 Dec 2025 18:42:11 +0000 Subject: [PATCH 3/7] Address code review: remove redundant type annotation Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/tensor_adapters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index f60b41da..2783505d 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -300,7 +300,6 @@ def __init__( def numpy(self) -> npt.NDArray: import numpy as np - self.raw: mx.array # MLX arrays support __dlpack__ for efficient zero-copy conversion # We use numpy's from_dlpack to convert return np.asarray(self.raw) From a5b1a3a4b0311c54c5cfec794ffbd1f2bc38ff88 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 13 Dec 2025 18:44:35 +0000 Subject: [PATCH 4/7] Add MLX adapter to API documentation Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- docs/api/ir_tensor_adapters.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/api/ir_tensor_adapters.md b/docs/api/ir_tensor_adapters.md index d43b9d6e..02595087 100644 --- a/docs/api/ir_tensor_adapters.md +++ b/docs/api/ir_tensor_adapters.md @@ -19,3 +19,19 @@ .. autofunction:: onnx_ir.tensor_adapters.from_torch_dtype .. autofunction:: onnx_ir.tensor_adapters.to_torch_dtype ``` + +## Adapters for MLX + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: classtemplate.rst + :nosignatures: + + onnx_ir.tensor_adapters.MlxArray +``` + +```{eval-rst} +.. autofunction:: onnx_ir.tensor_adapters.from_mlx_dtype +.. autofunction:: onnx_ir.tensor_adapters.to_mlx_dtype +``` From 85babd08ccfcf9c9321d11d64ca2e1407ae03a6f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 13 Dec 2025 10:52:16 -0800 Subject: [PATCH 5/7] Apply suggestion from @justinchuby Signed-off-by: Justin Chu --- src/onnx_ir/tensor_adapters.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 2783505d..502d4dd5 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -300,8 +300,6 @@ def __init__( def numpy(self) -> npt.NDArray: import numpy as np - # MLX arrays support __dlpack__ for efficient zero-copy conversion - # We use numpy's from_dlpack to convert return np.asarray(self.raw) def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: From 00053c3c788d3e1941d945b42f45a41950ea1455 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 13 Dec 2025 10:54:21 -0800 Subject: [PATCH 6/7] Add 'mlx' to the dependencies list Signed-off-by: Justin Chu --- noxfile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/noxfile.py b/noxfile.py index 3eff5c67..30a72a72 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,6 +28,7 @@ "types-PyYAML", "typing_extensions>=4.10", "ml-dtypes", + "mlx", "onnxruntime", ) ONNX = "onnx==1.18" From 94e1ef0555a6f4d21f3801c7a750815a2f0463de Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 13 Dec 2025 19:01:42 +0000 Subject: [PATCH 7/7] Add platform-specific MLX installation and Windows skip logic - Add mlx to requirements-dev.txt with sys_platform != "win32" condition - Update skip_if_no decorator to skip MLX tests on Windows - Improve error handling in skip_if_no for missing parent modules Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- requirements-dev.txt | 1 + src/onnx_ir/tensor_adapters_test.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index a6c1c625..5c1f9d86 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -28,6 +28,7 @@ pyyaml torch>=2.3 torchvision>=0.18.0 transformers>=4.37.2 +mlx; sys_platform != "win32" # Lint lintrunner>=0.10.7 diff --git a/src/onnx_ir/tensor_adapters_test.py b/src/onnx_ir/tensor_adapters_test.py index 447fa70f..5fc6f3b9 100644 --- a/src/onnx_ir/tensor_adapters_test.py +++ b/src/onnx_ir/tensor_adapters_test.py @@ -6,6 +6,7 @@ import importlib.util import io +import sys import tempfile import unittest @@ -20,8 +21,17 @@ def skip_if_no(module_name: str): """Decorator to skip a test if a module is not installed or cannot be imported.""" - if importlib.util.find_spec(module_name) is None: + # Special handling for MLX: skip on Windows as it's not supported + if module_name == "mlx.core" and sys.platform == "win32": + return unittest.skip("mlx is not available on Windows") + + try: + spec = importlib.util.find_spec(module_name) + if spec is None: + return unittest.skip(f"{module_name} not installed") + except (ModuleNotFoundError, ImportError, ValueError): return unittest.skip(f"{module_name} not installed") + # Try to actually import the module to check if it works try: __import__(module_name)