Skip to content

Conversation

@ppaleja
Copy link

@ppaleja ppaleja commented Nov 21, 2025

First time contributor, appreciate any feedback! Work in progress on feature req from #1255.

Proposed changes

  • python/C++ api for mlx.core.searchsorted (caveats: basic logic on cpu using binary search).
  • Tests for searchsorted added to test_ops

In Progress:

[ ] Add functionality for axis parameter for parity with np.searchsorted (documentation has it currently)
[x] Run benchmarks
[x] (not optimized) Implement vectorized linear search vs non-vectorized binary search
[ ] implement eval_gpu.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Introduces a comprehensive Product Requirements Document for the new mlx.searchsorted feature, outlining API parity with NumPy, multi-dimensional support, axis-based operations, and GPU-accelerated implementation plans. Documents goals, functional/spec requirements, testing strategies, documentation plans, release notes, risks, and implementation phases. References the related issue ml-explore#1255.
Introduces a comprehensive test suite for the searchsorted operation, validating correctness across multiple dimensions, dtypes (int, float, half precision), and sides, including:
- Basic 1D behavior and dtype compatibility
- Handling of duplicates and edge cases (empty, single element, all duplicates)
- Multidimensional inputs with various axes and broadcasting rules
- Special values (inf, -inf, NaN) and out-of-range scenarios
- Output shape and dtype correctness for scalar, vector, and 2D value inputs
- Mixed precision scenarios to ensure robust type handling
- Large input arrays for performance/behavior checks
Introduces a .clangd configuration to enhance editor tooling and code quality.

- Enables strict diagnostic checks (UnusedIncludes, MissingIncludes) and enables Clang-Tidy with tailored rule sets.
- Configures compilation database and background indexing to speed up tooling feedback.
- Enables inlay hints (parameter names, deduced types) and hover improvements for better readability.
- Activates comprehensive completion across all scopes and improves developer experience with AKA hover support.
Drops unused header dependencies:
- removes transforms.h from the operations module
- removes backend/utils.h from primitives
Introduces a local numpy-based reference implementation for searchsorted that supports axis and broadcasting, intended to validate the multi-dimensional behavior against the library. Updates test expectations for scalar input on 2D arrays. Temporarily comments out extensive multi-dimensional test cases pending full parity, but leaves the groundwork for future validation and comparison.
Improves code hygiene by:
- Wrapping a long `set` command in `CMakeLists.txt` for better readability.
- Removing trailing whitespace in `test_ops.py`.
- Standardizing string literal quotes (single to double) in `test_ops.py`.
- Converting whitespace-only lines to truly empty lines.
@ppaleja ppaleja force-pushed the feature/searchsorted branch from 401acee to 94c7e08 Compare December 11, 2025 23:16
Introduces a fast-path for fully contiguous arrays to avoid unnecessary copying and allocation. Switches between linear, binary, and exponential search based on axis size, improving performance for both small and large inputs. Reduces overhead by using strided iterators directly for non-contiguous data.
@ppaleja
Copy link
Author

ppaleja commented Dec 12, 2025

Ran benchmarks on personal branch to avoid commit clutter profiling-searchsorted. Current implementation is faster (on cpu) than naive implementations from discussion in #1255 , but slower than pytorch searchsorted (by about 1.5x times).

@ppaleja ppaleja marked this pull request as ready for review December 12, 2025 20:09
Copilot AI review requested due to automatic review settings December 12, 2025 20:09
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds the searchsorted function to MLX, implementing binary search functionality similar to NumPy's searchsorted. This is a work in progress addressing feature request #1255, with CPU implementation complete but GPU implementation and full axis parameter support still pending.

Key Changes:

  • Implements C++ and Python APIs for mlx.core.searchsorted with binary search on CPU
  • Adds comprehensive test coverage for various edge cases, dtypes, and special values
  • Includes optimizations for different array sizes (linear scan for small arrays, exponential search for large ones)

Reviewed changes

Copilot reviewed 12 out of 13 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
mlx/ops.h Adds function declarations for searchsorted with and without axis parameter
mlx/ops.cpp Implements main searchsorted logic including type promotion and axis handling
mlx/primitives.h Defines SearchSorted primitive class with axis and side parameters
mlx/primitives.cpp Implements SearchSorted primitive methods including vmap and output shape computation
mlx/backend/cpu/sort.cpp Implements CPU evaluation with optimizations for contiguous arrays and different size ranges
python/src/ops.cpp Adds Python bindings with documentation for the searchsorted function
tests/ops_tests.cpp Adds C++ tests covering various dtypes, edge cases, and special values
python/tests/test_ops.py Adds Python tests verifying NumPy compatibility for basic functionality and edge cases
docs/src/python/ops.rst Adds searchsorted to the API documentation list
mlx/fast_primitives.h Adds missing include for mlx/array.h
CMakeLists.txt Adds virtualenv Python detection (unrelated to searchsorted)
.gitignore Adds personal development tool entries (unrelated to searchsorted)
.clangd Adds clangd configuration file (unrelated to searchsorted)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removed include for "mlx/transforms.h" may be needed elsewhere in the file. Only remove this include if you've verified that no part of ops.cpp requires declarations from transforms.h. The fact that transforms_impl.h is still included suggests transforms.h might have been providing necessary declarations.

