Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 73 additions & 35 deletions src/root.zig
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ pub const ReportOptions = struct {
pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anytype, options: Options) !Metrics {
assertFunctionDef(function, args);

// ref: https://pyk.sh/blog/2025-12-08-bench-fixing-constant-folding
var runtime_args = createRuntimeArgs(function, args);
const volatile_args_ptr: *volatile @TypeOf(runtime_args) = &runtime_args;

for (0..options.warmup_iters) |_| {
std.mem.doNotOptimizeAway(args);
try execute(function, args);
try execute(function, volatile_args_ptr.*);
}

// We need to determine a batch_size such that the total execution time of the batch
Expand All @@ -60,8 +63,7 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt
while (true) {
timer.reset();
for (0..batch_size) |_| {
std.mem.doNotOptimizeAway(args);
try execute(function, args);
try execute(function, volatile_args_ptr.*);
}
const duration = timer.read();

Expand All @@ -87,8 +89,7 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt
for (0..options.sample_size) |i| {
timer.reset();
for (0..batch_size) |_| {
std.mem.doNotOptimizeAway(args);
try execute(function, args);
try execute(function, volatile_args_ptr.*);
}
const total_ns = timer.read();
// Average time per operation for this batch
Expand Down Expand Up @@ -141,8 +142,7 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt
try perf.capture();
for (0..options.sample_size) |_| {
for (0..batch_size) |_| {
std.mem.doNotOptimizeAway(args);
try execute(function, args);
try execute(function, volatile_args_ptr.*);
}
}
try perf.stop();
Expand All @@ -165,48 +165,86 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt
return metrics;
}

inline fn execute(function: anytype, args: anytype) !void {
const FnType = unwrapFnType(@TypeOf(function));
const return_type = @typeInfo(FnType).@"fn".return_type.?;
// Conditional execution based on whether the function can fail
if (@typeInfo(return_type) == .error_union) {
const result = try @call(.auto, function, args);
std.mem.doNotOptimizeAway(result);
} else {
const result = @call(.auto, function, args);
std.mem.doNotOptimizeAway(result);
}
}

/// Returns the underlying Function type, unwrapping it if it is a pointer.
fn unwrapFnType(comptime T: type) type {
if (@typeInfo(T) == .pointer) return @typeInfo(T).pointer.child;
return T;
}

////////////////////////////////////////////////////////////////////////////////
// Function definition checker

fn assertFunctionDef(function: anytype, args: anytype) void {
// Verify args is a tuple
const ArgsType = @TypeOf(args);
const args_info = @typeInfo(ArgsType);
if (args_info != .@"struct" or !args_info.@"struct".is_tuple) {
@compileError("Expected 'args' to be a tuple, found " ++ @typeName(ArgsType));
@compileError("Expected 'args' to be a tuple, found '" ++ @typeName(ArgsType) ++ "'");
}

// Unwrap function type
const FnType = @TypeOf(function);
const UnwrappedFnType = if (@typeInfo(FnType) == .pointer)
@typeInfo(FnType).pointer.child
else
FnType;
const fn_info = @typeInfo(UnwrappedFnType);
if (fn_info != .@"fn") {
@compileError("Expected 'function' to be a function or function pointer, found " ++ @typeName(@TypeOf(function)));
const FnType = unwrapFnType(@TypeOf(function));
if (@typeInfo(FnType) != .@"fn") {
@compileError("Expected 'function' to be a function or function pointer, found '" ++ @typeName(@TypeOf(function)) ++ "'");
}

// Verify argument count matches
if (fn_info.@"fn".params.len != args_info.@"struct".fields.len) {
const params_len = @typeInfo(FnType).@"fn".params.len;
const args_len = @typeInfo(ArgsType).@"struct".fields.len;

if (params_len != args_len) {
@compileError(std.fmt.comptimePrint(
"Function expects {d} arguments, but args tuple has {d}",
.{ fn_info.@"fn".params.len, args_info.@"struct".fields.len },
.{ params_len, args_len },
));
}
}

inline fn execute(function: anytype, args: anytype) !void {
const FnType = @TypeOf(function);
const UnwrappedFnType = if (@typeInfo(FnType) == .pointer)
@typeInfo(FnType).pointer.child
else
FnType;
const return_type = @typeInfo(UnwrappedFnType).@"fn".return_type.?;
if (@typeInfo(return_type) == .error_union) {
const result = try @call(.auto, function, args);
std.mem.doNotOptimizeAway(result);
} else {
const result = @call(.auto, function, args);
std.mem.doNotOptimizeAway(result);
////////////////////////////////////////////////////////////////////////////////
// Runtime Arguments Helpers

/// Constructs the runtime argument tuple based on function parameters and input args.
fn createRuntimeArgs(function: anytype, args: anytype) RuntimeArgsType(@TypeOf(function), @TypeOf(args)) {
const TupleType = RuntimeArgsType(@TypeOf(function), @TypeOf(args));
var runtime_args: TupleType = undefined;

// We only need the length here to iterate
const fn_params = getFnParams(@TypeOf(function));

inline for (0..fn_params.len) |i| {
runtime_args[i] = args[i];
}
return runtime_args;
}

/// Computes the precise Tuple type required to hold the arguments.
fn RuntimeArgsType(comptime FnType: type, comptime ArgsType: type) type {
const fn_params = getFnParams(FnType);
const args_fields = @typeInfo(ArgsType).@"struct".fields;
comptime var types: [fn_params.len]type = undefined;
inline for (fn_params, 0..) |p, i| {
if (p.type) |t| {
types[i] = t;
} else {
types[i] = args_fields[i].type;
}
}
return std.meta.Tuple(&types);
}

/// Helper to unwrap function pointers and retrieve parameter info
fn getFnParams(comptime FnType: type) []const std.builtin.Type.Fn.Param {
return @typeInfo(unwrapFnType(FnType)).@"fn".params;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
34 changes: 34 additions & 0 deletions src/test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,37 @@ test "run: suppported signatures" {
_ = try bench.run(allocator, "functionReturnValue", functionReturnValue, .{}, .{});
_ = try bench.run(allocator, "functionReturnValueError", functionReturnValueError, .{}, .{});
}

///////////////////////////////////////////////////////////////////////////////
// Fibonacci

fn fibNaive(n: u64) u64 {
if (n <= 1) return n;
return fibNaive(n - 1) + fibNaive(n - 2);
}

fn fibIterative(n: u64) u64 {
if (n == 0) return 0;

var a: u64 = 0;
var b: u64 = 1;
for (2..n + 1) |_| {
const c = a + b;
a = b;
b = c;
}

return b;
}

test "run: fibonacci" {
const allocator = std.heap.smp_allocator;
const opts = bench.Options{
.sample_size = 100,
.warmup_iters = 3,
};
const m_naive = try bench.run(allocator, "fibNaive", fibNaive, .{@as(u64, 30)}, opts);
const m_iter = try bench.run(allocator, "fibIterative", fibIterative, .{@as(u64, 30)}, opts);

try testing.expect(m_naive.mean_ns > m_iter.mean_ns * 100);
}