-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[WIP] Add mlx.core.searchsorted #2817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Introduces a comprehensive Product Requirements Document for the new mlx.searchsorted feature, outlining API parity with NumPy, multi-dimensional support, axis-based operations, and GPU-accelerated implementation plans. Documents goals, functional/spec requirements, testing strategies, documentation plans, release notes, risks, and implementation phases. References the related issue ml-explore#1255.
Introduces a comprehensive test suite for the searchsorted operation, validating correctness across multiple dimensions, dtypes (int, float, half precision), and sides, including: - Basic 1D behavior and dtype compatibility - Handling of duplicates and edge cases (empty, single element, all duplicates) - Multidimensional inputs with various axes and broadcasting rules - Special values (inf, -inf, NaN) and out-of-range scenarios - Output shape and dtype correctness for scalar, vector, and 2D value inputs - Mixed precision scenarios to ensure robust type handling - Large input arrays for performance/behavior checks
Introduces a .clangd configuration to enhance editor tooling and code quality. - Enables strict diagnostic checks (UnusedIncludes, MissingIncludes) and enables Clang-Tidy with tailored rule sets. - Configures compilation database and background indexing to speed up tooling feedback. - Enables inlay hints (parameter names, deduced types) and hover improvements for better readability. - Activates comprehensive completion across all scopes and improves developer experience with AKA hover support.
Drops unused header dependencies: - removes transforms.h from the operations module - removes backend/utils.h from primitives
Introduces a local numpy-based reference implementation for searchsorted that supports axis and broadcasting, intended to validate the multi-dimensional behavior against the library. Updates test expectations for scalar input on 2D arrays. Temporarily comments out extensive multi-dimensional test cases pending full parity, but leaves the groundwork for future validation and comparison.
Improves code hygiene by: - Wrapping a long `set` command in `CMakeLists.txt` for better readability. - Removing trailing whitespace in `test_ops.py`. - Standardizing string literal quotes (single to double) in `test_ops.py`. - Converting whitespace-only lines to truly empty lines.
This reverts commit abf4f49.
401acee to
94c7e08
Compare
Introduces a fast-path for fully contiguous arrays to avoid unnecessary copying and allocation. Switches between linear, binary, and exponential search based on axis size, improving performance for both small and large inputs. Reduces overhead by using strided iterators directly for non-contiguous data.
|
Ran benchmarks on personal branch to avoid commit clutter profiling-searchsorted. Current implementation is faster (on cpu) than naive implementations from discussion in #1255 , but slower than pytorch searchsorted (by about 1.5x times). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds the searchsorted function to MLX, implementing binary search functionality similar to NumPy's searchsorted. This is a work in progress addressing feature request #1255, with CPU implementation complete but GPU implementation and full axis parameter support still pending.
Key Changes:
- Implements C++ and Python APIs for
mlx.core.searchsortedwith binary search on CPU - Adds comprehensive test coverage for various edge cases, dtypes, and special values
- Includes optimizations for different array sizes (linear scan for small arrays, exponential search for large ones)
Reviewed changes
Copilot reviewed 12 out of 13 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| mlx/ops.h | Adds function declarations for searchsorted with and without axis parameter |
| mlx/ops.cpp | Implements main searchsorted logic including type promotion and axis handling |
| mlx/primitives.h | Defines SearchSorted primitive class with axis and side parameters |
| mlx/primitives.cpp | Implements SearchSorted primitive methods including vmap and output shape computation |
| mlx/backend/cpu/sort.cpp | Implements CPU evaluation with optimizations for contiguous arrays and different size ranges |
| python/src/ops.cpp | Adds Python bindings with documentation for the searchsorted function |
| tests/ops_tests.cpp | Adds C++ tests covering various dtypes, edge cases, and special values |
| python/tests/test_ops.py | Adds Python tests verifying NumPy compatibility for basic functionality and edge cases |
| docs/src/python/ops.rst | Adds searchsorted to the API documentation list |
| mlx/fast_primitives.h | Adds missing include for mlx/array.h |
| CMakeLists.txt | Adds virtualenv Python detection (unrelated to searchsorted) |
| .gitignore | Adds personal development tool entries (unrelated to searchsorted) |
| .clangd | Adds clangd configuration file (unrelated to searchsorted) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| #include "mlx/fast_primitives.h" | ||
| #include "mlx/ops.h" | ||
| #include "mlx/primitives.h" |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removed include for "mlx/transforms.h" may be needed elsewhere in the file. Only remove this include if you've verified that no part of ops.cpp requires declarations from transforms.h. The fact that transforms_impl.h is still included suggests transforms.h might have been providing necessary declarations.
| #include "mlx/primitives.h" | |
| #include "mlx/primitives.h" | |
| #include "mlx/transforms.h" |
| if (to_ax != -1 && to_ax <= axis) { | ||
| axis++; | ||
| } | ||
| return {{searchsorted(a, v, right_, axis, stream())}, {to_ax}}; |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vmap implementation has incorrect parameter order in the searchsorted call. The function is called as searchsorted(a, v, right_, axis, stream()) but the function signature in ops.h line 731-736 shows it should be searchsorted(a, v, axis, right, stream). The parameters right_ and axis are swapped, which will cause incorrect behavior or compilation errors.
| return {{searchsorted(a, v, right_, axis, stream())}, {to_ax}}; | |
| return {{searchsorted(a, v, axis, right_, stream())}, {to_ax}}; |
| CompileFlags: | ||
| CompilationDatabase: build | ||
|
|
||
| Index: | ||
| Background: Build |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The .clangd configuration file addition appears to be a personal development environment preference unrelated to the searchsorted feature. This should either be removed or submitted as a separate pull request focused on development tooling improvements.
| CompileFlags: | |
| CompilationDatabase: build | |
| Index: | |
| Background: Build |
| If there is no suitable index, return either 0 or N (where N is the | ||
| length of the array). |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation states "If there is no suitable index, return either 0 or N (where N is the length of the array)" but this is misleading. The function always returns a valid insertion index between 0 and N (inclusive). When no suitable index exists, it returns 0 for values less than all elements, or N for values greater than all elements. Consider rephrasing to: "Returns indices in the range [0, N] where N is the length of the sorted dimension."
| If there is no suitable index, return either 0 or N (where N is the | |
| length of the array). | |
| Returns indices in the range [0, N] where N is the length of the sorted | |
| dimension. |
| If there is no suitable index, return either 0 or N (where N is the | ||
| length of the array). | ||
| axis (int, optional): Axis along which to search. If not provided, the | ||
| array is flattened. |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The axis parameter documentation says "If not provided, the array is flattened" but looking at the implementation in ops.cpp, when axis is not provided, it calls reshape(a, {-1}, stream) which flattens the array and then searches along axis 0. This is correct but should be clarified in the documentation to match NumPy's behavior: "If not provided, the array is treated as if it were flattened and searched along a single axis."
| array is flattened. | |
| array is treated as if it were flattened and searched along a single axis (axis 0), matching NumPy's behavior. |
| if(NOT DEFINED Python_EXECUTABLE AND DEFINED ENV{VIRTUAL_ENV}) | ||
| if(WIN32) | ||
| set(_venv_py "$ENV{VIRTUAL_ENV}/Scripts/python.exe") | ||
| else() | ||
| set(_venv_py "$ENV{VIRTUAL_ENV}/bin/python") | ||
| endif() | ||
| if(EXISTS ${_venv_py}) | ||
| message(STATUS "Using Python from virtualenv: ${_venv_py}") | ||
| set(Python_EXECUTABLE | ||
| ${_venv_py} | ||
| CACHE FILEPATH "Python executable" FORCE) | ||
| endif() | ||
| endif() |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CMakeLists.txt changes adding virtualenv Python detection appear unrelated to the searchsorted feature. These changes should be submitted as a separate pull request to keep changes focused and easier to review. Including unrelated build system improvements in a feature PR makes it harder to track what changes are part of the feature vs. infrastructure improvements.
| if(NOT DEFINED Python_EXECUTABLE AND DEFINED ENV{VIRTUAL_ENV}) | |
| if(WIN32) | |
| set(_venv_py "$ENV{VIRTUAL_ENV}/Scripts/python.exe") | |
| else() | |
| set(_venv_py "$ENV{VIRTUAL_ENV}/bin/python") | |
| endif() | |
| if(EXISTS ${_venv_py}) | |
| message(STATUS "Using Python from virtualenv: ${_venv_py}") | |
| set(Python_EXECUTABLE | |
| ${_venv_py} | |
| CACHE FILEPATH "Python executable" FORCE) | |
| endif() | |
| endif() |
| result_mx = mx.searchsorted(a_mx, v_mx) | ||
| self.assertTrue(np.array_equal(result_np, result_mx)) | ||
|
|
||
| def test_searchsorted_int64_promotion(self): |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test name "test_searchsorted_int64_promotion" is misleading as it tests much more than just int64 type promotion. It includes tests for single element arrays, duplicate values, out of range values, special float values (inf, -inf, NaN), output shapes, dtypes, large arrays, 2D arrays, and mixed precision. Consider either renaming this to something more general like "test_searchsorted_edge_cases" or splitting it into multiple focused test methods.
| def test_searchsorted_int64_promotion(self): | |
| def test_searchsorted_edge_cases(self): |
| # Test output dtype is integer | ||
| a_mx = mx.array([1.0, 2.0, 3.0, 4.0, 5.0]) | ||
| result = mx.searchsorted(a_mx, mx.array([2.5, 4.5])) | ||
| self.assertTrue(result.dtype in [mx.int32, mx.int64, mx.uint32, mx.uint64]) |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.
| self.assertTrue(result.dtype in [mx.int32, mx.int64, mx.uint32, mx.uint64]) | |
| self.assertIn(result.dtype, [mx.int32, mx.int64, mx.uint32, mx.uint64]) |
| # shape = (3, 4, 5) | ||
| # for dtype in ("int32", "float32"): | ||
| # for axis in (None, 0, 1, 2, -1): | ||
| # with self.subTest(dtype=dtype, axis=axis): | ||
| # np.random.seed(0) | ||
| # np_dtype = getattr(np, dtype) | ||
| # a_np = np.sort( | ||
| # np.random.uniform(0, 100, size=shape).astype(np_dtype), | ||
| # axis=axis, | ||
| # ) | ||
| # a_mx = mx.array(a_np) | ||
|
|
||
| # # Create search values | ||
| # if axis is None: | ||
| # v_np = np.array([25, 50, 75], dtype=np_dtype) | ||
| # else: | ||
| # # Create values that broadcast correctly | ||
| # # v needs to be broadcastable against a_no_axis (2D) | ||
| # # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2) | ||
| # v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1) | ||
|
|
||
| # v_mx = mx.array(v_np) | ||
|
|
||
| # result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis) | ||
| # result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis) | ||
| # self.assertTrue(np.array_equal(result_np, result_mx)) | ||
|
|
||
| # Test edge cases | ||
| # Empty array |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment appears to contain commented-out code.
| # shape = (3, 4, 5) | |
| # for dtype in ("int32", "float32"): | |
| # for axis in (None, 0, 1, 2, -1): | |
| # with self.subTest(dtype=dtype, axis=axis): | |
| # np.random.seed(0) | |
| # np_dtype = getattr(np, dtype) | |
| # a_np = np.sort( | |
| # np.random.uniform(0, 100, size=shape).astype(np_dtype), | |
| # axis=axis, | |
| # ) | |
| # a_mx = mx.array(a_np) | |
| # # Create search values | |
| # if axis is None: | |
| # v_np = np.array([25, 50, 75], dtype=np_dtype) | |
| # else: | |
| # # Create values that broadcast correctly | |
| # # v needs to be broadcastable against a_no_axis (2D) | |
| # # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2) | |
| # v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1) | |
| # v_mx = mx.array(v_np) | |
| # result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis) | |
| # result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis) | |
| # self.assertTrue(np.array_equal(result_np, result_mx)) | |
| # Test edge cases | |
| # Empty array | |
| # Test edge cases |
| # Test multi-dimensional with different axes | ||
| # shape = (3, 4, 5) | ||
| # for dtype in ("int32", "float32"): | ||
| # for axis in (None, 0, 1, 2, -1): | ||
| # with self.subTest(dtype=dtype, axis=axis): | ||
| # np.random.seed(0) | ||
| # np_dtype = getattr(np, dtype) | ||
| # a_np = np.sort( | ||
| # np.random.uniform(0, 100, size=shape).astype(np_dtype), | ||
| # axis=axis, | ||
| # ) | ||
| # a_mx = mx.array(a_np) | ||
|
|
||
| # # Create search values | ||
| # if axis is None: | ||
| # v_np = np.array([25, 50, 75], dtype=np_dtype) | ||
| # else: | ||
| # # Create values that broadcast correctly | ||
| # # v needs to be broadcastable against a_no_axis (2D) | ||
| # # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2) | ||
| # v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1) | ||
|
|
||
| # v_mx = mx.array(v_np) | ||
|
|
||
| # result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis) | ||
| # result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis) | ||
| # self.assertTrue(np.array_equal(result_np, result_mx)) | ||
|
|
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment appears to contain commented-out code.
| # Test multi-dimensional with different axes | |
| # shape = (3, 4, 5) | |
| # for dtype in ("int32", "float32"): | |
| # for axis in (None, 0, 1, 2, -1): | |
| # with self.subTest(dtype=dtype, axis=axis): | |
| # np.random.seed(0) | |
| # np_dtype = getattr(np, dtype) | |
| # a_np = np.sort( | |
| # np.random.uniform(0, 100, size=shape).astype(np_dtype), | |
| # axis=axis, | |
| # ) | |
| # a_mx = mx.array(a_np) | |
| # # Create search values | |
| # if axis is None: | |
| # v_np = np.array([25, 50, 75], dtype=np_dtype) | |
| # else: | |
| # # Create values that broadcast correctly | |
| # # v needs to be broadcastable against a_no_axis (2D) | |
| # # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2) | |
| # v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1) | |
| # v_mx = mx.array(v_np) | |
| # result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis) | |
| # result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis) | |
| # self.assertTrue(np.array_equal(result_np, result_mx)) |
First time contributor, appreciate any feedback! Work in progress on feature req from #1255.
Proposed changes
In Progress:
[ ] Add functionality for axis parameter for parity with np.searchsorted (documentation has it currently)
[x] Run benchmarks
[x] (not optimized) Implement vectorized linear search vs non-vectorized binary search
[ ] implement eval_gpu.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes