Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/api/ir_tensor_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,19 @@
.. autofunction:: onnx_ir.tensor_adapters.from_torch_dtype
.. autofunction:: onnx_ir.tensor_adapters.to_torch_dtype
```

## Adapters for MLX

```{eval-rst}
.. autosummary::
:toctree: generated
:template: classtemplate.rst
:nosignatures:

onnx_ir.tensor_adapters.MlxArray
```

```{eval-rst}
.. autofunction:: onnx_ir.tensor_adapters.from_mlx_dtype
.. autofunction:: onnx_ir.tensor_adapters.to_mlx_dtype
```
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"types-PyYAML",
"typing_extensions>=4.10",
"ml-dtypes",
"mlx",
"onnxruntime",
)
ONNX = "onnx==1.18"
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pyyaml
torch>=2.3
torchvision>=0.18.0
transformers>=4.37.2
mlx; sys_platform != "win32"

# Lint
lintrunner>=0.10.7
Expand Down
117 changes: 117 additions & 0 deletions src/onnx_ir/tensor_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
"from_torch_dtype",
"to_torch_dtype",
"TorchTensor",
"from_mlx_dtype",
"to_mlx_dtype",
"MlxArray",
]

import ctypes
Expand All @@ -43,11 +46,14 @@
from onnx_ir import _core

if TYPE_CHECKING:
import mlx.core as mx
import torch


_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
_ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None
_MLX_DTYPE_TO_ONNX: dict[mx.Dtype, ir.DataType] | None = None
_ONNX_DTYPE_TO_MLX: dict[ir.DataType, mx.Dtype] | None = None


def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
Expand Down Expand Up @@ -198,3 +204,114 @@ def tobytes(self) -> bytes:
def tofile(self, file) -> None:
_, data = self._get_cbytes()
return file.write(data)


def from_mlx_dtype(dtype: mx.Dtype) -> ir.DataType:
"""Convert an MLX dtype to an ONNX IR DataType."""
global _MLX_DTYPE_TO_ONNX
if _MLX_DTYPE_TO_ONNX is None:
import mlx.core as mx

_MLX_DTYPE_TO_ONNX = {
mx.bool_: ir.DataType.BOOL,
mx.uint8: ir.DataType.UINT8,
mx.uint16: ir.DataType.UINT16,
mx.uint32: ir.DataType.UINT32,
mx.uint64: ir.DataType.UINT64,
mx.int8: ir.DataType.INT8,
mx.int16: ir.DataType.INT16,
mx.int32: ir.DataType.INT32,
mx.int64: ir.DataType.INT64,
mx.float16: ir.DataType.FLOAT16,
mx.float32: ir.DataType.FLOAT,
mx.bfloat16: ir.DataType.BFLOAT16,
mx.complex64: ir.DataType.COMPLEX64,
}

if dtype not in _MLX_DTYPE_TO_ONNX:
raise TypeError(
f"Unsupported MLX dtype '{dtype}'. "
"Please use a supported dtype from the list: "
f"{list(_MLX_DTYPE_TO_ONNX.keys())}"
)
return _MLX_DTYPE_TO_ONNX[dtype]


def to_mlx_dtype(dtype: ir.DataType) -> mx.Dtype:
"""Convert an ONNX IR DataType to an MLX dtype."""
global _ONNX_DTYPE_TO_MLX
if _ONNX_DTYPE_TO_MLX is None:
import mlx.core as mx

_ONNX_DTYPE_TO_MLX = {
ir.DataType.BOOL: mx.bool_,
ir.DataType.UINT8: mx.uint8,
ir.DataType.UINT16: mx.uint16,
ir.DataType.UINT32: mx.uint32,
ir.DataType.UINT64: mx.uint64,
ir.DataType.INT8: mx.int8,
ir.DataType.INT16: mx.int16,
ir.DataType.INT32: mx.int32,
ir.DataType.INT64: mx.int64,
ir.DataType.FLOAT16: mx.float16,
ir.DataType.FLOAT: mx.float32,
ir.DataType.BFLOAT16: mx.bfloat16,
ir.DataType.COMPLEX64: mx.complex64,
}

if dtype not in _ONNX_DTYPE_TO_MLX:
raise TypeError(
f"Unsupported conversion from ONNX dtype '{dtype}' to mlx. "
"Please use a supported dtype from the list: "
f"{list(_ONNX_DTYPE_TO_MLX.keys())}"
)
return _ONNX_DTYPE_TO_MLX[dtype]


class MlxArray(_core.Tensor):
"""Tensor adapter for MLX arrays.

This class wraps MLX arrays to make them compatible with the ONNX IR tensor protocol.
MLX arrays are efficiently converted to numpy arrays for serialization.

Example::
import mlx.core as mx
import onnx_ir as ir

# Create an MLX array
mlx_array = mx.array([1, 2, 3], dtype=mx.float32)

# Wrap the MLX array in an MlxArray object
ir_tensor = ir.tensor_adapters.MlxArray(mlx_array)

# Use the IR tensor in the graph
attr = ir.AttrTensor("x", ir_tensor)
print(attr)
"""

def __init__(
self, array: mx.array, name: str | None = None, doc_string: str | None = None
):
# Pass the array as the raw data to ir.Tensor's constructor
super().__init__(
array, dtype=from_mlx_dtype(array.dtype), name=name, doc_string=doc_string
)

def numpy(self) -> npt.NDArray:
import numpy as np

return np.asarray(self.raw)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
del copy # Unused, but needed for the signature
if dtype is None:
return self.numpy()
return self.numpy().__array__(dtype)

def tobytes(self) -> bytes:
# Convert to numpy and then to bytes for serialization
return self.numpy().tobytes()

def tofile(self, file) -> None:
# Convert to numpy and write to file
return self.numpy().tofile(file)
184 changes: 182 additions & 2 deletions src/onnx_ir/tensor_adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import importlib.util
import io
import sys
import tempfile
import unittest

Expand All @@ -19,9 +20,23 @@


def skip_if_no(module_name: str):
"""Decorator to skip a test if a module is not installed."""
if importlib.util.find_spec(module_name) is None:
"""Decorator to skip a test if a module is not installed or cannot be imported."""
# Special handling for MLX: skip on Windows as it's not supported
if module_name == "mlx.core" and sys.platform == "win32":
return unittest.skip("mlx is not available on Windows")

try:
spec = importlib.util.find_spec(module_name)
if spec is None:
return unittest.skip(f"{module_name} not installed")
except (ModuleNotFoundError, ImportError, ValueError):
return unittest.skip(f"{module_name} not installed")

# Try to actually import the module to check if it works
try:
__import__(module_name)
except Exception as e:
return unittest.skip(f"{module_name} cannot be imported: {e}")
return lambda func: func


Expand Down Expand Up @@ -209,5 +224,170 @@ def test_tofile_non_contiguous(self):
self.assertEqual(tensor.tobytes(), expected_manual)


@skip_if_no("mlx.core")
class MlxArrayTest(unittest.TestCase):
@parameterized.parameterized.expand(
[
("mlx.core.bool_", np.bool_),
("mlx.core.uint8", np.uint8),
("mlx.core.uint16", np.uint16),
("mlx.core.uint32", np.uint32),
("mlx.core.uint64", np.uint64),
("mlx.core.int8", np.int8),
("mlx.core.int16", np.int16),
("mlx.core.int32", np.int32),
("mlx.core.int64", np.int64),
("mlx.core.float16", np.float16),
("mlx.core.float32", np.float32),
("mlx.core.bfloat16", ml_dtypes.bfloat16),
("mlx.core.complex64", np.complex64),
],
)
def test_numpy_returns_correct_dtype(self, mlx_dtype_str: str, np_dtype):
import mlx.core as mx

# Get the actual mlx dtype from string like "mlx.core.float32"
dtype_name = mlx_dtype_str.split(".")[-1]
mlx_dtype = getattr(mx, dtype_name)

mlx_array = mx.array([1], dtype=mlx_dtype)
tensor = tensor_adapters.MlxArray(mlx_array)
self.assertEqual(tensor.numpy().dtype, np_dtype)
self.assertEqual(tensor.__array__().dtype, np_dtype)
self.assertEqual(np.array(tensor).dtype, np_dtype)

@parameterized.parameterized.expand(
[
("mlx.core.bool_",),
("mlx.core.uint8",),
("mlx.core.uint16",),
("mlx.core.uint32",),
("mlx.core.uint64",),
("mlx.core.int8",),
("mlx.core.int16",),
("mlx.core.int32",),
("mlx.core.int64",),
("mlx.core.float16",),
("mlx.core.float32",),
("mlx.core.bfloat16",),
("mlx.core.complex64",),
],
)
def test_tobytes(self, mlx_dtype_str: str):
import mlx.core as mx

dtype_name = mlx_dtype_str.split(".")[-1]
mlx_dtype = getattr(mx, dtype_name)

mlx_array = mx.array([1], dtype=mlx_dtype)
tensor = tensor_adapters.MlxArray(mlx_array)
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())

def test_tofile_method_exists_and_works(self):
"""Test that tofile() method exists and works correctly."""
import mlx.core as mx

mlx_array = mx.array([1.0, 2.0, 3.0], dtype=mx.float32)
tensor = tensor_adapters.MlxArray(mlx_array)

# Test with BytesIO buffer
buffer = io.BytesIO()
tensor.tofile(buffer)
result_bytes = buffer.getvalue()

expected_bytes = tensor.tobytes()
self.assertEqual(result_bytes, expected_bytes)

@parameterized.parameterized.expand(
[
("mlx.core.bool_",),
("mlx.core.uint8",),
("mlx.core.uint16",),
("mlx.core.uint32",),
("mlx.core.uint64",),
("mlx.core.int8",),
("mlx.core.int16",),
("mlx.core.int32",),
("mlx.core.int64",),
("mlx.core.float16",),
("mlx.core.float32",),
("mlx.core.bfloat16",),
("mlx.core.complex64",),
],
)
def test_tofile(self, mlx_dtype_str: str):
"""Test tofile() method for various data types."""
import mlx.core as mx

dtype_name = mlx_dtype_str.split(".")[-1]
mlx_dtype = getattr(mx, dtype_name)

mlx_array = mx.array([1], dtype=mlx_dtype)
tensor = tensor_adapters.MlxArray(mlx_array)

with tempfile.NamedTemporaryFile() as temp_file:
tensor.tofile(temp_file)
temp_file.seek(0)
result_bytes = temp_file.read()

expected_bytes = tensor.tobytes()
self.assertEqual(result_bytes, expected_bytes)


@skip_if_no("mlx.core")
class MlxDtypeConversionTest(unittest.TestCase):
@parameterized.parameterized.expand(
[
(ir.DataType.BOOL, "mlx.core.bool_"),
(ir.DataType.UINT8, "mlx.core.uint8"),
(ir.DataType.UINT16, "mlx.core.uint16"),
(ir.DataType.UINT32, "mlx.core.uint32"),
(ir.DataType.UINT64, "mlx.core.uint64"),
(ir.DataType.INT8, "mlx.core.int8"),
(ir.DataType.INT16, "mlx.core.int16"),
(ir.DataType.INT32, "mlx.core.int32"),
(ir.DataType.INT64, "mlx.core.int64"),
(ir.DataType.FLOAT16, "mlx.core.float16"),
(ir.DataType.FLOAT, "mlx.core.float32"),
(ir.DataType.BFLOAT16, "mlx.core.bfloat16"),
(ir.DataType.COMPLEX64, "mlx.core.complex64"),
]
)
def test_to_mlx_dtype(self, onnx_dtype: ir.DataType, expected_mlx_dtype_str: str):
import mlx.core as mx

dtype_name = expected_mlx_dtype_str.split(".")[-1]
expected_mlx_dtype = getattr(mx, dtype_name)

actual = tensor_adapters.to_mlx_dtype(onnx_dtype)
self.assertEqual(actual, expected_mlx_dtype)

@parameterized.parameterized.expand(
[
("mlx.core.bool_", ir.DataType.BOOL),
("mlx.core.uint8", ir.DataType.UINT8),
("mlx.core.uint16", ir.DataType.UINT16),
("mlx.core.uint32", ir.DataType.UINT32),
("mlx.core.uint64", ir.DataType.UINT64),
("mlx.core.int8", ir.DataType.INT8),
("mlx.core.int16", ir.DataType.INT16),
("mlx.core.int32", ir.DataType.INT32),
("mlx.core.int64", ir.DataType.INT64),
("mlx.core.float16", ir.DataType.FLOAT16),
("mlx.core.float32", ir.DataType.FLOAT),
("mlx.core.bfloat16", ir.DataType.BFLOAT16),
("mlx.core.complex64", ir.DataType.COMPLEX64),
]
)
def test_from_mlx_dtype(self, mlx_dtype_str: str, expected_onnx_dtype: ir.DataType):
import mlx.core as mx

dtype_name = mlx_dtype_str.split(".")[-1]
mlx_dtype = getattr(mx, dtype_name)

actual = tensor_adapters.from_mlx_dtype(mlx_dtype)
self.assertEqual(actual, expected_onnx_dtype)


if __name__ == "__main__":
unittest.main()
Loading