Skip to content

Commit 401acee

Browse files
committed
Add standalone searchsorted benchmark for MLX and NumPy
1 parent 6f01658 commit 401acee

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import argparse
2+
import time
3+
import numpy as np
4+
5+
try:
6+
import mlx.core as mx
7+
except Exception as e:
8+
mx = None
9+
10+
def time_fn(fn, iters: int = 10):
11+
# Simple timing helper: run fn() iters times and return average seconds
12+
start = time.perf_counter()
13+
for _ in range(iters):
14+
fn()
15+
end = time.perf_counter()
16+
return (end - start) / iters
17+
18+
19+
def bench_searchsorted_mx(a_sizes, v_sizes, side, dtype):
20+
if mx is None:
21+
raise RuntimeError("mlx.core not available. Install MLX Python first.")
22+
23+
results = []
24+
for n in a_sizes:
25+
for m in v_sizes:
26+
# Create sorted array 'a' and values 'v'
27+
a_np = np.sort(np.random.rand(n).astype(dtype))
28+
v_np = np.random.rand(m).astype(dtype)
29+
30+
a = mx.array(a_np)
31+
v = mx.array(v_np)
32+
33+
# Warm-up
34+
idx = mx.searchsorted(a, v, side=side)
35+
mx.eval(idx)
36+
37+
def _run():
38+
out = mx.searchsorted(a, v, side=side)
39+
mx.eval(out)
40+
return out
41+
42+
t = time_fn(_run)
43+
results.append((n, m, t))
44+
return results
45+
46+
47+
def bench_searchsorted_numpy(a_sizes, v_sizes, side, dtype):
48+
results = []
49+
for n in a_sizes:
50+
for m in v_sizes:
51+
a = np.sort(np.random.rand(n).astype(dtype))
52+
v = np.random.rand(m).astype(dtype)
53+
54+
# Warm-up
55+
_ = np.searchsorted(a, v, side=side)
56+
57+
def _run():
58+
return np.searchsorted(a, v, side=side)
59+
60+
t = time_fn(_run)
61+
results.append((n, m, t))
62+
return results
63+
64+
65+
def fmt_results(tag, results):
66+
print(f"\n{tag} results (a_size, v_size, time_ms):")
67+
for n, m, t in results:
68+
print(f"{n:>8} {m:>8} {t*1e3:8.3f}")
69+
70+
71+
def main():
72+
parser = argparse.ArgumentParser(description="Benchmark searchsorted for MLX vs NumPy")
73+
parser.add_argument("--side", choices=["left", "right"], default="left")
74+
parser.add_argument("--dtype", choices=["float32", "float64"], default="float32")
75+
parser.add_argument("--a-sizes", type=int, nargs="*", default=[1_000, 10_000, 100_000, 1_000_000])
76+
parser.add_argument("--v-sizes", type=int, nargs="*", default=[10, 100, 1_000, 10_000])
77+
args = parser.parse_args()
78+
79+
dtype = np.float32 if args.dtype == "float32" else np.float64
80+
81+
np_results = bench_searchsorted_numpy(args.a_sizes, args.v_sizes, args.side, dtype)
82+
fmt_results("NumPy", np_results)
83+
84+
try:
85+
mx_results = bench_searchsorted_mx(args.a_sizes, args.v_sizes, args.side, dtype)
86+
fmt_results("MLX", mx_results)
87+
except Exception as e:
88+
print(f"\nMLX benchmark skipped: {e}")
89+
90+
91+
if __name__ == "__main__":
92+
main()

0 commit comments

Comments
 (0)