Suggested change
#include "mlx/primitives.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"

Copilot uses AI. Check for mistakes.
if (to_ax != -1 && to_ax <= axis) {
axis++;
}
return {{searchsorted(a, v, right_, axis, stream())}, {to_ax}};
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vmap implementation has incorrect parameter order in the searchsorted call. The function is called as searchsorted(a, v, right_, axis, stream()) but the function signature in ops.h line 731-736 shows it should be searchsorted(a, v, axis, right, stream). The parameters right_ and axis are swapped, which will cause incorrect behavior or compilation errors.

Suggested change
return {{searchsorted(a, v, right_, axis, stream())}, {to_ax}};
return {{searchsorted(a, v, axis, right_, stream())}, {to_ax}};

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +5
CompileFlags:
CompilationDatabase: build

Index:
Background: Build
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .clangd configuration file addition appears to be a personal development environment preference unrelated to the searchsorted feature. This should either be removed or submitted as a separate pull request focused on development tooling improvements.

Suggested change
CompileFlags:
CompilationDatabase: build
Index:
Background: Build

Copilot uses AI. Check for mistakes.
Comment on lines +2786 to +2787
If there is no suitable index, return either 0 or N (where N is the
length of the array).
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation states "If there is no suitable index, return either 0 or N (where N is the length of the array)" but this is misleading. The function always returns a valid insertion index between 0 and N (inclusive). When no suitable index exists, it returns 0 for values less than all elements, or N for values greater than all elements. Consider rephrasing to: "Returns indices in the range [0, N] where N is the length of the sorted dimension."

Suggested change
If there is no suitable index, return either 0 or N (where N is the
length of the array).
Returns indices in the range [0, N] where N is the length of the sorted
dimension.

Copilot uses AI. Check for mistakes.
If there is no suitable index, return either 0 or N (where N is the
length of the array).
axis (int, optional): Axis along which to search. If not provided, the
array is flattened.
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The axis parameter documentation says "If not provided, the array is flattened" but looking at the implementation in ops.cpp, when axis is not provided, it calls reshape(a, {-1}, stream) which flattens the array and then searches along axis 0. This is correct but should be clarified in the documentation to match NumPy's behavior: "If not provided, the array is treated as if it were flattened and searched along a single axis."

Suggested change
array is flattened.
array is treated as if it were flattened and searched along a single axis (axis 0), matching NumPy's behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +18 to +30
if(NOT DEFINED Python_EXECUTABLE AND DEFINED ENV{VIRTUAL_ENV})
if(WIN32)
set(_venv_py "$ENV{VIRTUAL_ENV}/Scripts/python.exe")
else()
set(_venv_py "$ENV{VIRTUAL_ENV}/bin/python")
endif()
if(EXISTS ${_venv_py})
message(STATUS "Using Python from virtualenv: ${_venv_py}")
set(Python_EXECUTABLE
${_venv_py}
CACHE FILEPATH "Python executable" FORCE)
endif()
endif()
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CMakeLists.txt changes adding virtualenv Python detection appear unrelated to the searchsorted feature. These changes should be submitted as a separate pull request to keep changes focused and easier to review. Including unrelated build system improvements in a feature PR makes it harder to track what changes are part of the feature vs. infrastructure improvements.

Suggested change
if(NOT DEFINED Python_EXECUTABLE AND DEFINED ENV{VIRTUAL_ENV})
if(WIN32)
set(_venv_py "$ENV{VIRTUAL_ENV}/Scripts/python.exe")
else()
set(_venv_py "$ENV{VIRTUAL_ENV}/bin/python")
endif()
if(EXISTS ${_venv_py})
message(STATUS "Using Python from virtualenv: ${_venv_py}")
set(Python_EXECUTABLE
${_venv_py}
CACHE FILEPATH "Python executable" FORCE)
endif()
endif()

Copilot uses AI. Check for mistakes.
result_mx = mx.searchsorted(a_mx, v_mx)
self.assertTrue(np.array_equal(result_np, result_mx))

