From 9337a44ead474cde36830e658694938d3d821788 Mon Sep 17 00:00:00 2001 From: pyk <2213646+pyk@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:52:16 +0700 Subject: [PATCH] fix(run): prevent constant folding --- src/root.zig | 108 ++++++++++++++++++++++++++++++++++----------------- src/test.zig | 34 ++++++++++++++++ 2 files changed, 107 insertions(+), 35 deletions(-) diff --git a/src/root.zig b/src/root.zig index b8dd3e0..05f2620 100644 --- a/src/root.zig +++ b/src/root.zig @@ -45,9 +45,12 @@ pub const ReportOptions = struct { pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anytype, options: Options) !Metrics { assertFunctionDef(function, args); + // ref: https://pyk.sh/blog/2025-12-08-bench-fixing-constant-folding + var runtime_args = createRuntimeArgs(function, args); + const volatile_args_ptr: *volatile @TypeOf(runtime_args) = &runtime_args; + for (0..options.warmup_iters) |_| { - std.mem.doNotOptimizeAway(args); - try execute(function, args); + try execute(function, volatile_args_ptr.*); } // We need to determine a batch_size such that the total execution time of the batch @@ -60,8 +63,7 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt while (true) { timer.reset(); for (0..batch_size) |_| { - std.mem.doNotOptimizeAway(args); - try execute(function, args); + try execute(function, volatile_args_ptr.*); } const duration = timer.read(); @@ -87,8 +89,7 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt for (0..options.sample_size) |i| { timer.reset(); for (0..batch_size) |_| { - std.mem.doNotOptimizeAway(args); - try execute(function, args); + try execute(function, volatile_args_ptr.*); } const total_ns = timer.read(); // Average time per operation for this batch @@ -141,8 +142,7 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt try perf.capture(); for (0..options.sample_size) |_| { for (0..batch_size) |_| { - std.mem.doNotOptimizeAway(args); - try execute(function, args); + try execute(function, volatile_args_ptr.*); } } try perf.stop(); @@ -165,48 +165,86 @@ pub fn run(allocator: Allocator, name: []const u8, function: anytype, args: anyt return metrics; } +inline fn execute(function: anytype, args: anytype) !void { + const FnType = unwrapFnType(@TypeOf(function)); + const return_type = @typeInfo(FnType).@"fn".return_type.?; + // Conditional execution based on whether the function can fail + if (@typeInfo(return_type) == .error_union) { + const result = try @call(.auto, function, args); + std.mem.doNotOptimizeAway(result); + } else { + const result = @call(.auto, function, args); + std.mem.doNotOptimizeAway(result); + } +} + +/// Returns the underlying Function type, unwrapping it if it is a pointer. +fn unwrapFnType(comptime T: type) type { + if (@typeInfo(T) == .pointer) return @typeInfo(T).pointer.child; + return T; +} + +//////////////////////////////////////////////////////////////////////////////// +// Function definition checker + fn assertFunctionDef(function: anytype, args: anytype) void { - // Verify args is a tuple const ArgsType = @TypeOf(args); const args_info = @typeInfo(ArgsType); if (args_info != .@"struct" or !args_info.@"struct".is_tuple) { - @compileError("Expected 'args' to be a tuple, found " ++ @typeName(ArgsType)); + @compileError("Expected 'args' to be a tuple, found '" ++ @typeName(ArgsType) ++ "'"); } - // Unwrap function type - const FnType = @TypeOf(function); - const UnwrappedFnType = if (@typeInfo(FnType) == .pointer) - @typeInfo(FnType).pointer.child - else - FnType; - const fn_info = @typeInfo(UnwrappedFnType); - if (fn_info != .@"fn") { - @compileError("Expected 'function' to be a function or function pointer, found " ++ @typeName(@TypeOf(function))); + const FnType = unwrapFnType(@TypeOf(function)); + if (@typeInfo(FnType) != .@"fn") { + @compileError("Expected 'function' to be a function or function pointer, found '" ++ @typeName(@TypeOf(function)) ++ "'"); } - // Verify argument count matches - if (fn_info.@"fn".params.len != args_info.@"struct".fields.len) { + const params_len = @typeInfo(FnType).@"fn".params.len; + const args_len = @typeInfo(ArgsType).@"struct".fields.len; + + if (params_len != args_len) { @compileError(std.fmt.comptimePrint( "Function expects {d} arguments, but args tuple has {d}", - .{ fn_info.@"fn".params.len, args_info.@"struct".fields.len }, + .{ params_len, args_len }, )); } } -inline fn execute(function: anytype, args: anytype) !void { - const FnType = @TypeOf(function); - const UnwrappedFnType = if (@typeInfo(FnType) == .pointer) - @typeInfo(FnType).pointer.child - else - FnType; - const return_type = @typeInfo(UnwrappedFnType).@"fn".return_type.?; - if (@typeInfo(return_type) == .error_union) { - const result = try @call(.auto, function, args); - std.mem.doNotOptimizeAway(result); - } else { - const result = @call(.auto, function, args); - std.mem.doNotOptimizeAway(result); +//////////////////////////////////////////////////////////////////////////////// +// Runtime Arguments Helpers + +/// Constructs the runtime argument tuple based on function parameters and input args. +fn createRuntimeArgs(function: anytype, args: anytype) RuntimeArgsType(@TypeOf(function), @TypeOf(args)) { + const TupleType = RuntimeArgsType(@TypeOf(function), @TypeOf(args)); + var runtime_args: TupleType = undefined; + + // We only need the length here to iterate + const fn_params = getFnParams(@TypeOf(function)); + + inline for (0..fn_params.len) |i| { + runtime_args[i] = args[i]; } + return runtime_args; +} + +/// Computes the precise Tuple type required to hold the arguments. +fn RuntimeArgsType(comptime FnType: type, comptime ArgsType: type) type { + const fn_params = getFnParams(FnType); + const args_fields = @typeInfo(ArgsType).@"struct".fields; + comptime var types: [fn_params.len]type = undefined; + inline for (fn_params, 0..) |p, i| { + if (p.type) |t| { + types[i] = t; + } else { + types[i] = args_fields[i].type; + } + } + return std.meta.Tuple(&types); +} + +/// Helper to unwrap function pointers and retrieve parameter info +fn getFnParams(comptime FnType: type) []const std.builtin.Type.Fn.Param { + return @typeInfo(unwrapFnType(FnType)).@"fn".params; } //////////////////////////////////////////////////////////////////////////////// diff --git a/src/test.zig b/src/test.zig index 95f6140..c9462a4 100644 --- a/src/test.zig +++ b/src/test.zig @@ -215,3 +215,37 @@ test "run: suppported signatures" { _ = try bench.run(allocator, "functionReturnValue", functionReturnValue, .{}, .{}); _ = try bench.run(allocator, "functionReturnValueError", functionReturnValueError, .{}, .{}); } + +/////////////////////////////////////////////////////////////////////////////// +// Fibonacci + +fn fibNaive(n: u64) u64 { + if (n <= 1) return n; + return fibNaive(n - 1) + fibNaive(n - 2); +} + +fn fibIterative(n: u64) u64 { + if (n == 0) return 0; + + var a: u64 = 0; + var b: u64 = 1; + for (2..n + 1) |_| { + const c = a + b; + a = b; + b = c; + } + + return b; +} + +test "run: fibonacci" { + const allocator = std.heap.smp_allocator; + const opts = bench.Options{ + .sample_size = 100, + .warmup_iters = 3, + }; + const m_naive = try bench.run(allocator, "fibNaive", fibNaive, .{@as(u64, 30)}, opts); + const m_iter = try bench.run(allocator, "fibIterative", fibIterative, .{@as(u64, 30)}, opts); + + try testing.expect(m_naive.mean_ns > m_iter.mean_ns * 100); +}