def test_searchsorted_int64_promotion(self):
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test name "test_searchsorted_int64_promotion" is misleading as it tests much more than just int64 type promotion. It includes tests for single element arrays, duplicate values, out of range values, special float values (inf, -inf, NaN), output shapes, dtypes, large arrays, 2D arrays, and mixed precision. Consider either renaming this to something more general like "test_searchsorted_edge_cases" or splitting it into multiple focused test methods.

Suggested change
def test_searchsorted_int64_promotion(self):
def test_searchsorted_edge_cases(self):

Copilot uses AI. Check for mistakes.
# Test output dtype is integer
a_mx = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
result = mx.searchsorted(a_mx, mx.array([2.5, 4.5]))
self.assertTrue(result.dtype in [mx.int32, mx.int64, mx.uint32, mx.uint64])
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

Suggested change
self.assertTrue(result.dtype in [mx.int32, mx.int64, mx.uint32, mx.uint64])
self.assertIn(result.dtype, [mx.int32, mx.int64, mx.uint32, mx.uint64])

Copilot uses AI. Check for mistakes.
Comment on lines +2264 to +2292
# shape = (3, 4, 5)
# for dtype in ("int32", "float32"):
# for axis in (None, 0, 1, 2, -1):
# with self.subTest(dtype=dtype, axis=axis):
# np.random.seed(0)
# np_dtype = getattr(np, dtype)
# a_np = np.sort(
# np.random.uniform(0, 100, size=shape).astype(np_dtype),
# axis=axis,
# )
# a_mx = mx.array(a_np)

# # Create search values
# if axis is None:
# v_np = np.array([25, 50, 75], dtype=np_dtype)
# else:
# # Create values that broadcast correctly
# # v needs to be broadcastable against a_no_axis (2D)
# # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2)
# v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1)

# v_mx = mx.array(v_np)

# result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis)
# result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis)
# self.assertTrue(np.array_equal(result_np, result_mx))

# Test edge cases
# Empty array
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Suggested change
# shape = (3, 4, 5)
# for dtype in ("int32", "float32"):
# for axis in (None, 0, 1, 2, -1):
# with self.subTest(dtype=dtype, axis=axis):
# np.random.seed(0)
# np_dtype = getattr(np, dtype)
# a_np = np.sort(
# np.random.uniform(0, 100, size=shape).astype(np_dtype),
# axis=axis,
# )
# a_mx = mx.array(a_np)
# # Create search values
# if axis is None:
# v_np = np.array([25, 50, 75], dtype=np_dtype)
# else:
# # Create values that broadcast correctly
# # v needs to be broadcastable against a_no_axis (2D)
# # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2)
# v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1)
# v_mx = mx.array(v_np)
# result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis)
# result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis)
# self.assertTrue(np.array_equal(result_np, result_mx))
# Test edge cases
# Empty array
# Test edge cases

Copilot uses AI. Check for mistakes.
Comment on lines +2263 to +2290
# Test multi-dimensional with different axes
# shape = (3, 4, 5)
# for dtype in ("int32", "float32"):
# for axis in (None, 0, 1, 2, -1):
# with self.subTest(dtype=dtype, axis=axis):
# np.random.seed(0)
# np_dtype = getattr(np, dtype)
# a_np = np.sort(
# np.random.uniform(0, 100, size=shape).astype(np_dtype),
# axis=axis,
# )
# a_mx = mx.array(a_np)

# # Create search values
# if axis is None:
# v_np = np.array([25, 50, 75], dtype=np_dtype)
# else:
# # Create values that broadcast correctly
# # v needs to be broadcastable against a_no_axis (2D)
# # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2)
# v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1)

# v_mx = mx.array(v_np)

# result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis)
# result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis)
# self.assertTrue(np.array_equal(result_np, result_mx))

Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Suggested change
# Test multi-dimensional with different axes
# shape = (3, 4, 5)
# for dtype in ("int32", "float32"):
# for axis in (None, 0, 1, 2, -1):
# with self.subTest(dtype=dtype, axis=axis):
# np.random.seed(0)
# np_dtype = getattr(np, dtype)
# a_np = np.sort(
# np.random.uniform(0, 100, size=shape).astype(np_dtype),
# axis=axis,
# )
# a_mx = mx.array(a_np)
# # Create search values
# if axis is None:
# v_np = np.array([25, 50, 75], dtype=np_dtype)
# else:
# # Create values that broadcast correctly
# # v needs to be broadcastable against a_no_axis (2D)
# # Reshape to (2, 1, 1) to broadcast to (2, dim1, dim2)
# v_np = np.array([25, 50], dtype=np_dtype).reshape(2, 1, 1)
# v_mx = mx.array(v_np)
# result_np = searchsorted_numpy(a_np, v_np, side="left", axis=axis)
# result_mx = mx.searchsorted(a_mx, v_mx, side="left", axis=axis)
# self.assertTrue(np.array_equal(result_np, result_mx))

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant