From eb2351c4a0877970f97e6f60875edd5812ebc803 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 6 Jan 2026 10:01:55 +0000 Subject: [PATCH 1/7] Add optimal & reference implementations with tests and benchmarks. --- examples/OneTrans/batched_gemm_example.py | 362 ++++++++++++++++++++++ 1 file changed, 362 insertions(+) create mode 100644 examples/OneTrans/batched_gemm_example.py diff --git a/examples/OneTrans/batched_gemm_example.py b/examples/OneTrans/batched_gemm_example.py new file mode 100644 index 000000000..00e584980 --- /dev/null +++ b/examples/OneTrans/batched_gemm_example.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped Linear Layer with Strided BMM Optimization + +================================================================================ +Problem +================================================================================ +Apply num_groups different linear transformations to corresponding slices of input: + + Input: x of shape (B * num_groups, input_dim) + Output: y of shape (B * num_groups, output_dim) + + For each group n: y[b, n, :] = x[b, n, :] @ W[n, :, :] + +================================================================================ +Reference Implementation +================================================================================ +The straightforward approach uses a loop over groups: + + x = x.reshape(B, num_groups, D_in) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(num_groups): + x_i = x_split[i].squeeze(1) # (B, D_in) + out_i = linear_layers[i](x_i) # (B, D_out) + out_list.append(out_i) + + output = torch.stack(out_list, dim=1).reshape(-1, D_out) + +================================================================================ +Optimized Implementation +================================================================================ +Use torch.bmm with strided output to fuse all GEMMs into one kernel: + + x = x.reshape(B, num_groups, D_in) + output = torch.empty(B, num_groups, D_out, ...) # pre-allocate final layout + torch.bmm(x.permute(1,0,2), weight, + out=output.permute(1,0,2)) # cuBLAS writes to strided memory + return output.view(-1, D_out) # O(1) view, no copy. + +Key feature: cuBLAS strided batched GEMM supports strided output via ldc/strideC +parameters, allowing direct write to the transposed memory layout. + +================================================================================ +Performance Results +================================================================================ +Config: batch_size=2560, num_groups=12, input_dim=1024, output_dim=3072, dtype=bf16 +Device: NVIDIA H100 +BMM_Opt Forward: 1.46x +BMM_Opt Forward+Backward:1.41x + +================================================================================ +""" + +import argparse +from typing import List, Tuple + +import torch +import torch.nn as nn + + +def warmup_gpu(): + """Warmup GPU to get stable timing""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +class ReferenceImpl(nn.Module): + """ + Reference implementation using reshape + split + loop + stack. + Simple but slow due to multiple kernel launches. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_groups: int, + device="cuda", + dtype=torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear_layers = nn.ModuleList( + [ + nn.Linear(input_dim, output_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ] + ) + + for layer in self.linear_layers: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # reshape: (B*ns, D) -> (B, ns, D) + x = x.reshape(-1, self.num_groups, self.input_dim) + + # split and loop + x_split = torch.split(x, 1, dim=1) + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) # (B, D) + out_i = self.linear_layers[i](x_i) # (B, D_out) + out_list.append(out_i) + + # stack: ns * (B, D_out) -> (B, ns, D_out) -> (B*ns, D_out) + return torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + +class StridedBmmFunction(torch.autograd.Function): + """Custom autograd function for BMM with strided output.""" + + @staticmethod + def forward(ctx, x, weight, batch_size, num_groups, output_dim): + ctx.save_for_backward(x, weight) + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + grad_x = grad_weight = None + grad_output_t = grad_output.permute(1, 0, 2) + + if ctx.needs_input_grad[0]: + grad_x = torch.empty_like(x) + torch.bmm( + grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2) + ) + + if ctx.needs_input_grad[1]: + grad_weight = torch.bmm(x.permute(1, 0, 2).transpose(-1, -2), grad_output_t) + + return grad_x, grad_weight, None, None, None + + +class BmmImpl(nn.Module): + """ + Optimized implementation using strided BMM. + Single kernel launch with fused permute via strided output. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_groups: int, + device="cuda", + dtype=torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.output_dim = output_dim + + self.weight = nn.Parameter( + torch.empty(num_groups, input_dim, output_dim, device=device, dtype=dtype) + ) + for i in range(num_groups): + nn.init.xavier_normal_(self.weight[i], gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + x = x.reshape(batch_size, self.num_groups, self.input_dim) + output = StridedBmmFunction.apply( + x, self.weight, batch_size, self.num_groups, self.output_dim + ) + return output.view(-1, self.output_dim) + + +def copy_weights(ref_model: ReferenceImpl, opt_model: BmmImpl): + """Copy weights from reference to optimized model.""" + with torch.no_grad(): + for i in range(ref_model.num_groups): + opt_model.weight[i].copy_(ref_model.linear_layers[i].weight.T) + + +def check_correctness( + ref_model: ReferenceImpl, + opt_model: BmmImpl, + batch_size: int, + num_groups: int, + input_dim: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[float, float, float]: + """Check forward and backward correctness.""" + # Forward check + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + with torch.no_grad(): + ref_out = ref_model(x) + opt_out = opt_model(x) + fwd_diff = (ref_out - opt_out).abs().max().item() + + # Backward check + x_ref = torch.randn( + batch_size * num_groups, + input_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + x_opt = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + opt_out = opt_model(x_opt) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + opt_out.backward(grad_output) + + # Input gradient + bwd_x_diff = (x_ref.grad - x_opt.grad).abs().max().item() + + # Weight gradient + ref_weight_grad = torch.stack( + [ref_model.linear_layers[i].weight.grad.T for i in range(num_groups)] + ) + bwd_w_diff = (ref_weight_grad - opt_model.weight.grad).abs().max().item() + + return fwd_diff, bwd_x_diff, bwd_w_diff + + +def benchmark( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + with_backward: bool = False, +) -> float: + """Benchmark forward or forward+backward pass using CUDA events for accurate GPU timing.""" + if with_backward: + x_list = [xi.requires_grad_(True) for xi in x_list] + grad_outputs = [ + torch.randn(xi.shape[0], model.output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + params = list(model.parameters()) + + # Warmup + for i in range(num_warmup): + xi = x_list[i % len(x_list)] + out = model(xi) + if with_backward: + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + for i in range(num_iterations): + xi = x_list[i % len(x_list)] + out = model(xi) + if with_backward: + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations # ms + + +def main(): + parser = argparse.ArgumentParser( + description="Grouped GEMM: Reference vs Strided BMM" + ) + parser.add_argument( + "--batch-size", type=int, default=2560, help="Batch size to test" + ) + parser.add_argument("--num-groups", type=int, default=12, help="Number of groups") + parser.add_argument("--input-dim", type=int, default=1024, help="Input dimension") + parser.add_argument("--output-dim", type=int, default=3072, help="Output dimension") + parser.add_argument( + "--iterations", type=int, default=100, help="Number of iterations for timing" + ) + args = parser.parse_args() + + torch.cuda.init() + + # Configuration from args + batch_size = args.batch_size + num_groups = args.num_groups + input_dim = args.input_dim + output_dim = args.output_dim + dtype = torch.bfloat16 + num_iterations = args.iterations + + print("=" * 60) + print("Grouped GEMM: Reference vs Strided BMM") + print("=" * 60) + print( + f"\nConfig: B={batch_size}, groups={num_groups}, D_in={input_dim}, D_out={output_dim}" + ) + print(f"Device: {torch.cuda.get_device_name(0)}") + + # Warmup GPU + print("\nWarming up GPU...") + warmup_gpu() + + # Create models + ref_model = ReferenceImpl(input_dim, output_dim, num_groups, dtype=dtype).cuda() + opt_model = BmmImpl(input_dim, output_dim, num_groups, dtype=dtype).cuda() + copy_weights(ref_model, opt_model) + + # Correctness check + print("\n" + "-" * 40) + print("Correctness Check") + print("-" * 40) + fwd_diff, bwd_x_diff, bwd_w_diff = check_correctness( + ref_model, opt_model, batch_size, num_groups, input_dim, dtype + ) + print(f"Forward max diff: {fwd_diff:.2e} {'✓' if fwd_diff < 1e-3 else '✗'}") + print(f"Backward dL/dx diff: {bwd_x_diff:.2e} {'✓' if bwd_x_diff < 1e-3 else '✗'}") + print(f"Backward dL/dW diff: {bwd_w_diff:.2e} {'✓' if bwd_w_diff < 1e-3 else '✗'}") + + # Benchmark + print("\n" + "-" * 40) + print("Performance Benchmark") + print("-" * 40) + + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Forward only + ref_fwd = benchmark(ref_model, x_list, num_iterations, with_backward=False) + opt_fwd = benchmark(opt_model, x_list, num_iterations, with_backward=False) + + print(f"\nForward pass (ms):") + print(f" Reference (loop): {ref_fwd:.4f}") + print(f" Optimized (BMM): {opt_fwd:.4f}") + print(f" Speedup: {ref_fwd/opt_fwd:.2f}x") + + # Forward + Backward + ref_fwdbwd = benchmark(ref_model, x_list, num_iterations, with_backward=True) + opt_fwdbwd = benchmark(opt_model, x_list, num_iterations, with_backward=True) + + print(f"\nForward + Backward (ms):") + print(f" Reference (loop): {ref_fwdbwd:.4f}") + print(f" Optimized (BMM): {opt_fwdbwd:.4f}") + print(f" Speedup: {ref_fwdbwd/opt_fwdbwd:.2f}x") + + print("\n" + "=" * 60) + print("Done!") + print("=" * 60) + + +if __name__ == "__main__": + main() From 553387d7605a5abc841c555fe0371af11593aad2 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Wed, 7 Jan 2026 10:11:37 +0000 Subject: [PATCH 2/7] Add parameter batch_first to generalize GroupedLinear. --- .../ops}/batched_gemm_example.py | 90 ++++++++++++++----- 1 file changed, 66 insertions(+), 24 deletions(-) rename examples/{OneTrans => commons/ops}/batched_gemm_example.py (79%) diff --git a/examples/OneTrans/batched_gemm_example.py b/examples/commons/ops/batched_gemm_example.py similarity index 79% rename from examples/OneTrans/batched_gemm_example.py rename to examples/commons/ops/batched_gemm_example.py index 00e584980..71fe3633c 100644 --- a/examples/OneTrans/batched_gemm_example.py +++ b/examples/commons/ops/batched_gemm_example.py @@ -119,36 +119,64 @@ class StridedBmmFunction(torch.autograd.Function): """Custom autograd function for BMM with strided output.""" @staticmethod - def forward(ctx, x, weight, batch_size, num_groups, output_dim): + def forward(ctx, x, weight, batch_size, num_groups, output_dim, batch_first): ctx.save_for_backward(x, weight) - output = torch.empty( - batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype - ) - torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + ctx.batch_first = batch_first + + if batch_first: + # x: [B, G, D] -> need permute to [G, B, D] for bmm + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + else: + # x: [G, B, D] -> already in correct layout for bmm + output = torch.bmm(x, weight) + return output @staticmethod def backward(ctx, grad_output): x, weight = ctx.saved_tensors + batch_first = ctx.batch_first grad_x = grad_weight = None - grad_output_t = grad_output.permute(1, 0, 2) - if ctx.needs_input_grad[0]: - grad_x = torch.empty_like(x) - torch.bmm( - grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2) - ) + if batch_first: + # grad_output: [B, G, D_out] + grad_output_t = grad_output.permute(1, 0, 2) # [G, B, D_out] - if ctx.needs_input_grad[1]: - grad_weight = torch.bmm(x.permute(1, 0, 2).transpose(-1, -2), grad_output_t) + if ctx.needs_input_grad[0]: + grad_x = torch.empty_like(x) + torch.bmm( + grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2) + ) - return grad_x, grad_weight, None, None, None + if ctx.needs_input_grad[1]: + grad_weight = torch.bmm( + x.permute(1, 0, 2).transpose(-1, -2), grad_output_t + ) + else: + # grad_output: [G, B, D_out] + if ctx.needs_input_grad[0]: + grad_x = torch.bmm(grad_output, weight.transpose(-1, -2)) + if ctx.needs_input_grad[1]: + grad_weight = torch.bmm(x.transpose(-1, -2), grad_output) -class BmmImpl(nn.Module): + return grad_x, grad_weight, None, None, None, None + + +class GroupedLinear(nn.Module): """ - Optimized implementation using strided BMM. - Single kernel launch with fused permute via strided output. + Grouped linear layer: applies num_groups different linear transforms in parallel. + Optimized using batched GEMM with strided output for single kernel launch. + + Args: + input_dim: Input feature dimension. + output_dim: Output feature dimension. + num_groups: Number of independent linear transforms. + batch_first: If True, input layout is [B, G, D] (batch-first, default). + If False, input layout is [G, B, D] (group-first). """ def __init__( @@ -158,11 +186,13 @@ def __init__( num_groups: int, device="cuda", dtype=torch.bfloat16, + batch_first: bool = True, ): super().__init__() self.num_groups = num_groups self.input_dim = input_dim self.output_dim = output_dim + self.batch_first = batch_first self.weight = nn.Parameter( torch.empty(num_groups, input_dim, output_dim, device=device, dtype=dtype) @@ -172,14 +202,26 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size = x.shape[0] // self.num_groups - x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.batch_first: + # Input flattened from [B, G, D] -> reshape to [B, G, D] + x = x.reshape(batch_size, self.num_groups, self.input_dim) + else: + # Input flattened from [G, B, D] -> reshape to [G, B, D] + x = x.reshape(self.num_groups, batch_size, self.input_dim) + output = StridedBmmFunction.apply( - x, self.weight, batch_size, self.num_groups, self.output_dim + x, + self.weight, + batch_size, + self.num_groups, + self.output_dim, + self.batch_first, ) return output.view(-1, self.output_dim) -def copy_weights(ref_model: ReferenceImpl, opt_model: BmmImpl): +def copy_weights(ref_model: ReferenceImpl, opt_model: GroupedLinear): """Copy weights from reference to optimized model.""" with torch.no_grad(): for i in range(ref_model.num_groups): @@ -188,7 +230,7 @@ def copy_weights(ref_model: ReferenceImpl, opt_model: BmmImpl): def check_correctness( ref_model: ReferenceImpl, - opt_model: BmmImpl, + opt_model: GroupedLinear, batch_size: int, num_groups: int, input_dim: int, @@ -311,7 +353,7 @@ def main(): # Create models ref_model = ReferenceImpl(input_dim, output_dim, num_groups, dtype=dtype).cuda() - opt_model = BmmImpl(input_dim, output_dim, num_groups, dtype=dtype).cuda() + opt_model = GroupedLinear(input_dim, output_dim, num_groups, dtype=dtype).cuda() copy_weights(ref_model, opt_model) # Correctness check @@ -341,7 +383,7 @@ def main(): print(f"\nForward pass (ms):") print(f" Reference (loop): {ref_fwd:.4f}") - print(f" Optimized (BMM): {opt_fwd:.4f}") + print(f" GroupedLinear: {opt_fwd:.4f}") print(f" Speedup: {ref_fwd/opt_fwd:.2f}x") # Forward + Backward @@ -350,7 +392,7 @@ def main(): print(f"\nForward + Backward (ms):") print(f" Reference (loop): {ref_fwdbwd:.4f}") - print(f" Optimized (BMM): {opt_fwdbwd:.4f}") + print(f" GroupedLinear: {opt_fwdbwd:.4f}") print(f" Speedup: {ref_fwdbwd/opt_fwdbwd:.2f}x") print("\n" + "=" * 60) From dd7767bf3be97831770df2ebb3d693b1188c8218 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Wed, 14 Jan 2026 08:31:11 +0000 Subject: [PATCH 3/7] Add new example. --- ...mm_example.py => GroupedLinear_example.py} | 0 examples/commons/ops/GroupedMLP_example.py | 927 ++++++++++++++++++ 2 files changed, 927 insertions(+) rename examples/commons/ops/{batched_gemm_example.py => GroupedLinear_example.py} (100%) create mode 100644 examples/commons/ops/GroupedMLP_example.py diff --git a/examples/commons/ops/batched_gemm_example.py b/examples/commons/ops/GroupedLinear_example.py similarity index 100% rename from examples/commons/ops/batched_gemm_example.py rename to examples/commons/ops/GroupedLinear_example.py diff --git a/examples/commons/ops/GroupedMLP_example.py b/examples/commons/ops/GroupedMLP_example.py new file mode 100644 index 000000000..25db46799 --- /dev/null +++ b/examples/commons/ops/GroupedMLP_example.py @@ -0,0 +1,927 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped MLP Benchmark: Reference vs Plan A vs Plan B + +================================================================================ +Problem +================================================================================ +Apply num_groups different MLP transformations with GLU gating (SwiGLU/GeGLU): + + For each group n: y[b, n, :] = down(act(gate(x)) * up(x)) + +================================================================================ +Implementations +================================================================================ +Reference: Loop over groups with separate nn.Linear layers +Plan A: 3 independent strided BMMs (gate, up, down) +Plan B: Fused gate+up into 1 BMM, then down (2 BMMs total) + +================================================================================ +""" + +import argparse +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.nvtx as nvtx + + +def warmup_gpu(): + """Warmup GPU to get stable timing.""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +def get_activation_fn(activation: Optional[str]) -> Optional[Callable]: + """Get activation function by name.""" + if activation is None: + return None + activation_map = { + "silu": F.silu, + "swish": F.silu, + "gelu": F.gelu, + "relu": F.relu, + "tanh": torch.tanh, + "sigmoid": torch.sigmoid, + } + if activation.lower() not in activation_map: + raise ValueError(f"Unknown activation: {activation}") + return activation_map[activation.lower()] + + +# ============================================================================= +# Reference Implementation +# ============================================================================= + +class ReferenceGroupedMLP(nn.Module): + """Reference implementation using loop over groups.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "silu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + if use_gating: + self.gate_proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + self.up_proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + else: + self.proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + + self.down_proj = nn.ModuleList([ + nn.Linear(hidden_dim, output_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + + self._init_weights() + + def _init_weights(self): + for module_list in [getattr(self, 'gate_proj', []), + getattr(self, 'up_proj', []), + getattr(self, 'proj', []), + self.down_proj]: + for layer in module_list: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + if enable_nvtx: + with nvtx.range("Ref_reshape"): + x = x.reshape(-1, self.num_groups, self.input_dim) + + with nvtx.range("Ref_split"): + x_split = torch.split(x, 1, dim=1) + + with nvtx.range("Ref_loop_gemm"): + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + with nvtx.range("Ref_stack"): + output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + else: + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + return output + + +# ============================================================================= +# Strided BMM Function +# ============================================================================= + +class StridedBmmFunction(torch.autograd.Function): + """Custom autograd function for BMM with strided output.""" + + @staticmethod + def forward(ctx, x, weight, batch_size, num_groups, output_dim): + ctx.save_for_backward(x, weight) + + output = torch.empty(batch_size, num_groups, output_dim, + device=x.device, dtype=x.dtype) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + grad_x = grad_weight = None + + grad_output_t = grad_output.permute(1, 0, 2) + + if ctx.needs_input_grad[0]: + grad_x = torch.empty_like(x) + torch.bmm(grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2)) + + if ctx.needs_input_grad[1]: + x_t = x.permute(1, 0, 2) + grad_weight = torch.bmm(x_t.transpose(-1, -2), grad_output_t) + + return grad_x, grad_weight, None, None, None + + +# ============================================================================= +# Plan A: 3 Independent BMMs +# ============================================================================= + +class GroupedMLP_PlanA(nn.Module): + """Plan A: 3 independent strided BMMs (gate, up, down).""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "silu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + if use_gating: + self.gate_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + self.up_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + else: + self.proj_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + + self.down_weight = nn.Parameter( + torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) + ) + + self._init_weights() + + def _init_weights(self): + for i in range(self.num_groups): + if self.use_gating: + nn.init.xavier_normal_(self.gate_weight[i], gain=1.0) + nn.init.xavier_normal_(self.up_weight[i], gain=1.0) + else: + nn.init.xavier_normal_(self.proj_weight[i], gain=1.0) + nn.init.xavier_normal_(self.down_weight[i], gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + + if enable_nvtx: + with nvtx.range("PlanA_reshape"): + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + with nvtx.range("PlanA_gate_bmm"): + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + ) + + with nvtx.range("PlanA_up_bmm"): + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + + with nvtx.range("PlanA_activation"): + if self.act_fn is not None: + hidden = self.act_fn(gate) * up + else: + hidden = gate * up + else: + with nvtx.range("PlanA_proj_bmm"): + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + with nvtx.range("PlanA_activation"): + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + with nvtx.range("PlanA_down_bmm"): + output = StridedBmmFunction.apply( + hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + ) + + with nvtx.range("PlanA_view"): + return output.view(-1, self.output_dim) + else: + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + ) + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = self.act_fn(gate) * up + else: + hidden = gate * up + else: + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + output = StridedBmmFunction.apply( + hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + ) + + return output.view(-1, self.output_dim) + + +# ============================================================================= +# Plan B: Fused gate+up BMM +# ============================================================================= + +class GroupedMLP_PlanB(nn.Module): + """Plan B: Fused gate+up into single BMM (2 BMMs total).""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "silu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + proj_out_dim = 2 * hidden_dim if use_gating else hidden_dim + self.proj_weight = nn.Parameter( + torch.empty(num_groups, input_dim, proj_out_dim, device=device, dtype=dtype) + ) + + self.down_weight = nn.Parameter( + torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) + ) + + self._init_weights() + + def _init_weights(self): + for i in range(self.num_groups): + if self.use_gating: + nn.init.xavier_normal_(self.proj_weight[i, :, :self.hidden_dim], gain=1.0) + nn.init.xavier_normal_(self.proj_weight[i, :, self.hidden_dim:], gain=1.0) + else: + nn.init.xavier_normal_(self.proj_weight[i], gain=1.0) + nn.init.xavier_normal_(self.down_weight[i], gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + + if enable_nvtx: + with nvtx.range("PlanB_reshape"): + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + with nvtx.range("PlanB_fused_proj_bmm"): + proj_out = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.proj_weight.shape[-1] + ) + + with nvtx.range("PlanB_activation"): + if self.use_gating: + gate, up = proj_out.chunk(2, dim=-1) + if self.act_fn is not None: + hidden = self.act_fn(gate) * up + else: + hidden = gate * up + else: + if self.act_fn is not None: + hidden = self.act_fn(proj_out) + else: + hidden = proj_out + + with nvtx.range("PlanB_down_bmm"): + output = StridedBmmFunction.apply( + hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + ) + + with nvtx.range("PlanB_view"): + return output.view(-1, self.output_dim) + else: + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + proj_out = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.proj_weight.shape[-1] + ) + + if self.use_gating: + gate, up = proj_out.chunk(2, dim=-1) + if self.act_fn is not None: + hidden = self.act_fn(gate) * up + else: + hidden = gate * up + else: + if self.act_fn is not None: + hidden = self.act_fn(proj_out) + else: + hidden = proj_out + + output = StridedBmmFunction.apply( + hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + ) + + return output.view(-1, self.output_dim) + + +# ============================================================================= +# Plan C: Native torch.bmm (no custom autograd.Function) +# ============================================================================= + +class GroupedMLP_PlanC(nn.Module): + """ + Plan C: Use native torch.bmm without custom autograd.Function. + + This avoids Python callback overhead in autograd by using only native PyTorch ops. + The tradeoff is an extra permute->contiguous copy, but autograd is pure C++. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "silu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + proj_out_dim = 2 * hidden_dim if use_gating else hidden_dim + self.proj_weight = nn.Parameter( + torch.empty(num_groups, input_dim, proj_out_dim, device=device, dtype=dtype) + ) + + self.down_weight = nn.Parameter( + torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) + ) + + self._init_weights() + + def _init_weights(self): + for i in range(self.num_groups): + if self.use_gating: + nn.init.xavier_normal_(self.proj_weight[i, :, :self.hidden_dim], gain=1.0) + nn.init.xavier_normal_(self.proj_weight[i, :, self.hidden_dim:], gain=1.0) + else: + nn.init.xavier_normal_(self.proj_weight[i], gain=1.0) + nn.init.xavier_normal_(self.down_weight[i], gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + + if enable_nvtx: + with nvtx.range("PlanC_reshape"): + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + with nvtx.range("PlanC_fused_proj_bmm"): + # Native bmm: (G, B, D_in) @ (G, D_in, D_out) -> (G, B, D_out) + x_t = x.permute(1, 0, 2) # (B, G, D) -> (G, B, D) + proj_out_t = torch.bmm(x_t, self.proj_weight) # (G, B, D_out) + proj_out = proj_out_t.permute(1, 0, 2) # (G, B, D) -> (B, G, D) + + with nvtx.range("PlanC_activation"): + if self.use_gating: + gate, up = proj_out.chunk(2, dim=-1) + if self.act_fn is not None: + hidden = self.act_fn(gate) * up + else: + hidden = gate * up + else: + if self.act_fn is not None: + hidden = self.act_fn(proj_out) + else: + hidden = proj_out + + with nvtx.range("PlanC_down_bmm"): + hidden_t = hidden.permute(1, 0, 2) # (B, G, D) -> (G, B, D) + output_t = torch.bmm(hidden_t, self.down_weight) # (G, B, D_out) + output = output_t.permute(1, 0, 2) # -> (B, G, D_out) + + with nvtx.range("PlanC_view"): + return output.reshape(-1, self.output_dim) + else: + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + # Native bmm with permute (autograd handles this in C++) + x_t = x.permute(1, 0, 2) + proj_out_t = torch.bmm(x_t, self.proj_weight) + proj_out = proj_out_t.permute(1, 0, 2) + + if self.use_gating: + gate, up = proj_out.chunk(2, dim=-1) + if self.act_fn is not None: + hidden = self.act_fn(gate) * up + else: + hidden = gate * up + else: + if self.act_fn is not None: + hidden = self.act_fn(proj_out) + else: + hidden = proj_out + + hidden_t = hidden.permute(1, 0, 2) + output_t = torch.bmm(hidden_t, self.down_weight) + output = output_t.permute(1, 0, 2) + + return output.reshape(-1, self.output_dim) + + +# Alias +GroupedMLP = GroupedMLP_PlanB + + +# ============================================================================= +# Weight Copy Utilities +# ============================================================================= + +def copy_weights_to_plan_a(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanA): + with torch.no_grad(): + num_groups = ref_model.num_groups + if ref_model.use_gating: + for i in range(num_groups): + opt_model.gate_weight[i].copy_(ref_model.gate_proj[i].weight.T) + opt_model.up_weight[i].copy_(ref_model.up_proj[i].weight.T) + else: + for i in range(num_groups): + opt_model.proj_weight[i].copy_(ref_model.proj[i].weight.T) + for i in range(num_groups): + opt_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) + + +def copy_weights_to_plan_b(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanB): + with torch.no_grad(): + num_groups = ref_model.num_groups + if ref_model.use_gating: + for i in range(num_groups): + opt_model.proj_weight[i, :, :opt_model.hidden_dim].copy_( + ref_model.gate_proj[i].weight.T + ) + opt_model.proj_weight[i, :, opt_model.hidden_dim:].copy_( + ref_model.up_proj[i].weight.T + ) + else: + for i in range(num_groups): + opt_model.proj_weight[i].copy_(ref_model.proj[i].weight.T) + for i in range(num_groups): + opt_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) + + +def copy_weights_to_plan_c(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanC): + # Same structure as Plan B + copy_weights_to_plan_b(ref_model, opt_model) + + +# ============================================================================= +# Correctness Check +# ============================================================================= + +def check_correctness( + ref_model: nn.Module, + opt_model: nn.Module, + batch_size: int, + num_groups: int, + input_dim: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[float, float]: + """Check forward and backward correctness.""" + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + with torch.no_grad(): + ref_out = ref_model(x) + opt_out = opt_model(x) + fwd_diff = (ref_out - opt_out).abs().max().item() + + x_ref = torch.randn( + batch_size * num_groups, input_dim, + device="cuda", dtype=dtype, requires_grad=True + ) + x_opt = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + opt_out = opt_model(x_opt) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + opt_out.backward(grad_output) + + bwd_x_diff = (x_ref.grad - x_opt.grad).abs().max().item() + + return fwd_diff, bwd_x_diff + + +# ============================================================================= +# Benchmark Functions (same as benchmark_batched_gemm.py) +# ============================================================================= + +def benchmark_forward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + enable_nvtx: bool = False, +) -> float: + """Benchmark forward pass using CUDA events for accurate GPU timing.""" + model_name = model.__class__.__name__ + + # Warmup + for i in range(num_warmup): + _ = model(x_list[i % len(x_list)], enable_nvtx=False) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + if enable_nvtx: + for i in range(num_iterations): + with nvtx.range(f"{model_name}_fwd_iter{i}"): + _ = model(x_list[i % len(x_list)], enable_nvtx=True) + else: + for i in range(num_iterations): + _ = model(x_list[i % len(x_list)], enable_nvtx=False) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +def benchmark_forward_backward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + enable_nvtx: bool = False, +) -> float: + """Benchmark forward + backward pass using CUDA events.""" + model_name = model.__class__.__name__ + output_dim = model.output_dim + + grad_outputs = [ + torch.randn(xi.shape[0], output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + + x_with_grad = [xi.requires_grad_(True) for xi in x_list] + params = list(model.parameters()) + + # Warmup + for i in range(num_warmup): + xi = x_with_grad[i % len(x_list)] + out = model(xi, enable_nvtx=False) + grads = torch.autograd.grad( + outputs=out, + inputs=[xi] + params, + grad_outputs=grad_outputs[i % len(x_list)], + ) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + if enable_nvtx: + for i in range(num_iterations): + with nvtx.range(f"{model_name}_fwdbwd_iter{i}"): + xi = x_with_grad[i % len(x_list)] + grad_out = grad_outputs[i % len(x_list)] + with nvtx.range("forward"): + out = model(xi, enable_nvtx=True) + # Separate backward into two parts for clearer profiling + with nvtx.range("backward_activation"): + # dL/dx (activation gradient) + grad_x = torch.autograd.grad( + outputs=out, + inputs=xi, + grad_outputs=grad_out, + retain_graph=True, # Keep graph for weight gradient + ) + with nvtx.range("backward_weight"): + # dL/dW (weight gradient) + grad_w = torch.autograd.grad( + outputs=out, + inputs=params, + grad_outputs=grad_out, + retain_graph=False, + ) + else: + for i in range(num_iterations): + xi = x_with_grad[i % len(x_list)] + out = model(xi, enable_nvtx=False) + grads = torch.autograd.grad( + outputs=out, + inputs=[xi] + params, + grad_outputs=grad_outputs[i % len(x_list)], + ) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser( + description="Grouped MLP Benchmark: Reference vs Plan A vs Plan B" + ) + parser.add_argument("--batch-size", type=int, default=2560) + parser.add_argument("--num-groups", type=int, default=12) + parser.add_argument("--input-dim", type=int, default=1024) + parser.add_argument("--hidden-dim", type=int, default=3072) + parser.add_argument("--output-dim", type=int, default=1024) + parser.add_argument("--activation", type=str, default="silu", + choices=["silu", "gelu", "relu", "tanh", "none"]) + parser.add_argument("--no-gating", action="store_true") + parser.add_argument("--iterations", type=int, default=100) + parser.add_argument("--enable-nvtx", action="store_true", + help="Enable NVTX markers (use with nsys profile)") + parser.add_argument("--compile", action="store_true", + help="Use torch.compile() to optimize models") + args = parser.parse_args() + + torch.cuda.init() + + batch_size = args.batch_size + num_groups = args.num_groups + input_dim = args.input_dim + hidden_dim = args.hidden_dim + output_dim = args.output_dim + activation = None if args.activation == "none" else args.activation + use_gating = not args.no_gating + dtype = torch.bfloat16 + num_iterations = args.iterations + + print("=" * 80) + print("Grouped MLP Benchmark: Reference vs Plan A vs Plan B") + print("=" * 80) + + if args.enable_nvtx: + print("\n*** NVTX PROFILING MODE ***") + print("Run with: nsys profile -o --trace=cuda,nvtx python ...") + + print(f""" +Config: + Batch size: {batch_size} + Num groups: {num_groups} + Dimensions: {input_dim} -> {hidden_dim} -> {output_dim} + Mode: {"GLU (SwiGLU)" if use_gating else "Simple MLP"} + Activation: {activation if activation else "None"} + Dtype: {dtype} + Device: {torch.cuda.get_device_name(0)} + Iterations: {num_iterations} +""") + + print("Warming up GPU...") + warmup_gpu() + + # Create models + print("Creating models...") + ref_model = ReferenceGroupedMLP( + input_dim, hidden_dim, output_dim, num_groups, + use_gating=use_gating, activation=activation, dtype=dtype + ).cuda() + + plan_a_model = GroupedMLP_PlanA( + input_dim, hidden_dim, output_dim, num_groups, + use_gating=use_gating, activation=activation, dtype=dtype + ).cuda() + + plan_b_model = GroupedMLP_PlanB( + input_dim, hidden_dim, output_dim, num_groups, + use_gating=use_gating, activation=activation, dtype=dtype + ).cuda() + + plan_c_model = GroupedMLP_PlanC( + input_dim, hidden_dim, output_dim, num_groups, + use_gating=use_gating, activation=activation, dtype=dtype + ).cuda() + + copy_weights_to_plan_a(ref_model, plan_a_model) + copy_weights_to_plan_b(ref_model, plan_b_model) + copy_weights_to_plan_c(ref_model, plan_c_model) + + # Apply torch.compile() if requested + if args.compile: + print("\nApplying torch.compile() to all models...") + ref_model = torch.compile(ref_model) + plan_a_model = torch.compile(plan_a_model) + plan_b_model = torch.compile(plan_b_model) + plan_c_model = torch.compile(plan_c_model) + print("Compilation complete (will JIT compile on first run).") + + # Correctness check + print("-" * 60) + print("Correctness Check") + print("-" * 60) + + fwd_a, bwd_a = check_correctness(ref_model, plan_a_model, batch_size, num_groups, input_dim, dtype) + fwd_b, bwd_b = check_correctness(ref_model, plan_b_model, batch_size, num_groups, input_dim, dtype) + fwd_c, bwd_c = check_correctness(ref_model, plan_c_model, batch_size, num_groups, input_dim, dtype) + + print(f"Plan A - Forward diff: {fwd_a:.2e}, Backward diff: {bwd_a:.2e}") + print(f"Plan B - Forward diff: {fwd_b:.2e}, Backward diff: {bwd_b:.2e}") + print(f"Plan C - Forward diff: {fwd_c:.2e}, Backward diff: {bwd_c:.2e}") + + # Prepare test data + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Benchmark (NEVER use NVTX for timing - NVTX adds Python overhead) + print("\n" + "-" * 60) + print("Performance Benchmark (NVTX disabled for accurate timing)") + print("-" * 60) + + # Forward - always benchmark without NVTX + print("\n>>> Forward Pass <<<") + ref_fwd = benchmark_forward(ref_model, x_list, num_iterations, enable_nvtx=False) + plan_a_fwd = benchmark_forward(plan_a_model, x_list, num_iterations, enable_nvtx=False) + plan_b_fwd = benchmark_forward(plan_b_model, x_list, num_iterations, enable_nvtx=False) + plan_c_fwd = benchmark_forward(plan_c_model, x_list, num_iterations, enable_nvtx=False) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") + print(f"{'Plan A (custom autograd)':<30} {plan_a_fwd:<12.4f} {ref_fwd/plan_a_fwd:<10.2f}x") + print(f"{'Plan B (custom autograd)':<30} {plan_b_fwd:<12.4f} {ref_fwd/plan_b_fwd:<10.2f}x") + print(f"{'Plan C (native torch.bmm)':<30} {plan_c_fwd:<12.4f} {ref_fwd/plan_c_fwd:<10.2f}x") + + # Forward + Backward - always benchmark without NVTX + print("\n>>> Forward + Backward <<<") + ref_fwdbwd = benchmark_forward_backward(ref_model, x_list, num_iterations, enable_nvtx=False) + plan_a_fwdbwd = benchmark_forward_backward(plan_a_model, x_list, num_iterations, enable_nvtx=False) + plan_b_fwdbwd = benchmark_forward_backward(plan_b_model, x_list, num_iterations, enable_nvtx=False) + plan_c_fwdbwd = benchmark_forward_backward(plan_c_model, x_list, num_iterations, enable_nvtx=False) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") + print(f"{'Plan A (custom autograd)':<30} {plan_a_fwdbwd:<12.4f} {ref_fwdbwd/plan_a_fwdbwd:<10.2f}x") + print(f"{'Plan B (custom autograd)':<30} {plan_b_fwdbwd:<12.4f} {ref_fwdbwd/plan_b_fwdbwd:<10.2f}x") + print(f"{'Plan C (native torch.bmm)':<30} {plan_c_fwdbwd:<12.4f} {ref_fwdbwd/plan_c_fwdbwd:<10.2f}x") + + # NVTX profiling run (separate from benchmark) + if args.enable_nvtx: + print("\n" + "-" * 60) + print("NVTX Profiling Run (for nsys analysis only)") + print("-" * 60) + torch.cuda.profiler.start() + + # Run a few iterations with NVTX for profiling + nvtx_iterations = min(10, num_iterations) + _ = benchmark_forward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward(plan_b_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward(plan_c_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward_backward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward_backward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward_backward(plan_b_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward_backward(plan_c_model, x_list, nvtx_iterations, enable_nvtx=True) + + torch.cuda.profiler.stop() + print("NVTX profiling complete.") + + # Summary + print("\n" + "=" * 80) + print("Summary") + print("=" * 80) + print(f""" +Implementation Details: + Reference: Loop over {num_groups} groups, uses nn.Linear (C++ autograd) + Plan A/B: Custom StridedBmmFunction (Python autograd callback overhead) + Plan C: Native torch.bmm + permute (C++ autograd, but extra memory copy) + +Forward Speedup: + Plan A vs Reference: {ref_fwd/plan_a_fwd:.2f}x + Plan B vs Reference: {ref_fwd/plan_b_fwd:.2f}x + Plan C vs Reference: {ref_fwd/plan_c_fwd:.2f}x + +Fwd+Bwd Speedup: + Plan A vs Reference: {ref_fwdbwd/plan_a_fwdbwd:.2f}x + Plan B vs Reference: {ref_fwdbwd/plan_b_fwdbwd:.2f}x + Plan C vs Reference: {ref_fwdbwd/plan_c_fwdbwd:.2f}x + +Note: If Plan C is faster than Plan A/B, the bottleneck is Python autograd callback. + Consider using torch.compile() or writing a C++ extension. +""") + print("=" * 80) + print("Done!") + + +if __name__ == "__main__": + main() From dd52bf8a8de514005a51797843a6b9e7119f96cf Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Thu, 15 Jan 2026 09:50:14 +0000 Subject: [PATCH 4/7] Add triton silu+mul to support swiglu. --- examples/commons/ops/GroupedMLP_example.py | 181 ++++++++++++++++++++- 1 file changed, 179 insertions(+), 2 deletions(-) diff --git a/examples/commons/ops/GroupedMLP_example.py b/examples/commons/ops/GroupedMLP_example.py index 25db46799..3a45507d9 100644 --- a/examples/commons/ops/GroupedMLP_example.py +++ b/examples/commons/ops/GroupedMLP_example.py @@ -20,6 +20,8 @@ ================================================================================ """ +import sys +sys.path.insert(0, '/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu') import argparse from typing import Callable, List, Optional, Tuple, Union @@ -28,6 +30,179 @@ import torch.nn as nn import torch.nn.functional as F import torch.cuda.nvtx as nvtx +import triton +import triton.language as tl +try: + # @manual=//triton:triton + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + # pyre-ignore: Undefined import [21] + # @manual=//triton:triton + from triton.language.math import fast_dividef + +from ops.triton_ops.common import triton_autotune + +def silu_configs(): + configs = [] + for x_block_size in [256, 512, 1024, 2048]: + for num_warps in [2, 4, 8, 16]: + config = triton.Config({"x_block_size": x_block_size}, num_warps) + configs.append(config) + return configs + + + + + +# ============================================================================= +# Fused SiLU * Up (SwiGLU pattern): output = silu(gate) * up +# ============================================================================= + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_forward( + output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """Fused forward: output = silu(gate) * up""" + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + # silu(gate) = gate * sigmoid(gate) = gate / (1 + exp(-gate)) + silu_gate = fast_dividef(gate, 1.0 + tl.exp(-gate)) + output = (silu_gate * up).to(output_ptr.dtype.element_ty) + + tl.store(output_ptr + x_offset + cols, output, mask=mask) + + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_backward( + grad_gate_ptr: tl.tensor, + grad_up_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """ + Fused backward for output = silu(gate) * up + + grad_gate = grad_output * up * d(silu)/d(gate) + = grad_output * up * (sigmoid(gate) + gate * sigmoid(gate) * (1 - sigmoid(gate))) + grad_up = grad_output * silu(gate) + """ + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + sigma = tl.sigmoid(gate) + silu_gate = gate * sigma + + # d(silu)/d(gate) = sigma + gate * sigma * (1 - sigma) + dsilu_dgate = sigma + gate * sigma * (1.0 - sigma) + + grad_gate = grad_output * up * dsilu_dgate + grad_up = grad_output * silu_gate + + tl.store(grad_gate_ptr + x_offset + cols, grad_gate.to(grad_gate_ptr.dtype.element_ty), mask=mask) + tl.store(grad_up_ptr + x_offset + cols, grad_up.to(grad_up_ptr.dtype.element_ty), mask=mask) + + +def triton_silu_mul_fwd(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """Forward: output = silu(gate) * up""" + assert gate.shape == up.shape, f"Shape mismatch: gate {gate.shape} vs up {up.shape}" + x_size = gate.numel() + gate_1d = gate.view(-1).contiguous() + up_1d = up.view(-1).contiguous() + output = torch.empty_like(gate_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_forward[grid]( + output, + gate_1d, + up_1d, + x_size, + ) + return output.view(gate.shape) + + +def triton_silu_mul_bwd( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward: returns (grad_gate, grad_up)""" + shape = gate.shape + x_size = gate.numel() + gate_1d = gate.view(-1).contiguous() + up_1d = up.view(-1).contiguous() + grad_output_1d = grad_output.view(-1).contiguous() + grad_gate = torch.empty_like(gate_1d) + grad_up = torch.empty_like(up_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_backward[grid]( + grad_gate, + grad_up, + grad_output_1d, + gate_1d, + up_1d, + x_size, + ) + return grad_gate.view(shape), grad_up.view(shape) + + +class TritonSiluMul(torch.autograd.Function): + """Autograd function for fused silu(gate) * up""" + + @staticmethod + def forward(ctx, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + output = triton_silu_mul_fwd(gate, up) + ctx.save_for_backward(gate, up) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + gate, up = ctx.saved_tensors + grad_gate, grad_up = triton_silu_mul_bwd(grad_output, gate, up) + return grad_gate, grad_up + + +def triton_silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """ + Fused SiLU multiplication (SwiGLU pattern). + + Computes: output = silu(gate) * up + + Args: + gate: Input tensor that goes through SiLU activation + up: Input tensor that multiplies with activated gate + + Returns: + output: silu(gate) * up + """ + gate = gate.contiguous() + up = up.contiguous() + return TritonSiluMul.apply(gate, up) def warmup_gpu(): @@ -49,6 +224,7 @@ def get_activation_fn(activation: Optional[str]) -> Optional[Callable]: "relu": F.relu, "tanh": torch.tanh, "sigmoid": torch.sigmoid, + "swiglu": triton_silu_mul, } if activation.lower() not in activation_map: raise ValueError(f"Unknown activation: {activation}") @@ -69,7 +245,7 @@ def __init__( output_dim: int, num_groups: int, use_gating: bool = True, - activation: Optional[str] = "silu", + activation: Optional[str] = "swiglu", device: Union[str, torch.device] = "cuda", dtype: torch.dtype = torch.bfloat16, ): @@ -303,7 +479,8 @@ def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: x, self.up_weight, batch_size, self.num_groups, self.hidden_dim ) if self.act_fn is not None: - hidden = self.act_fn(gate) * up + # hidden = self.act_fn(gate) * up + hidden = triton_silu_mul(gate, up) else: hidden = gate * up else: From 6369887cec9e2ca9a2c661f6679c835b5f3b1d9f Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Fri, 16 Jan 2026 08:41:53 +0000 Subject: [PATCH 5/7] Simplified code. --- examples/commons/ops/GroupedMLP_example.py | 369 +++++---------------- 1 file changed, 87 insertions(+), 282 deletions(-) diff --git a/examples/commons/ops/GroupedMLP_example.py b/examples/commons/ops/GroupedMLP_example.py index 3a45507d9..8edac2dcc 100644 --- a/examples/commons/ops/GroupedMLP_example.py +++ b/examples/commons/ops/GroupedMLP_example.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """ -Grouped MLP Benchmark: Reference vs Plan A vs Plan B +Grouped MLP Benchmark: Reference vs Plan A ================================================================================ Problem @@ -16,7 +16,6 @@ ================================================================================ Reference: Loop over groups with separate nn.Linear layers Plan A: 3 independent strided BMMs (gate, up, down) -Plan B: Fused gate+up into 1 BMM, then down (2 BMMs total) ================================================================================ """ @@ -340,6 +339,29 @@ def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: return output + def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass up to hidden (excluding down projection).""" + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + hidden_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + hidden_list.append(hidden_i) + + return torch.stack(hidden_list, dim=1).reshape(-1, self.hidden_dim) + # ============================================================================= # Strided BMM Function @@ -496,224 +518,30 @@ def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: return output.view(-1, self.output_dim) - -# ============================================================================= -# Plan B: Fused gate+up BMM -# ============================================================================= - -class GroupedMLP_PlanB(nn.Module): - """Plan B: Fused gate+up into single BMM (2 BMMs total).""" - - def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - num_groups: int, - use_gating: bool = True, - activation: Optional[str] = "silu", - device: Union[str, torch.device] = "cuda", - dtype: torch.dtype = torch.bfloat16, - ): - super().__init__() - self.num_groups = num_groups - self.input_dim = input_dim - self.hidden_dim = hidden_dim - self.output_dim = output_dim - self.use_gating = use_gating - self.act_fn = get_activation_fn(activation) - - proj_out_dim = 2 * hidden_dim if use_gating else hidden_dim - self.proj_weight = nn.Parameter( - torch.empty(num_groups, input_dim, proj_out_dim, device=device, dtype=dtype) - ) - - self.down_weight = nn.Parameter( - torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) - ) - - self._init_weights() - - def _init_weights(self): - for i in range(self.num_groups): - if self.use_gating: - nn.init.xavier_normal_(self.proj_weight[i, :, :self.hidden_dim], gain=1.0) - nn.init.xavier_normal_(self.proj_weight[i, :, self.hidden_dim:], gain=1.0) - else: - nn.init.xavier_normal_(self.proj_weight[i], gain=1.0) - nn.init.xavier_normal_(self.down_weight[i], gain=1.0) - - def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass up to hidden (excluding down projection).""" batch_size = x.shape[0] // self.num_groups + x = x.reshape(batch_size, self.num_groups, self.input_dim) - if enable_nvtx: - with nvtx.range("PlanB_reshape"): - x = x.reshape(batch_size, self.num_groups, self.input_dim) - - with nvtx.range("PlanB_fused_proj_bmm"): - proj_out = StridedBmmFunction.apply( - x, self.proj_weight, batch_size, self.num_groups, self.proj_weight.shape[-1] - ) - - with nvtx.range("PlanB_activation"): - if self.use_gating: - gate, up = proj_out.chunk(2, dim=-1) - if self.act_fn is not None: - hidden = self.act_fn(gate) * up - else: - hidden = gate * up - else: - if self.act_fn is not None: - hidden = self.act_fn(proj_out) - else: - hidden = proj_out - - with nvtx.range("PlanB_down_bmm"): - output = StridedBmmFunction.apply( - hidden, self.down_weight, batch_size, self.num_groups, self.output_dim - ) - - with nvtx.range("PlanB_view"): - return output.view(-1, self.output_dim) - else: - x = x.reshape(batch_size, self.num_groups, self.input_dim) - - proj_out = StridedBmmFunction.apply( - x, self.proj_weight, batch_size, self.num_groups, self.proj_weight.shape[-1] + if self.use_gating: + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim ) - - if self.use_gating: - gate, up = proj_out.chunk(2, dim=-1) - if self.act_fn is not None: - hidden = self.act_fn(gate) * up - else: - hidden = gate * up - else: - if self.act_fn is not None: - hidden = self.act_fn(proj_out) - else: - hidden = proj_out - - output = StridedBmmFunction.apply( - hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim ) - - return output.view(-1, self.output_dim) - - -# ============================================================================= -# Plan C: Native torch.bmm (no custom autograd.Function) -# ============================================================================= - -class GroupedMLP_PlanC(nn.Module): - """ - Plan C: Use native torch.bmm without custom autograd.Function. - - This avoids Python callback overhead in autograd by using only native PyTorch ops. - The tradeoff is an extra permute->contiguous copy, but autograd is pure C++. - """ - - def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - num_groups: int, - use_gating: bool = True, - activation: Optional[str] = "silu", - device: Union[str, torch.device] = "cuda", - dtype: torch.dtype = torch.bfloat16, - ): - super().__init__() - self.num_groups = num_groups - self.input_dim = input_dim - self.hidden_dim = hidden_dim - self.output_dim = output_dim - self.use_gating = use_gating - self.act_fn = get_activation_fn(activation) - - proj_out_dim = 2 * hidden_dim if use_gating else hidden_dim - self.proj_weight = nn.Parameter( - torch.empty(num_groups, input_dim, proj_out_dim, device=device, dtype=dtype) - ) - - self.down_weight = nn.Parameter( - torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) - ) - - self._init_weights() - - def _init_weights(self): - for i in range(self.num_groups): - if self.use_gating: - nn.init.xavier_normal_(self.proj_weight[i, :, :self.hidden_dim], gain=1.0) - nn.init.xavier_normal_(self.proj_weight[i, :, self.hidden_dim:], gain=1.0) + if self.act_fn is not None: + hidden = triton_silu_mul(gate, up) else: - nn.init.xavier_normal_(self.proj_weight[i], gain=1.0) - nn.init.xavier_normal_(self.down_weight[i], gain=1.0) - - def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: - batch_size = x.shape[0] // self.num_groups - - if enable_nvtx: - with nvtx.range("PlanC_reshape"): - x = x.reshape(batch_size, self.num_groups, self.input_dim) - - with nvtx.range("PlanC_fused_proj_bmm"): - # Native bmm: (G, B, D_in) @ (G, D_in, D_out) -> (G, B, D_out) - x_t = x.permute(1, 0, 2) # (B, G, D) -> (G, B, D) - proj_out_t = torch.bmm(x_t, self.proj_weight) # (G, B, D_out) - proj_out = proj_out_t.permute(1, 0, 2) # (G, B, D) -> (B, G, D) - - with nvtx.range("PlanC_activation"): - if self.use_gating: - gate, up = proj_out.chunk(2, dim=-1) - if self.act_fn is not None: - hidden = self.act_fn(gate) * up - else: - hidden = gate * up - else: - if self.act_fn is not None: - hidden = self.act_fn(proj_out) - else: - hidden = proj_out - - with nvtx.range("PlanC_down_bmm"): - hidden_t = hidden.permute(1, 0, 2) # (B, G, D) -> (G, B, D) - output_t = torch.bmm(hidden_t, self.down_weight) # (G, B, D_out) - output = output_t.permute(1, 0, 2) # -> (B, G, D_out) - - with nvtx.range("PlanC_view"): - return output.reshape(-1, self.output_dim) + hidden = gate * up else: - x = x.reshape(batch_size, self.num_groups, self.input_dim) - - # Native bmm with permute (autograd handles this in C++) - x_t = x.permute(1, 0, 2) - proj_out_t = torch.bmm(x_t, self.proj_weight) - proj_out = proj_out_t.permute(1, 0, 2) - - if self.use_gating: - gate, up = proj_out.chunk(2, dim=-1) - if self.act_fn is not None: - hidden = self.act_fn(gate) * up - else: - hidden = gate * up - else: - if self.act_fn is not None: - hidden = self.act_fn(proj_out) - else: - hidden = proj_out - - hidden_t = hidden.permute(1, 0, 2) - output_t = torch.bmm(hidden_t, self.down_weight) - output = output_t.permute(1, 0, 2) - - return output.reshape(-1, self.output_dim) - - -# Alias -GroupedMLP = GroupedMLP_PlanB + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + return hidden.view(-1, self.hidden_dim) # ============================================================================= @@ -734,29 +562,6 @@ def copy_weights_to_plan_a(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP opt_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) -def copy_weights_to_plan_b(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanB): - with torch.no_grad(): - num_groups = ref_model.num_groups - if ref_model.use_gating: - for i in range(num_groups): - opt_model.proj_weight[i, :, :opt_model.hidden_dim].copy_( - ref_model.gate_proj[i].weight.T - ) - opt_model.proj_weight[i, :, opt_model.hidden_dim:].copy_( - ref_model.up_proj[i].weight.T - ) - else: - for i in range(num_groups): - opt_model.proj_weight[i].copy_(ref_model.proj[i].weight.T) - for i in range(num_groups): - opt_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) - - -def copy_weights_to_plan_c(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanC): - # Same structure as Plan B - copy_weights_to_plan_b(ref_model, opt_model) - - # ============================================================================= # Correctness Check # ============================================================================= @@ -832,6 +637,32 @@ def benchmark_forward( return start_event.elapsed_time(end_event) / num_iterations +def benchmark_forward_to_hidden( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, +) -> float: + """Benchmark forward pass up to hidden (excluding down projection).""" + # Warmup + for i in range(num_warmup): + _ = model.forward_to_hidden(x_list[i % len(x_list)]) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + for i in range(num_iterations): + _ = model.forward_to_hidden(x_list[i % len(x_list)]) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + def benchmark_forward_backward( model: nn.Module, x_list: List[torch.Tensor], @@ -913,7 +744,7 @@ def benchmark_forward_backward( def main(): parser = argparse.ArgumentParser( - description="Grouped MLP Benchmark: Reference vs Plan A vs Plan B" + description="Grouped MLP Benchmark: Reference vs Plan A" ) parser.add_argument("--batch-size", type=int, default=2560) parser.add_argument("--num-groups", type=int, default=12) @@ -943,7 +774,7 @@ def main(): num_iterations = args.iterations print("=" * 80) - print("Grouped MLP Benchmark: Reference vs Plan A vs Plan B") + print("Grouped MLP Benchmark: Reference vs Plan A") print("=" * 80) if args.enable_nvtx: @@ -977,27 +808,13 @@ def main(): use_gating=use_gating, activation=activation, dtype=dtype ).cuda() - plan_b_model = GroupedMLP_PlanB( - input_dim, hidden_dim, output_dim, num_groups, - use_gating=use_gating, activation=activation, dtype=dtype - ).cuda() - - plan_c_model = GroupedMLP_PlanC( - input_dim, hidden_dim, output_dim, num_groups, - use_gating=use_gating, activation=activation, dtype=dtype - ).cuda() - copy_weights_to_plan_a(ref_model, plan_a_model) - copy_weights_to_plan_b(ref_model, plan_b_model) - copy_weights_to_plan_c(ref_model, plan_c_model) # Apply torch.compile() if requested if args.compile: print("\nApplying torch.compile() to all models...") ref_model = torch.compile(ref_model) plan_a_model = torch.compile(plan_a_model) - plan_b_model = torch.compile(plan_b_model) - plan_c_model = torch.compile(plan_c_model) print("Compilation complete (will JIT compile on first run).") # Correctness check @@ -1006,12 +823,7 @@ def main(): print("-" * 60) fwd_a, bwd_a = check_correctness(ref_model, plan_a_model, batch_size, num_groups, input_dim, dtype) - fwd_b, bwd_b = check_correctness(ref_model, plan_b_model, batch_size, num_groups, input_dim, dtype) - fwd_c, bwd_c = check_correctness(ref_model, plan_c_model, batch_size, num_groups, input_dim, dtype) - print(f"Plan A - Forward diff: {fwd_a:.2e}, Backward diff: {bwd_a:.2e}") - print(f"Plan B - Forward diff: {fwd_b:.2e}, Backward diff: {bwd_b:.2e}") - print(f"Plan C - Forward diff: {fwd_c:.2e}, Backward diff: {bwd_c:.2e}") # Prepare test data x_list = [ @@ -1028,29 +840,31 @@ def main(): print("\n>>> Forward Pass <<<") ref_fwd = benchmark_forward(ref_model, x_list, num_iterations, enable_nvtx=False) plan_a_fwd = benchmark_forward(plan_a_model, x_list, num_iterations, enable_nvtx=False) - plan_b_fwd = benchmark_forward(plan_b_model, x_list, num_iterations, enable_nvtx=False) - plan_c_fwd = benchmark_forward(plan_c_model, x_list, num_iterations, enable_nvtx=False) print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") print("-" * 52) print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") - print(f"{'Plan A (custom autograd)':<30} {plan_a_fwd:<12.4f} {ref_fwd/plan_a_fwd:<10.2f}x") - print(f"{'Plan B (custom autograd)':<30} {plan_b_fwd:<12.4f} {ref_fwd/plan_b_fwd:<10.2f}x") - print(f"{'Plan C (native torch.bmm)':<30} {plan_c_fwd:<12.4f} {ref_fwd/plan_c_fwd:<10.2f}x") + print(f"{'Plan A (batched BMM)':<30} {plan_a_fwd:<12.4f} {ref_fwd/plan_a_fwd:<10.2f}x") + + # Forward to Hidden (excluding down projection) + print("\n>>> Forward to Hidden (excluding down_proj) <<<") + ref_hidden = benchmark_forward_to_hidden(ref_model, x_list, num_iterations) + plan_a_hidden = benchmark_forward_to_hidden(plan_a_model, x_list, num_iterations) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_hidden:<12.4f} {'1.00x':<10}") + print(f"{'Plan A (batched BMM)':<30} {plan_a_hidden:<12.4f} {ref_hidden/plan_a_hidden:<10.2f}x") # Forward + Backward - always benchmark without NVTX print("\n>>> Forward + Backward <<<") ref_fwdbwd = benchmark_forward_backward(ref_model, x_list, num_iterations, enable_nvtx=False) plan_a_fwdbwd = benchmark_forward_backward(plan_a_model, x_list, num_iterations, enable_nvtx=False) - plan_b_fwdbwd = benchmark_forward_backward(plan_b_model, x_list, num_iterations, enable_nvtx=False) - plan_c_fwdbwd = benchmark_forward_backward(plan_c_model, x_list, num_iterations, enable_nvtx=False) print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") print("-" * 52) print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") - print(f"{'Plan A (custom autograd)':<30} {plan_a_fwdbwd:<12.4f} {ref_fwdbwd/plan_a_fwdbwd:<10.2f}x") - print(f"{'Plan B (custom autograd)':<30} {plan_b_fwdbwd:<12.4f} {ref_fwdbwd/plan_b_fwdbwd:<10.2f}x") - print(f"{'Plan C (native torch.bmm)':<30} {plan_c_fwdbwd:<12.4f} {ref_fwdbwd/plan_c_fwdbwd:<10.2f}x") + print(f"{'Plan A (batched BMM)':<30} {plan_a_fwdbwd:<12.4f} {ref_fwdbwd/plan_a_fwdbwd:<10.2f}x") # NVTX profiling run (separate from benchmark) if args.enable_nvtx: @@ -1063,12 +877,8 @@ def main(): nvtx_iterations = min(10, num_iterations) _ = benchmark_forward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) _ = benchmark_forward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) - _ = benchmark_forward(plan_b_model, x_list, nvtx_iterations, enable_nvtx=True) - _ = benchmark_forward(plan_c_model, x_list, nvtx_iterations, enable_nvtx=True) _ = benchmark_forward_backward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) _ = benchmark_forward_backward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) - _ = benchmark_forward_backward(plan_b_model, x_list, nvtx_iterations, enable_nvtx=True) - _ = benchmark_forward_backward(plan_c_model, x_list, nvtx_iterations, enable_nvtx=True) torch.cuda.profiler.stop() print("NVTX profiling complete.") @@ -1080,21 +890,16 @@ def main(): print(f""" Implementation Details: Reference: Loop over {num_groups} groups, uses nn.Linear (C++ autograd) - Plan A/B: Custom StridedBmmFunction (Python autograd callback overhead) - Plan C: Native torch.bmm + permute (C++ autograd, but extra memory copy) + Plan A: Batched BMM with custom StridedBmmFunction -Forward Speedup: +Forward Speedup (full MLP): Plan A vs Reference: {ref_fwd/plan_a_fwd:.2f}x - Plan B vs Reference: {ref_fwd/plan_b_fwd:.2f}x - Plan C vs Reference: {ref_fwd/plan_c_fwd:.2f}x + +Forward to Hidden (excluding down_proj): + Plan A vs Reference: {ref_hidden/plan_a_hidden:.2f}x Fwd+Bwd Speedup: Plan A vs Reference: {ref_fwdbwd/plan_a_fwdbwd:.2f}x - Plan B vs Reference: {ref_fwdbwd/plan_b_fwdbwd:.2f}x - Plan C vs Reference: {ref_fwdbwd/plan_c_fwdbwd:.2f}x - -Note: If Plan C is faster than Plan A/B, the bottleneck is Python autograd callback. - Consider using torch.compile() or writing a C++ extension. """) print("=" * 80) print("Done!") From 38b49dcf4f3ef9011337d115a37345dd6167e8dd Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 20 Jan 2026 07:45:46 +0000 Subject: [PATCH 6/7] Add custom op version. --- examples/commons/ops/GroupedMLP_example.py | 144 ++++ examples/commons/ops/grouped_mlp_customop.py | 862 +++++++++++++++++++ 2 files changed, 1006 insertions(+) create mode 100644 examples/commons/ops/grouped_mlp_customop.py diff --git a/examples/commons/ops/GroupedMLP_example.py b/examples/commons/ops/GroupedMLP_example.py index 8edac2dcc..edfc1c108 100644 --- a/examples/commons/ops/GroupedMLP_example.py +++ b/examples/commons/ops/GroupedMLP_example.py @@ -212,6 +212,115 @@ def warmup_gpu(): torch.cuda.synchronize() +def benchmark_fused_silu_mul_bandwidth( + batch_size: int, + num_groups: int, + hidden_dim: int, + dtype: torch.dtype = torch.bfloat16, + num_iterations: int = 100, + num_warmup: int = 10, +) -> Tuple[float, float, float]: + """ + Benchmark fused SiLU * up kernel and calculate memory bandwidth. + + Returns: + (time_ms, bandwidth_gb_s, achieved_percent) + """ + # Create tensors matching the shape used in GroupedMLP + shape = (batch_size, num_groups, hidden_dim) + gate = torch.randn(shape, device="cuda", dtype=dtype) + up = torch.randn(shape, device="cuda", dtype=dtype) + + # Calculate data volume + numel = gate.numel() + bytes_per_element = gate.element_size() # 2 for bf16, 4 for fp32 + # Read: gate + up, Write: output + total_bytes = 3 * numel * bytes_per_element + + # Warmup + for _ in range(num_warmup): + _ = triton_silu_mul(gate, up) + torch.cuda.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iterations): + _ = triton_silu_mul(gate, up) + end_event.record() + torch.cuda.synchronize() + + time_ms = start_event.elapsed_time(end_event) / num_iterations + time_s = time_ms / 1000.0 + + # Calculate bandwidth + bandwidth_gb_s = total_bytes / time_s / 1e9 + + # Get theoretical peak bandwidth (for reference) + # H100 SXM: ~3.35 TB/s, A100 SXM: ~2.0 TB/s, A100 PCIe: ~1.9 TB/s + # You can query this from torch.cuda.get_device_properties + props = torch.cuda.get_device_properties(0) + # Memory bandwidth = memory_clock_rate (kHz) * bus_width (bits) * 2 (DDR) / 8 (bits to bytes) + # Note: This is approximate; actual peak may differ + theoretical_peak_gb_s = 2000.0 # Default estimate, adjust based on your GPU + + achieved_percent = bandwidth_gb_s / theoretical_peak_gb_s * 100 + + return time_ms, bandwidth_gb_s, achieved_percent + + +def benchmark_fused_silu_mul_backward_bandwidth( + batch_size: int, + num_groups: int, + hidden_dim: int, + dtype: torch.dtype = torch.bfloat16, + num_iterations: int = 100, + num_warmup: int = 10, +) -> Tuple[float, float]: + """ + Benchmark fused SiLU * up backward kernel and calculate memory bandwidth. + + Backward reads: grad_output, gate, up (3 tensors) + Backward writes: grad_gate, grad_up (2 tensors) + Total: 5 * numel * bytes_per_element + + Returns: + (time_ms, bandwidth_gb_s) + """ + shape = (batch_size, num_groups, hidden_dim) + gate = torch.randn(shape, device="cuda", dtype=dtype) + up = torch.randn(shape, device="cuda", dtype=dtype) + grad_output = torch.randn(shape, device="cuda", dtype=dtype) + + numel = gate.numel() + bytes_per_element = gate.element_size() + # Read: grad_output + gate + up, Write: grad_gate + grad_up + total_bytes = 5 * numel * bytes_per_element + + # Warmup + for _ in range(num_warmup): + _ = triton_silu_mul_bwd(grad_output, gate, up) + torch.cuda.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iterations): + _ = triton_silu_mul_bwd(grad_output, gate, up) + end_event.record() + torch.cuda.synchronize() + + time_ms = start_event.elapsed_time(end_event) / num_iterations + time_s = time_ms / 1000.0 + bandwidth_gb_s = total_bytes / time_s / 1e9 + + return time_ms, bandwidth_gb_s + + def get_activation_fn(activation: Optional[str]) -> Optional[Callable]: """Get activation function by name.""" if activation is None: @@ -883,6 +992,37 @@ def main(): torch.cuda.profiler.stop() print("NVTX profiling complete.") + # Fused kernel bandwidth benchmark + print("\n" + "-" * 60) + print("Fused SiLU*Up Kernel Bandwidth Analysis") + print("-" * 60) + + fwd_time, fwd_bw, fwd_pct = benchmark_fused_silu_mul_bandwidth( + batch_size, num_groups, hidden_dim, dtype, num_iterations + ) + bwd_time, bwd_bw = benchmark_fused_silu_mul_backward_bandwidth( + batch_size, num_groups, hidden_dim, dtype, num_iterations + ) + + # Calculate data sizes for reference + numel = batch_size * num_groups * hidden_dim + bytes_per_elem = 2 if dtype == torch.bfloat16 else 4 + fwd_data_mb = 3 * numel * bytes_per_elem / 1e6 + bwd_data_mb = 5 * numel * bytes_per_elem / 1e6 + + print(f"\nTensor shape: ({batch_size}, {num_groups}, {hidden_dim})") + print(f"Elements: {numel:,} ({numel * bytes_per_elem / 1e6:.2f} MB per tensor)") + + print(f"\n{'Kernel':<20} {'Time (ms)':<12} {'Data (MB)':<12} {'BW (GB/s)':<12}") + print("-" * 56) + print(f"{'Forward (silu*up)':<20} {fwd_time:<12.4f} {fwd_data_mb:<12.2f} {fwd_bw:<12.1f}") + print(f"{'Backward':<20} {bwd_time:<12.4f} {bwd_data_mb:<12.2f} {bwd_bw:<12.1f}") + + print(f"\nNote: Peak memory bandwidth varies by GPU:") + print(f" - H100 SXM: ~3350 GB/s") + print(f" - A100 SXM: ~2039 GB/s") + print(f" - A100 PCIe: ~1935 GB/s") + # Summary print("\n" + "=" * 80) print("Summary") @@ -900,6 +1040,10 @@ def main(): Fwd+Bwd Speedup: Plan A vs Reference: {ref_fwdbwd/plan_a_fwdbwd:.2f}x + +Fused SiLU*Up Kernel: + Forward: {fwd_bw:.1f} GB/s + Backward: {bwd_bw:.1f} GB/s """) print("=" * 80) print("Done!") diff --git a/examples/commons/ops/grouped_mlp_customop.py b/examples/commons/ops/grouped_mlp_customop.py new file mode 100644 index 000000000..df126cdc6 --- /dev/null +++ b/examples/commons/ops/grouped_mlp_customop.py @@ -0,0 +1,862 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped MLP as PyTorch Custom Ops + +Structure +[1] Triton Kernels - SiLU*Up kernel +[2] Custom Ops - strided_bmm, silu_mul, grouped_mlp_gated_fwd/bwd +[3] nn.Module Wrappers - GroupedMLP_CustomOp, ReferenceGroupedMLP +[4] Correctness & Benchmark +""" + +import sys +sys.path.insert(0, '/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu') + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +try: + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + from triton.language.math import fast_dividef + +from ops.triton_ops.common import triton_autotune + + +# ============================================================================= +# [1] Triton Kernel Configurations +# ============================================================================= + +def silu_configs(): + configs = [] + for x_block_size in [256, 512, 1024, 2048]: + for num_warps in [2, 4, 8, 16]: + config = triton.Config({"x_block_size": x_block_size}, num_warps) + configs.append(config) + return configs + + +# ============================================================================= +# [1] Triton Kernels - SiLU * Up (SwiGLU pattern) +# ============================================================================= + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_forward_kernel( + output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """Fused forward: output = silu(gate) * up""" + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + silu_gate = fast_dividef(gate, 1.0 + tl.exp(-gate)) + output = (silu_gate * up).to(output_ptr.dtype.element_ty) + + tl.store(output_ptr + x_offset + cols, output, mask=mask) + + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_backward_kernel( + grad_gate_ptr: tl.tensor, + grad_up_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """Fused backward for output = silu(gate) * up""" + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + sigma = tl.sigmoid(gate) + silu_gate = gate * sigma + dsilu_dgate = sigma + gate * sigma * (1.0 - sigma) + + grad_gate = grad_output * up * dsilu_dgate + grad_up = grad_output * silu_gate + + tl.store(grad_gate_ptr + x_offset + cols, grad_gate.to(grad_gate_ptr.dtype.element_ty), mask=mask) + tl.store(grad_up_ptr + x_offset + cols, grad_up.to(grad_up_ptr.dtype.element_ty), mask=mask) + + +# ============================================================================= +# [1] Triton Kernel Launchers (Internal) +# ============================================================================= + +def _launch_silu_mul_fwd(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """Internal: launch forward kernel""" + x_size = gate.numel() + gate_1d = gate.reshape(-1).contiguous() + up_1d = up.reshape(-1).contiguous() + output = torch.empty_like(gate_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_forward_kernel[grid](output, gate_1d, up_1d, x_size) + return output.view(gate.shape) + + +def _launch_silu_mul_bwd( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Internal: launch backward kernel""" + shape = gate.shape + x_size = gate.numel() + gate_1d = gate.reshape(-1).contiguous() + up_1d = up.reshape(-1).contiguous() + grad_output_1d = grad_output.reshape(-1).contiguous() + grad_gate = torch.empty_like(gate_1d) + grad_up = torch.empty_like(up_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_backward_kernel[grid](grad_gate, grad_up, grad_output_1d, gate_1d, up_1d, x_size) + return grad_gate.view(shape), grad_up.view(shape) + + +# ============================================================================= +# [2] Custom Op: silu_mul +# ============================================================================= + +@torch.library.custom_op("grouped_mlp::silu_mul", mutates_args=(), device_types="cuda") +def silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """ + Fused SiLU multiplication: output = silu(gate) * up + + This is the SwiGLU activation pattern used in modern LLMs. + """ + torch._check(gate.shape == up.shape, lambda: f"Shape mismatch: {gate.shape} vs {up.shape}") + return _launch_silu_mul_fwd(gate.contiguous(), up.contiguous()) + + +@silu_mul.register_fake +def _(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + torch._check(gate.shape == up.shape) + return torch.empty_like(gate) + + +@torch.library.custom_op("grouped_mlp::silu_mul_backward", mutates_args=(), device_types="cuda") +def silu_mul_backward( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward for silu_mul: returns (grad_gate, grad_up)""" + return _launch_silu_mul_bwd( + grad_output.contiguous(), gate.contiguous(), up.contiguous() + ) + + +@silu_mul_backward.register_fake +def _(grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(gate), torch.empty_like(up) + + +def _silu_mul_bwd_fn(ctx, grad_output): + gate, up = ctx.saved_tensors + return silu_mul_backward(grad_output, gate, up) + + +def _silu_mul_setup_ctx(ctx, inputs, output): + gate, up = inputs + ctx.save_for_backward(gate, up) + + +silu_mul.register_autograd(_silu_mul_bwd_fn, setup_context=_silu_mul_setup_ctx) + + +# ============================================================================= +# [2] Custom Op: strided_bmm (Strided Batched Matrix Multiply) +# ============================================================================= + +@torch.library.custom_op("grouped_mlp::strided_bmm", mutates_args=(), device_types="cuda") +def strided_bmm( + x: torch.Tensor, # (B, N, K) + weight: torch.Tensor, # (N, K, M) +) -> torch.Tensor: + """ + Strided BMM: output[b, n, :] = x[b, n, :] @ weight[n, :, :] + + Input: x - (batch_size, num_groups, input_dim) + Weight: weight - (num_groups, input_dim, output_dim) + Output: - (batch_size, num_groups, output_dim) + + Internally uses permute + torch.bmm for efficiency. + """ + batch_size, num_groups, _ = x.shape + output_dim = weight.shape[2] + + output = torch.empty(batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype) + # x: (B, N, K) -> (N, B, K) + # weight: (N, K, M) + # out: (N, B, M) -> (B, N, M) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + return output + + +@strided_bmm.register_fake +def _(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + batch_size, num_groups, _ = x.shape + output_dim = weight.shape[2] + return x.new_empty(batch_size, num_groups, output_dim) + + +@torch.library.custom_op("grouped_mlp::strided_bmm_backward", mutates_args=(), device_types="cuda") +def strided_bmm_backward( + grad_output: torch.Tensor, # (B, N, M) + x: torch.Tensor, # (B, N, K) + weight: torch.Tensor, # (N, K, M) +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Backward for strided_bmm. + + grad_x = grad_output @ weight.T => (B, N, K) + grad_weight = x.T @ grad_output => (N, K, M) + """ + grad_output_t = grad_output.permute(1, 0, 2) # (N, B, M) + + # grad_x: (N, B, M) @ (N, M, K) -> (N, B, K) -> (B, N, K) + grad_x = torch.empty_like(x) + torch.bmm(grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2)) + + # grad_weight: (N, K, B) @ (N, B, M) -> (N, K, M) + x_t = x.permute(1, 0, 2) # (N, B, K) + grad_weight = torch.bmm(x_t.transpose(-1, -2), grad_output_t) + + return grad_x, grad_weight + + +@strided_bmm_backward.register_fake +def _(grad_output: torch.Tensor, x: torch.Tensor, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(weight) + + +def _strided_bmm_bwd_fn(ctx, grad_output): + x, weight = ctx.saved_tensors + return strided_bmm_backward(grad_output, x, weight) + + +def _strided_bmm_setup_ctx(ctx, inputs, output): + x, weight = inputs + ctx.save_for_backward(x, weight) + + +strided_bmm.register_autograd(_strided_bmm_bwd_fn, setup_context=_strided_bmm_setup_ctx) + + +def grouped_mlp_gated_forward( + x: torch.Tensor, # (B*N, D_in) + gate_weight: torch.Tensor, # (N, D_in, D_hidden) + up_weight: torch.Tensor, # (N, D_in, D_hidden) + down_weight: torch.Tensor, # (N, D_hidden, D_out) + num_groups: int, +) -> torch.Tensor: + """ + Grouped MLP forward with gating (SwiGLU pattern). + + This is a composition of custom ops (strided_bmm, silu_mul). + Autograd is handled automatically through the registered backward of each op. + + For each group n: + gate = x @ gate_weight[n] + up = x @ up_weight[n] + hidden = silu(gate) * up + output = hidden @ down_weight[n] + + Args: + x: Input tensor, shape (B*N, D_in) + gate_weight: Gate projection weights, shape (N, D_in, D_hidden) + up_weight: Up projection weights, shape (N, D_in, D_hidden) + down_weight: Down projection weights, shape (N, D_hidden, D_out) + num_groups: Number of groups (N) + + Returns: + Output tensor, shape (B*N, D_out) + """ + batch_size = x.shape[0] // num_groups + input_dim = gate_weight.shape[1] + output_dim = down_weight.shape[2] + + # Reshape: (B*N, D_in) -> (B, N, D_in) + x_3d = x.reshape(batch_size, num_groups, input_dim) + + # Gate BMM: (B, N, D_in) @ (N, D_in, D_hidden) -> (B, N, D_hidden) + gate = strided_bmm(x_3d, gate_weight) + + # Up BMM: (B, N, D_in) @ (N, D_in, D_hidden) -> (B, N, D_hidden) + up = strided_bmm(x_3d, up_weight) + + # Fused SiLU activation: hidden = silu(gate) * up + hidden = silu_mul(gate, up) + + # Down BMM: (B, N, D_hidden) @ (N, D_hidden, D_out) -> (B, N, D_out) + output = strided_bmm(hidden, down_weight) + + # Reshape: (B, N, D_out) -> (B*N, D_out) + return output.reshape(-1, output_dim) + + +# ============================================================================= +# [3] nn.Module: GroupedMLP_CustomOp (Plan A using Custom Ops) +# ============================================================================= + +class GroupedMLP_CustomOp(nn.Module): + """ + Grouped MLP using custom ops (Plan A: BMM-based). + + This implementation uses registered custom ops instead of autograd.Function, + making it compatible with torch.compile and scenarios where autograd is not available. + + Forward computation: + For each group n: + gate = x @ gate_weight[n] + up = x @ up_weight[n] + hidden = silu(gate) * up + output = hidden @ down_weight[n] + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + # Weights: (N, D_in, D_hidden) or (N, D_hidden, D_out) + self.gate_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + self.up_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + self.down_weight = nn.Parameter( + torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) + ) + + self._init_weights() + + def _init_weights(self): + for i in range(self.num_groups): + nn.init.xavier_normal_(self.gate_weight[i], gain=1.0) + nn.init.xavier_normal_(self.up_weight[i], gain=1.0) + nn.init.xavier_normal_(self.down_weight[i], gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass using custom ops. + + Args: + x: Input tensor, shape (B*N, D_in) + + Returns: + Output tensor, shape (B*N, D_out) + """ + return grouped_mlp_gated_forward( + x, + self.gate_weight, + self.up_weight, + self.down_weight, + self.num_groups, + ) + + def forward_decomposed(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with explicit steps (for debugging/profiling). + + Same computation as forward(), but with explicit intermediate tensors. + """ + batch_size = x.shape[0] // self.num_groups + + # Reshape + x_3d = x.reshape(batch_size, self.num_groups, self.input_dim) + + # 3 BMMs + 1 fused activation + gate = strided_bmm(x_3d, self.gate_weight) + up = strided_bmm(x_3d, self.up_weight) + hidden = silu_mul(gate, up) + output = strided_bmm(hidden, self.down_weight) + + return output.reshape(-1, self.output_dim) + + +# ============================================================================= +# [3] nn.Module: ReferenceGroupedMLP +# ============================================================================= + +class ReferenceGroupedMLP(nn.Module): + """ + Reference implementation using loop over groups with nn.Linear. + + This is the baseline for correctness verification. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.gate_proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + self.up_proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + self.down_proj = nn.ModuleList([ + nn.Linear(hidden_dim, output_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + + self._init_weights() + + def _init_weights(self): + for module_list in [self.gate_proj, self.up_proj, self.down_proj]: + for layer in module_list: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward with loop over groups.""" + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + hidden_i = F.silu(gate_i) * up_i + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + return torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + +# ============================================================================= +# [4] Weight Copy Utilities +# ============================================================================= + +def copy_weights_ref_to_customop( + ref_model: ReferenceGroupedMLP, + customop_model: GroupedMLP_CustomOp, +): + """Copy weights from Reference model to CustomOp model.""" + with torch.no_grad(): + for i in range(ref_model.num_groups): + # nn.Linear weight is (out, in), we need (in, out) + customop_model.gate_weight[i].copy_(ref_model.gate_proj[i].weight.T) + customop_model.up_weight[i].copy_(ref_model.up_proj[i].weight.T) + customop_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) + + +# ============================================================================= +# [5] Correctness Verification +# ============================================================================= + +def check_forward_correctness( + ref_model: ReferenceGroupedMLP, + customop_model: GroupedMLP_CustomOp, + batch_size: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[bool, float]: + """Check forward correctness.""" + num_groups = ref_model.num_groups + input_dim = ref_model.input_dim + + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + + with torch.no_grad(): + ref_out = ref_model(x) + customop_out = customop_model(x) + + max_diff = (ref_out - customop_out).abs().max().item() + # bf16 has limited precision, use looser tolerance + atol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + rtol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + passed = torch.allclose(ref_out, customop_out, atol=atol, rtol=rtol) + + return passed, max_diff + + +def check_backward_correctness( + ref_model: ReferenceGroupedMLP, + customop_model: GroupedMLP_CustomOp, + batch_size: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[bool, float]: + """Check backward correctness (input gradient).""" + num_groups = ref_model.num_groups + input_dim = ref_model.input_dim + + x_ref = torch.randn( + batch_size * num_groups, input_dim, + device="cuda", dtype=dtype, requires_grad=True + ) + x_customop = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + customop_out = customop_model(x_customop) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + customop_out.backward(grad_output) + + max_diff = (x_ref.grad - x_customop.grad).abs().max().item() + # bf16 has ~3 significant digits precision, multiple BMM ops accumulate error + # Use looser tolerance: 5e-2 for bf16 (vs 1e-2 for fp32) + atol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + rtol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + passed = torch.allclose(x_ref.grad, x_customop.grad, atol=atol, rtol=rtol) + + return passed, max_diff + + +def check_torch_compile(customop_model: GroupedMLP_CustomOp, batch_size: int) -> bool: + """Check torch.compile compatibility.""" + num_groups = customop_model.num_groups + input_dim = customop_model.input_dim + + @torch.compile(fullgraph=True) + def compiled_forward(model, x): + return model(x) + + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=torch.bfloat16) + + try: + out = compiled_forward(customop_model, x) + return out.shape == (batch_size * num_groups, customop_model.output_dim) + except Exception as e: + print(f"torch.compile failed: {e}") + return False + + +def run_opcheck(customop_model: GroupedMLP_CustomOp): + """Run opcheck on individual custom ops.""" + print("\nRunning opcheck on custom ops...") + + # Test silu_mul + examples_silu = [ + [torch.randn(32, 64, device="cuda", dtype=torch.bfloat16), + torch.randn(32, 64, device="cuda", dtype=torch.bfloat16)], + [torch.randn(16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True), + torch.randn(16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True)], + ] + for i, ex in enumerate(examples_silu): + try: + torch.library.opcheck(silu_mul, ex) + print(f" silu_mul example {i+1}: PASSED") + except Exception as e: + print(f" silu_mul example {i+1}: FAILED - {e}") + + # Test strided_bmm + examples_bmm = [ + [torch.randn(32, 8, 64, device="cuda", dtype=torch.bfloat16), + torch.randn(8, 64, 128, device="cuda", dtype=torch.bfloat16)], + [torch.randn(16, 12, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True), + torch.randn(12, 256, 512, device="cuda", dtype=torch.bfloat16, requires_grad=True)], + ] + for i, ex in enumerate(examples_bmm): + try: + torch.library.opcheck(strided_bmm, ex) + print(f" strided_bmm example {i+1}: PASSED") + except Exception as e: + print(f" strided_bmm example {i+1}: FAILED - {e}") + + +# ============================================================================= +# [6] Benchmark Functions +# ============================================================================= + +def warmup_gpu(): + """Warmup GPU.""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +def benchmark_forward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, +) -> float: + """Benchmark forward pass.""" + for i in range(num_warmup): + _ = model(x_list[i % len(x_list)]) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for i in range(num_iterations): + _ = model(x_list[i % len(x_list)]) + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / num_iterations + + +def benchmark_forward_backward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, +) -> float: + """Benchmark forward + backward pass.""" + output_dim = model.output_dim + grad_outputs = [ + torch.randn(xi.shape[0], output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + x_with_grad = [xi.requires_grad_(True) for xi in x_list] + params = list(model.parameters()) + + for i in range(num_warmup): + xi = x_with_grad[i % len(x_list)] + out = model(xi) + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for i in range(num_iterations): + xi = x_with_grad[i % len(x_list)] + out = model(xi) + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / num_iterations + + +def benchmark_silu_mul_bandwidth( + shape: Tuple[int, ...], + dtype: torch.dtype = torch.bfloat16, + num_iterations: int = 100, +) -> Tuple[float, float, float, float]: + """ + Benchmark silu_mul kernel bandwidth. + + Returns: (fwd_time_ms, fwd_bw_gb_s, bwd_time_ms, bwd_bw_gb_s) + """ + gate = torch.randn(shape, device="cuda", dtype=dtype) + up = torch.randn(shape, device="cuda", dtype=dtype) + grad_output = torch.randn(shape, device="cuda", dtype=dtype) + + numel = gate.numel() + bytes_per_elem = gate.element_size() + fwd_bytes = 3 * numel * bytes_per_elem # read gate, up; write output + bwd_bytes = 5 * numel * bytes_per_elem # read grad_out, gate, up; write grad_gate, grad_up + + # Warmup + for _ in range(10): + _ = silu_mul(gate, up) + _ = silu_mul_backward(grad_output, gate, up) + torch.cuda.synchronize() + + # Forward benchmark + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(num_iterations): + _ = silu_mul(gate, up) + end.record() + torch.cuda.synchronize() + fwd_ms = start.elapsed_time(end) / num_iterations + fwd_bw = fwd_bytes / (fwd_ms / 1000) / 1e9 + + # Backward benchmark + start.record() + for _ in range(num_iterations): + _ = silu_mul_backward(grad_output, gate, up) + end.record() + torch.cuda.synchronize() + bwd_ms = start.elapsed_time(end) / num_iterations + bwd_bw = bwd_bytes / (bwd_ms / 1000) / 1e9 + + return fwd_ms, fwd_bw, bwd_ms, bwd_bw + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + print("=" * 70) + print("Grouped MLP as PyTorch Custom Ops") + print("=" * 70) + + torch.cuda.init() + print(f"\nDevice: {torch.cuda.get_device_name(0)}") + + # Configuration + batch_size = 2560 + num_groups = 12 + input_dim = 1024 + hidden_dim = 3072 + output_dim = 1024 + dtype = torch.bfloat16 + num_iterations = 100 + + print(f""" +Config: + Batch size: {batch_size} + Num groups: {num_groups} + Dimensions: {input_dim} -> {hidden_dim} -> {output_dim} + Dtype: {dtype} +""") + + print("Warming up GPU...") + warmup_gpu() + + # Create models + print("Creating models...") + ref_model = ReferenceGroupedMLP( + input_dim, hidden_dim, output_dim, num_groups, dtype=dtype + ).cuda() + + customop_model = GroupedMLP_CustomOp( + input_dim, hidden_dim, output_dim, num_groups, dtype=dtype + ).cuda() + + # Copy weights + copy_weights_ref_to_customop(ref_model, customop_model) + + # ========================= + # Correctness Verification + # ========================= + print("\n" + "-" * 70) + print("Correctness Verification") + print("-" * 70) + + fwd_ok, fwd_diff = check_forward_correctness(ref_model, customop_model, batch_size, dtype) + print(f"\nForward: {'✓ PASS' if fwd_ok else '✗ FAIL'} (max_diff: {fwd_diff:.2e})") + + bwd_ok, bwd_diff = check_backward_correctness(ref_model, customop_model, batch_size, dtype) + print(f"Backward: {'✓ PASS' if bwd_ok else '✗ FAIL'} (max_diff: {bwd_diff:.2e})") + + compile_ok = check_torch_compile(customop_model, batch_size) + print(f"torch.compile(fullgraph=True): {'✓ PASS' if compile_ok else '✗ FAIL'}") + + # ========================= + # opcheck + # ========================= + print("\n" + "-" * 70) + print("Op Registration Check") + print("-" * 70) + run_opcheck(customop_model) + + # ========================= + # Performance Benchmark + # ========================= + print("\n" + "-" * 70) + print("Performance Benchmark") + print("-" * 70) + + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Forward + print("\n>>> Forward Pass <<<") + ref_fwd = benchmark_forward(ref_model, x_list, num_iterations) + customop_fwd = benchmark_forward(customop_model, x_list, num_iterations) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") + print(f"{'CustomOp (BMM)':<30} {customop_fwd:<12.4f} {ref_fwd/customop_fwd:<10.2f}x") + + # Forward + Backward + print("\n>>> Forward + Backward <<<") + ref_fwdbwd = benchmark_forward_backward(ref_model, x_list, num_iterations) + customop_fwdbwd = benchmark_forward_backward(customop_model, x_list, num_iterations) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") + print(f"{'CustomOp (BMM)':<30} {customop_fwdbwd:<12.4f} {ref_fwdbwd/customop_fwdbwd:<10.2f}x") + + # ========================= + # SiLU*Up Kernel Bandwidth + # ========================= + print("\n" + "-" * 70) + print("Fused SiLU*Up Kernel Bandwidth") + print("-" * 70) + + shape = (batch_size, num_groups, hidden_dim) + fwd_ms, fwd_bw, bwd_ms, bwd_bw = benchmark_silu_mul_bandwidth(shape, dtype) + + print(f"\nShape: {shape}") + print(f"\n{'Kernel':<15} {'Time (ms)':<12} {'BW (GB/s)':<12}") + print("-" * 40) + print(f"{'Forward':<15} {fwd_ms:<12.4f} {fwd_bw:<12.1f}") + print(f"{'Backward':<15} {bwd_ms:<12.4f} {bwd_bw:<12.1f}") + + # ========================= + # Summary + # ========================= + print("\n" + "=" * 70) + print("Summary") + print("=" * 70) + print(f""" + Forward Speedup: {ref_fwd/customop_fwd:.2f}x + Fwd+Bwd Speedup: {ref_fwdbwd/customop_fwdbwd:.2f}x + SiLU*Up BW: {fwd_bw:.1f} GB/s (fwd), {bwd_bw:.1f} GB/s (bwd) +""") + print("=" * 70) + print("Done!") + + +if __name__ == "__main__": + main() + From ba68e30058335fd95f79e02269bc173dc996f1ae Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 20 Jan 2026 10:09:47 +0000 Subject: [PATCH 7/7] Pre commit --- examples/commons/ops/GroupedMLP_example.py | 391 +++++++++++------- examples/commons/ops/grouped_mlp_customop.py | 397 +++++++++++-------- 2 files changed, 489 insertions(+), 299 deletions(-) diff --git a/examples/commons/ops/GroupedMLP_example.py b/examples/commons/ops/GroupedMLP_example.py index edfc1c108..0736c2df8 100644 --- a/examples/commons/ops/GroupedMLP_example.py +++ b/examples/commons/ops/GroupedMLP_example.py @@ -20,17 +20,21 @@ ================================================================================ """ import sys -sys.path.insert(0, '/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu') + +sys.path.insert( + 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" +) import argparse from typing import Callable, List, Optional, Tuple, Union import torch +import torch.cuda.nvtx as nvtx import torch.nn as nn import torch.nn.functional as F -import torch.cuda.nvtx as nvtx import triton import triton.language as tl + try: # @manual=//triton:triton from triton.language.extra.libdevice import fast_dividef @@ -45,6 +49,7 @@ from ops.triton_ops.common import triton_autotune + def silu_configs(): configs = [] for x_block_size in [256, 512, 1024, 2048]: @@ -54,13 +59,11 @@ def silu_configs(): return configs - - - # ============================================================================= # Fused SiLU * Up (SwiGLU pattern): output = silu(gate) * up # ============================================================================= + @triton_autotune(silu_configs(), key=["x_size"]) @triton.jit def _silu_mul_forward( @@ -98,7 +101,7 @@ def _silu_mul_backward( ): """ Fused backward for output = silu(gate) * up - + grad_gate = grad_output * up * d(silu)/d(gate) = grad_output * up * (sigmoid(gate) + gate * sigmoid(gate) * (1 - sigmoid(gate))) grad_up = grad_output * silu(gate) @@ -107,29 +110,43 @@ def _silu_mul_backward( mask = x_offset + tl.arange(0, x_block_size) < x_size cols = tl.arange(0, x_block_size) - grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to( + tl.float32 + ) gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) sigma = tl.sigmoid(gate) silu_gate = gate * sigma - + # d(silu)/d(gate) = sigma + gate * sigma * (1 - sigma) dsilu_dgate = sigma + gate * sigma * (1.0 - sigma) - + grad_gate = grad_output * up * dsilu_dgate grad_up = grad_output * silu_gate - tl.store(grad_gate_ptr + x_offset + cols, grad_gate.to(grad_gate_ptr.dtype.element_ty), mask=mask) - tl.store(grad_up_ptr + x_offset + cols, grad_up.to(grad_up_ptr.dtype.element_ty), mask=mask) + tl.store( + grad_gate_ptr + x_offset + cols, + grad_gate.to(grad_gate_ptr.dtype.element_ty), + mask=mask, + ) + tl.store( + grad_up_ptr + x_offset + cols, + grad_up.to(grad_up_ptr.dtype.element_ty), + mask=mask, + ) def triton_silu_mul_fwd(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: """Forward: output = silu(gate) * up""" assert gate.shape == up.shape, f"Shape mismatch: gate {gate.shape} vs up {up.shape}" x_size = gate.numel() - gate_1d = gate.view(-1).contiguous() - up_1d = up.view(-1).contiguous() + + with torch.cuda.nvtx.range("gate_contiguous"): + gate_1d = gate.view(-1).contiguous() + + with torch.cuda.nvtx.range("up_contiguous"): + up_1d = up.view(-1).contiguous() output = torch.empty_like(gate_1d) def grid(meta): @@ -172,7 +189,7 @@ def grid(meta): class TritonSiluMul(torch.autograd.Function): """Autograd function for fused silu(gate) * up""" - + @staticmethod def forward(ctx, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: output = triton_silu_mul_fwd(gate, up) @@ -189,13 +206,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor def triton_silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: """ Fused SiLU multiplication (SwiGLU pattern). - + Computes: output = silu(gate) * up - + Args: gate: Input tensor that goes through SiLU activation up: Input tensor that multiplies with activated gate - + Returns: output: silu(gate) * up """ @@ -222,7 +239,7 @@ def benchmark_fused_silu_mul_bandwidth( ) -> Tuple[float, float, float]: """ Benchmark fused SiLU * up kernel and calculate memory bandwidth. - + Returns: (time_ms, bandwidth_gb_s, achieved_percent) """ @@ -230,44 +247,44 @@ def benchmark_fused_silu_mul_bandwidth( shape = (batch_size, num_groups, hidden_dim) gate = torch.randn(shape, device="cuda", dtype=dtype) up = torch.randn(shape, device="cuda", dtype=dtype) - + # Calculate data volume numel = gate.numel() bytes_per_element = gate.element_size() # 2 for bf16, 4 for fp32 # Read: gate + up, Write: output total_bytes = 3 * numel * bytes_per_element - + # Warmup for _ in range(num_warmup): _ = triton_silu_mul(gate, up) torch.cuda.synchronize() - + # Benchmark start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - + start_event.record() for _ in range(num_iterations): _ = triton_silu_mul(gate, up) end_event.record() torch.cuda.synchronize() - + time_ms = start_event.elapsed_time(end_event) / num_iterations time_s = time_ms / 1000.0 - + # Calculate bandwidth bandwidth_gb_s = total_bytes / time_s / 1e9 - + # Get theoretical peak bandwidth (for reference) # H100 SXM: ~3.35 TB/s, A100 SXM: ~2.0 TB/s, A100 PCIe: ~1.9 TB/s # You can query this from torch.cuda.get_device_properties - props = torch.cuda.get_device_properties(0) + torch.cuda.get_device_properties(0) # Memory bandwidth = memory_clock_rate (kHz) * bus_width (bits) * 2 (DDR) / 8 (bits to bytes) # Note: This is approximate; actual peak may differ theoretical_peak_gb_s = 2000.0 # Default estimate, adjust based on your GPU - + achieved_percent = bandwidth_gb_s / theoretical_peak_gb_s * 100 - + return time_ms, bandwidth_gb_s, achieved_percent @@ -281,11 +298,11 @@ def benchmark_fused_silu_mul_backward_bandwidth( ) -> Tuple[float, float]: """ Benchmark fused SiLU * up backward kernel and calculate memory bandwidth. - + Backward reads: grad_output, gate, up (3 tensors) Backward writes: grad_gate, grad_up (2 tensors) Total: 5 * numel * bytes_per_element - + Returns: (time_ms, bandwidth_gb_s) """ @@ -293,31 +310,31 @@ def benchmark_fused_silu_mul_backward_bandwidth( gate = torch.randn(shape, device="cuda", dtype=dtype) up = torch.randn(shape, device="cuda", dtype=dtype) grad_output = torch.randn(shape, device="cuda", dtype=dtype) - + numel = gate.numel() bytes_per_element = gate.element_size() # Read: grad_output + gate + up, Write: grad_gate + grad_up total_bytes = 5 * numel * bytes_per_element - + # Warmup for _ in range(num_warmup): _ = triton_silu_mul_bwd(grad_output, gate, up) torch.cuda.synchronize() - + # Benchmark start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - + start_event.record() for _ in range(num_iterations): _ = triton_silu_mul_bwd(grad_output, gate, up) end_event.record() torch.cuda.synchronize() - + time_ms = start_event.elapsed_time(end_event) / num_iterations time_s = time_ms / 1000.0 bandwidth_gb_s = total_bytes / time_s / 1e9 - + return time_ms, bandwidth_gb_s @@ -343,6 +360,7 @@ def get_activation_fn(activation: Optional[str]) -> Optional[Callable]: # Reference Implementation # ============================================================================= + class ReferenceGroupedMLP(nn.Module): """Reference implementation using loop over groups.""" @@ -366,32 +384,50 @@ def __init__( self.act_fn = get_activation_fn(activation) if use_gating: - self.gate_proj = nn.ModuleList([ - nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) - for _ in range(num_groups) - ]) - self.up_proj = nn.ModuleList([ - nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) - for _ in range(num_groups) - ]) + self.gate_proj = nn.ModuleList( + [ + nn.Linear( + input_dim, hidden_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) + self.up_proj = nn.ModuleList( + [ + nn.Linear( + input_dim, hidden_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) else: - self.proj = nn.ModuleList([ - nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) - for _ in range(num_groups) - ]) + self.proj = nn.ModuleList( + [ + nn.Linear( + input_dim, hidden_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) - self.down_proj = nn.ModuleList([ - nn.Linear(hidden_dim, output_dim, bias=False, device=device, dtype=dtype) - for _ in range(num_groups) - ]) + self.down_proj = nn.ModuleList( + [ + nn.Linear( + hidden_dim, output_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) self._init_weights() def _init_weights(self): - for module_list in [getattr(self, 'gate_proj', []), - getattr(self, 'up_proj', []), - getattr(self, 'proj', []), - self.down_proj]: + for module_list in [ + getattr(self, "gate_proj", []), + getattr(self, "up_proj", []), + getattr(self, "proj", []), + self.down_proj, + ]: for layer in module_list: nn.init.xavier_normal_(layer.weight, gain=1.0) @@ -399,10 +435,10 @@ def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: if enable_nvtx: with nvtx.range("Ref_reshape"): x = x.reshape(-1, self.num_groups, self.input_dim) - + with nvtx.range("Ref_split"): x_split = torch.split(x, 1, dim=1) - + with nvtx.range("Ref_loop_gemm"): out_list = [] for i in range(self.num_groups): @@ -420,13 +456,13 @@ def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: hidden_i = self.act_fn(hidden_i) out_i = self.down_proj[i](hidden_i) out_list.append(out_i) - + with nvtx.range("Ref_stack"): output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) else: x = x.reshape(-1, self.num_groups, self.input_dim) x_split = torch.split(x, 1, dim=1) - + out_list = [] for i in range(self.num_groups): x_i = x_split[i].squeeze(1) @@ -443,16 +479,16 @@ def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: hidden_i = self.act_fn(hidden_i) out_i = self.down_proj[i](hidden_i) out_list.append(out_i) - + output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) - + return output def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: """Forward pass up to hidden (excluding down projection).""" x = x.reshape(-1, self.num_groups, self.input_dim) x_split = torch.split(x, 1, dim=1) - + hidden_list = [] for i in range(self.num_groups): x_i = x_split[i].squeeze(1) @@ -468,7 +504,7 @@ def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: if self.act_fn is not None: hidden_i = self.act_fn(hidden_i) hidden_list.append(hidden_i) - + return torch.stack(hidden_list, dim=1).reshape(-1, self.hidden_dim) @@ -476,34 +512,38 @@ def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: # Strided BMM Function # ============================================================================= + class StridedBmmFunction(torch.autograd.Function): """Custom autograd function for BMM with strided output.""" @staticmethod def forward(ctx, x, weight, batch_size, num_groups, output_dim): ctx.save_for_backward(x, weight) - - output = torch.empty(batch_size, num_groups, output_dim, - device=x.device, dtype=x.dtype) + + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) - + return output @staticmethod def backward(ctx, grad_output): x, weight = ctx.saved_tensors grad_x = grad_weight = None - + grad_output_t = grad_output.permute(1, 0, 2) - + if ctx.needs_input_grad[0]: grad_x = torch.empty_like(x) - torch.bmm(grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2)) - + torch.bmm( + grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2) + ) + if ctx.needs_input_grad[1]: x_t = x.permute(1, 0, 2) grad_weight = torch.bmm(x_t.transpose(-1, -2), grad_output_t) - + return grad_x, grad_weight, None, None, None @@ -511,6 +551,7 @@ def backward(ctx, grad_output): # Plan A: 3 Independent BMMs # ============================================================================= + class GroupedMLP_PlanA(nn.Module): """Plan A: 3 independent strided BMMs (gate, up, down).""" @@ -535,14 +576,20 @@ def __init__( if use_gating: self.gate_weight = nn.Parameter( - torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + torch.empty( + num_groups, input_dim, hidden_dim, device=device, dtype=dtype + ) ) self.up_weight = nn.Parameter( - torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + torch.empty( + num_groups, input_dim, hidden_dim, device=device, dtype=dtype + ) ) else: self.proj_weight = nn.Parameter( - torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + torch.empty( + num_groups, input_dim, hidden_dim, device=device, dtype=dtype + ) ) self.down_weight = nn.Parameter( @@ -562,46 +609,59 @@ def _init_weights(self): def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: batch_size = x.shape[0] // self.num_groups - + if enable_nvtx: with nvtx.range("PlanA_reshape"): x = x.reshape(batch_size, self.num_groups, self.input_dim) - + if self.use_gating: with nvtx.range("PlanA_gate_bmm"): gate = StridedBmmFunction.apply( - x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + x, + self.gate_weight, + batch_size, + self.num_groups, + self.hidden_dim, ) - + with nvtx.range("PlanA_up_bmm"): up = StridedBmmFunction.apply( x, self.up_weight, batch_size, self.num_groups, self.hidden_dim ) - + with nvtx.range("PlanA_activation"): if self.act_fn is not None: - hidden = self.act_fn(gate) * up + # hidden = self.act_fn(gate) * up + hidden = triton_silu_mul(gate, up) else: hidden = gate * up else: with nvtx.range("PlanA_proj_bmm"): hidden = StridedBmmFunction.apply( - x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + x, + self.proj_weight, + batch_size, + self.num_groups, + self.hidden_dim, ) with nvtx.range("PlanA_activation"): if self.act_fn is not None: hidden = self.act_fn(hidden) - + with nvtx.range("PlanA_down_bmm"): output = StridedBmmFunction.apply( - hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + hidden, + self.down_weight, + batch_size, + self.num_groups, + self.output_dim, ) - + with nvtx.range("PlanA_view"): return output.view(-1, self.output_dim) else: x = x.reshape(batch_size, self.num_groups, self.input_dim) - + if self.use_gating: gate = StridedBmmFunction.apply( x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim @@ -620,18 +680,18 @@ def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: ) if self.act_fn is not None: hidden = self.act_fn(hidden) - + output = StridedBmmFunction.apply( hidden, self.down_weight, batch_size, self.num_groups, self.output_dim ) - + return output.view(-1, self.output_dim) def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: """Forward pass up to hidden (excluding down projection).""" batch_size = x.shape[0] // self.num_groups x = x.reshape(batch_size, self.num_groups, self.input_dim) - + if self.use_gating: gate = StridedBmmFunction.apply( x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim @@ -649,7 +709,7 @@ def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: ) if self.act_fn is not None: hidden = self.act_fn(hidden) - + return hidden.view(-1, self.hidden_dim) @@ -657,6 +717,7 @@ def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: # Weight Copy Utilities # ============================================================================= + def copy_weights_to_plan_a(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanA): with torch.no_grad(): num_groups = ref_model.num_groups @@ -675,6 +736,7 @@ def copy_weights_to_plan_a(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP # Correctness Check # ============================================================================= + def check_correctness( ref_model: nn.Module, opt_model: nn.Module, @@ -691,8 +753,11 @@ def check_correctness( fwd_diff = (ref_out - opt_out).abs().max().item() x_ref = torch.randn( - batch_size * num_groups, input_dim, - device="cuda", dtype=dtype, requires_grad=True + batch_size * num_groups, + input_dim, + device="cuda", + dtype=dtype, + requires_grad=True, ) x_opt = x_ref.detach().clone().requires_grad_(True) @@ -712,6 +777,7 @@ def check_correctness( # Benchmark Functions (same as benchmark_batched_gemm.py) # ============================================================================= + def benchmark_forward( model: nn.Module, x_list: List[torch.Tensor], @@ -721,16 +787,16 @@ def benchmark_forward( ) -> float: """Benchmark forward pass using CUDA events for accurate GPU timing.""" model_name = model.__class__.__name__ - + # Warmup for i in range(num_warmup): _ = model(x_list[i % len(x_list)], enable_nvtx=False) torch.cuda.synchronize() - + # Create CUDA events start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - + # Benchmark start_event.record() if enable_nvtx: @@ -742,7 +808,7 @@ def benchmark_forward( _ = model(x_list[i % len(x_list)], enable_nvtx=False) end_event.record() torch.cuda.synchronize() - + return start_event.elapsed_time(end_event) / num_iterations @@ -757,18 +823,18 @@ def benchmark_forward_to_hidden( for i in range(num_warmup): _ = model.forward_to_hidden(x_list[i % len(x_list)]) torch.cuda.synchronize() - + # Create CUDA events start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - + # Benchmark start_event.record() for i in range(num_iterations): _ = model.forward_to_hidden(x_list[i % len(x_list)]) end_event.record() torch.cuda.synchronize() - + return start_event.elapsed_time(end_event) / num_iterations @@ -782,15 +848,15 @@ def benchmark_forward_backward( """Benchmark forward + backward pass using CUDA events.""" model_name = model.__class__.__name__ output_dim = model.output_dim - + grad_outputs = [ torch.randn(xi.shape[0], output_dim, device="cuda", dtype=xi.dtype) for xi in x_list ] - + x_with_grad = [xi.requires_grad_(True) for xi in x_list] params = list(model.parameters()) - + # Warmup for i in range(num_warmup): xi = x_with_grad[i % len(x_list)] @@ -801,11 +867,11 @@ def benchmark_forward_backward( grad_outputs=grad_outputs[i % len(x_list)], ) torch.cuda.synchronize() - + # Create CUDA events start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - + # Benchmark start_event.record() if enable_nvtx: @@ -843,7 +909,7 @@ def benchmark_forward_backward( ) end_event.record() torch.cuda.synchronize() - + return start_event.elapsed_time(end_event) / num_iterations @@ -851,6 +917,7 @@ def benchmark_forward_backward( # Main # ============================================================================= + def main(): parser = argparse.ArgumentParser( description="Grouped MLP Benchmark: Reference vs Plan A" @@ -860,14 +927,22 @@ def main(): parser.add_argument("--input-dim", type=int, default=1024) parser.add_argument("--hidden-dim", type=int, default=3072) parser.add_argument("--output-dim", type=int, default=1024) - parser.add_argument("--activation", type=str, default="silu", - choices=["silu", "gelu", "relu", "tanh", "none"]) + parser.add_argument( + "--activation", + type=str, + default="silu", + choices=["silu", "gelu", "relu", "tanh", "none"], + ) parser.add_argument("--no-gating", action="store_true") parser.add_argument("--iterations", type=int, default=100) - parser.add_argument("--enable-nvtx", action="store_true", - help="Enable NVTX markers (use with nsys profile)") - parser.add_argument("--compile", action="store_true", - help="Use torch.compile() to optimize models") + parser.add_argument( + "--enable-nvtx", + action="store_true", + help="Enable NVTX markers (use with nsys profile)", + ) + parser.add_argument( + "--compile", action="store_true", help="Use torch.compile() to optimize models" + ) args = parser.parse_args() torch.cuda.init() @@ -885,12 +960,13 @@ def main(): print("=" * 80) print("Grouped MLP Benchmark: Reference vs Plan A") print("=" * 80) - + if args.enable_nvtx: print("\n*** NVTX PROFILING MODE ***") print("Run with: nsys profile -o --trace=cuda,nvtx python ...") - - print(f""" + + print( + f""" Config: Batch size: {batch_size} Num groups: {num_groups} @@ -900,7 +976,8 @@ def main(): Dtype: {dtype} Device: {torch.cuda.get_device_name(0)} Iterations: {num_iterations} -""") +""" + ) print("Warming up GPU...") warmup_gpu() @@ -908,13 +985,23 @@ def main(): # Create models print("Creating models...") ref_model = ReferenceGroupedMLP( - input_dim, hidden_dim, output_dim, num_groups, - use_gating=use_gating, activation=activation, dtype=dtype + input_dim, + hidden_dim, + output_dim, + num_groups, + use_gating=use_gating, + activation=activation, + dtype=dtype, ).cuda() plan_a_model = GroupedMLP_PlanA( - input_dim, hidden_dim, output_dim, num_groups, - use_gating=use_gating, activation=activation, dtype=dtype + input_dim, + hidden_dim, + output_dim, + num_groups, + use_gating=use_gating, + activation=activation, + dtype=dtype, ).cuda() copy_weights_to_plan_a(ref_model, plan_a_model) @@ -930,8 +1017,10 @@ def main(): print("-" * 60) print("Correctness Check") print("-" * 60) - - fwd_a, bwd_a = check_correctness(ref_model, plan_a_model, batch_size, num_groups, input_dim, dtype) + + fwd_a, bwd_a = check_correctness( + ref_model, plan_a_model, batch_size, num_groups, input_dim, dtype + ) print(f"Plan A - Forward diff: {fwd_a:.2e}, Backward diff: {bwd_a:.2e}") # Prepare test data @@ -948,12 +1037,16 @@ def main(): # Forward - always benchmark without NVTX print("\n>>> Forward Pass <<<") ref_fwd = benchmark_forward(ref_model, x_list, num_iterations, enable_nvtx=False) - plan_a_fwd = benchmark_forward(plan_a_model, x_list, num_iterations, enable_nvtx=False) + plan_a_fwd = benchmark_forward( + plan_a_model, x_list, num_iterations, enable_nvtx=False + ) print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") print("-" * 52) print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") - print(f"{'Plan A (batched BMM)':<30} {plan_a_fwd:<12.4f} {ref_fwd/plan_a_fwd:<10.2f}x") + print( + f"{'Plan A (batched BMM)':<30} {plan_a_fwd:<12.4f} {ref_fwd/plan_a_fwd:<10.2f}x" + ) # Forward to Hidden (excluding down projection) print("\n>>> Forward to Hidden (excluding down_proj) <<<") @@ -963,17 +1056,25 @@ def main(): print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") print("-" * 52) print(f"{'Reference (loop)':<30} {ref_hidden:<12.4f} {'1.00x':<10}") - print(f"{'Plan A (batched BMM)':<30} {plan_a_hidden:<12.4f} {ref_hidden/plan_a_hidden:<10.2f}x") + print( + f"{'Plan A (batched BMM)':<30} {plan_a_hidden:<12.4f} {ref_hidden/plan_a_hidden:<10.2f}x" + ) # Forward + Backward - always benchmark without NVTX print("\n>>> Forward + Backward <<<") - ref_fwdbwd = benchmark_forward_backward(ref_model, x_list, num_iterations, enable_nvtx=False) - plan_a_fwdbwd = benchmark_forward_backward(plan_a_model, x_list, num_iterations, enable_nvtx=False) + ref_fwdbwd = benchmark_forward_backward( + ref_model, x_list, num_iterations, enable_nvtx=False + ) + plan_a_fwdbwd = benchmark_forward_backward( + plan_a_model, x_list, num_iterations, enable_nvtx=False + ) print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") print("-" * 52) print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") - print(f"{'Plan A (batched BMM)':<30} {plan_a_fwdbwd:<12.4f} {ref_fwdbwd/plan_a_fwdbwd:<10.2f}x") + print( + f"{'Plan A (batched BMM)':<30} {plan_a_fwdbwd:<12.4f} {ref_fwdbwd/plan_a_fwdbwd:<10.2f}x" + ) # NVTX profiling run (separate from benchmark) if args.enable_nvtx: @@ -981,14 +1082,18 @@ def main(): print("NVTX Profiling Run (for nsys analysis only)") print("-" * 60) torch.cuda.profiler.start() - + # Run a few iterations with NVTX for profiling nvtx_iterations = min(10, num_iterations) _ = benchmark_forward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) _ = benchmark_forward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) - _ = benchmark_forward_backward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) - _ = benchmark_forward_backward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) - + _ = benchmark_forward_backward( + ref_model, x_list, nvtx_iterations, enable_nvtx=True + ) + _ = benchmark_forward_backward( + plan_a_model, x_list, nvtx_iterations, enable_nvtx=True + ) + torch.cuda.profiler.stop() print("NVTX profiling complete.") @@ -996,38 +1101,41 @@ def main(): print("\n" + "-" * 60) print("Fused SiLU*Up Kernel Bandwidth Analysis") print("-" * 60) - + fwd_time, fwd_bw, fwd_pct = benchmark_fused_silu_mul_bandwidth( batch_size, num_groups, hidden_dim, dtype, num_iterations ) bwd_time, bwd_bw = benchmark_fused_silu_mul_backward_bandwidth( batch_size, num_groups, hidden_dim, dtype, num_iterations ) - + # Calculate data sizes for reference numel = batch_size * num_groups * hidden_dim bytes_per_elem = 2 if dtype == torch.bfloat16 else 4 fwd_data_mb = 3 * numel * bytes_per_elem / 1e6 bwd_data_mb = 5 * numel * bytes_per_elem / 1e6 - + print(f"\nTensor shape: ({batch_size}, {num_groups}, {hidden_dim})") print(f"Elements: {numel:,} ({numel * bytes_per_elem / 1e6:.2f} MB per tensor)") - + print(f"\n{'Kernel':<20} {'Time (ms)':<12} {'Data (MB)':<12} {'BW (GB/s)':<12}") print("-" * 56) - print(f"{'Forward (silu*up)':<20} {fwd_time:<12.4f} {fwd_data_mb:<12.2f} {fwd_bw:<12.1f}") + print( + f"{'Forward (silu*up)':<20} {fwd_time:<12.4f} {fwd_data_mb:<12.2f} {fwd_bw:<12.1f}" + ) print(f"{'Backward':<20} {bwd_time:<12.4f} {bwd_data_mb:<12.2f} {bwd_bw:<12.1f}") - + print(f"\nNote: Peak memory bandwidth varies by GPU:") print(f" - H100 SXM: ~3350 GB/s") - print(f" - A100 SXM: ~2039 GB/s") + print(f" - A100 SXM: ~2039 GB/s") print(f" - A100 PCIe: ~1935 GB/s") # Summary print("\n" + "=" * 80) print("Summary") print("=" * 80) - print(f""" + print( + f""" Implementation Details: Reference: Loop over {num_groups} groups, uses nn.Linear (C++ autograd) Plan A: Batched BMM with custom StridedBmmFunction @@ -1044,7 +1152,8 @@ def main(): Fused SiLU*Up Kernel: Forward: {fwd_bw:.1f} GB/s Backward: {bwd_bw:.1f} GB/s -""") +""" + ) print("=" * 80) print("Done!") diff --git a/examples/commons/ops/grouped_mlp_customop.py b/examples/commons/ops/grouped_mlp_customop.py index df126cdc6..1f0fdd3e1 100644 --- a/examples/commons/ops/grouped_mlp_customop.py +++ b/examples/commons/ops/grouped_mlp_customop.py @@ -12,9 +12,12 @@ """ import sys -sys.path.insert(0, '/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu') -from typing import Callable, List, Optional, Tuple, Union +sys.path.insert( + 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" +) + +from typing import List, Tuple, Union import torch import torch.nn as nn @@ -32,11 +35,11 @@ from ops.triton_ops.common import triton_autotune - # ============================================================================= # [1] Triton Kernel Configurations # ============================================================================= + def silu_configs(): configs = [] for x_block_size in [256, 512, 1024, 2048]: @@ -50,6 +53,7 @@ def silu_configs(): # [1] Triton Kernels - SiLU * Up (SwiGLU pattern) # ============================================================================= + @triton_autotune(silu_configs(), key=["x_size"]) @triton.jit def _silu_mul_forward_kernel( @@ -89,25 +93,36 @@ def _silu_mul_backward_kernel( mask = x_offset + tl.arange(0, x_block_size) < x_size cols = tl.arange(0, x_block_size) - grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to( + tl.float32 + ) gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) sigma = tl.sigmoid(gate) silu_gate = gate * sigma dsilu_dgate = sigma + gate * sigma * (1.0 - sigma) - + grad_gate = grad_output * up * dsilu_dgate grad_up = grad_output * silu_gate - tl.store(grad_gate_ptr + x_offset + cols, grad_gate.to(grad_gate_ptr.dtype.element_ty), mask=mask) - tl.store(grad_up_ptr + x_offset + cols, grad_up.to(grad_up_ptr.dtype.element_ty), mask=mask) + tl.store( + grad_gate_ptr + x_offset + cols, + grad_gate.to(grad_gate_ptr.dtype.element_ty), + mask=mask, + ) + tl.store( + grad_up_ptr + x_offset + cols, + grad_up.to(grad_up_ptr.dtype.element_ty), + mask=mask, + ) # ============================================================================= # [1] Triton Kernel Launchers (Internal) # ============================================================================= + def _launch_silu_mul_fwd(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: """Internal: launch forward kernel""" x_size = gate.numel() @@ -137,22 +152,27 @@ def _launch_silu_mul_bwd( def grid(meta): return (triton.cdiv(x_size, meta["x_block_size"]),) - _silu_mul_backward_kernel[grid](grad_gate, grad_up, grad_output_1d, gate_1d, up_1d, x_size) + _silu_mul_backward_kernel[grid]( + grad_gate, grad_up, grad_output_1d, gate_1d, up_1d, x_size + ) return grad_gate.view(shape), grad_up.view(shape) # ============================================================================= -# [2] Custom Op: silu_mul +# [2] Custom Op: silu_mul # ============================================================================= + @torch.library.custom_op("grouped_mlp::silu_mul", mutates_args=(), device_types="cuda") def silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: """ Fused SiLU multiplication: output = silu(gate) * up - + This is the SwiGLU activation pattern used in modern LLMs. """ - torch._check(gate.shape == up.shape, lambda: f"Shape mismatch: {gate.shape} vs {up.shape}") + torch._check( + gate.shape == up.shape, lambda: f"Shape mismatch: {gate.shape} vs {up.shape}" + ) return _launch_silu_mul_fwd(gate.contiguous(), up.contiguous()) @@ -162,7 +182,9 @@ def _(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: return torch.empty_like(gate) -@torch.library.custom_op("grouped_mlp::silu_mul_backward", mutates_args=(), device_types="cuda") +@torch.library.custom_op( + "grouped_mlp::silu_mul_backward", mutates_args=(), device_types="cuda" +) def silu_mul_backward( grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -173,7 +195,9 @@ def silu_mul_backward( @silu_mul_backward.register_fake -def _(grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def _( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(gate), torch.empty_like(up) @@ -194,24 +218,29 @@ def _silu_mul_setup_ctx(ctx, inputs, output): # [2] Custom Op: strided_bmm (Strided Batched Matrix Multiply) # ============================================================================= -@torch.library.custom_op("grouped_mlp::strided_bmm", mutates_args=(), device_types="cuda") + +@torch.library.custom_op( + "grouped_mlp::strided_bmm", mutates_args=(), device_types="cuda" +) def strided_bmm( - x: torch.Tensor, # (B, N, K) - weight: torch.Tensor, # (N, K, M) + x: torch.Tensor, # (B, N, K) + weight: torch.Tensor, # (N, K, M) ) -> torch.Tensor: """ Strided BMM: output[b, n, :] = x[b, n, :] @ weight[n, :, :] - + Input: x - (batch_size, num_groups, input_dim) Weight: weight - (num_groups, input_dim, output_dim) Output: - (batch_size, num_groups, output_dim) - + Internally uses permute + torch.bmm for efficiency. """ batch_size, num_groups, _ = x.shape output_dim = weight.shape[2] - - output = torch.empty(batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype) + + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) # x: (B, N, K) -> (N, B, K) # weight: (N, K, M) # out: (N, B, M) -> (B, N, M) @@ -226,33 +255,37 @@ def _(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: return x.new_empty(batch_size, num_groups, output_dim) -@torch.library.custom_op("grouped_mlp::strided_bmm_backward", mutates_args=(), device_types="cuda") +@torch.library.custom_op( + "grouped_mlp::strided_bmm_backward", mutates_args=(), device_types="cuda" +) def strided_bmm_backward( grad_output: torch.Tensor, # (B, N, M) - x: torch.Tensor, # (B, N, K) - weight: torch.Tensor, # (N, K, M) + x: torch.Tensor, # (B, N, K) + weight: torch.Tensor, # (N, K, M) ) -> Tuple[torch.Tensor, torch.Tensor]: """ Backward for strided_bmm. - + grad_x = grad_output @ weight.T => (B, N, K) grad_weight = x.T @ grad_output => (N, K, M) """ grad_output_t = grad_output.permute(1, 0, 2) # (N, B, M) - + # grad_x: (N, B, M) @ (N, M, K) -> (N, B, K) -> (B, N, K) grad_x = torch.empty_like(x) torch.bmm(grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2)) - + # grad_weight: (N, K, B) @ (N, B, M) -> (N, K, M) x_t = x.permute(1, 0, 2) # (N, B, K) grad_weight = torch.bmm(x_t.transpose(-1, -2), grad_output_t) - + return grad_x, grad_weight @strided_bmm_backward.register_fake -def _(grad_output: torch.Tensor, x: torch.Tensor, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def _( + grad_output: torch.Tensor, x: torch.Tensor, weight: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(x), torch.empty_like(weight) @@ -270,76 +303,77 @@ def _strided_bmm_setup_ctx(ctx, inputs, output): def grouped_mlp_gated_forward( - x: torch.Tensor, # (B*N, D_in) - gate_weight: torch.Tensor, # (N, D_in, D_hidden) - up_weight: torch.Tensor, # (N, D_in, D_hidden) - down_weight: torch.Tensor, # (N, D_hidden, D_out) + x: torch.Tensor, # (B*N, D_in) + gate_weight: torch.Tensor, # (N, D_in, D_hidden) + up_weight: torch.Tensor, # (N, D_in, D_hidden) + down_weight: torch.Tensor, # (N, D_hidden, D_out) num_groups: int, ) -> torch.Tensor: """ Grouped MLP forward with gating (SwiGLU pattern). - + This is a composition of custom ops (strided_bmm, silu_mul). Autograd is handled automatically through the registered backward of each op. - + For each group n: gate = x @ gate_weight[n] up = x @ up_weight[n] hidden = silu(gate) * up output = hidden @ down_weight[n] - + Args: x: Input tensor, shape (B*N, D_in) gate_weight: Gate projection weights, shape (N, D_in, D_hidden) up_weight: Up projection weights, shape (N, D_in, D_hidden) down_weight: Down projection weights, shape (N, D_hidden, D_out) num_groups: Number of groups (N) - + Returns: Output tensor, shape (B*N, D_out) """ batch_size = x.shape[0] // num_groups input_dim = gate_weight.shape[1] output_dim = down_weight.shape[2] - + # Reshape: (B*N, D_in) -> (B, N, D_in) x_3d = x.reshape(batch_size, num_groups, input_dim) - + # Gate BMM: (B, N, D_in) @ (N, D_in, D_hidden) -> (B, N, D_hidden) gate = strided_bmm(x_3d, gate_weight) - + # Up BMM: (B, N, D_in) @ (N, D_in, D_hidden) -> (B, N, D_hidden) up = strided_bmm(x_3d, up_weight) - + # Fused SiLU activation: hidden = silu(gate) * up hidden = silu_mul(gate, up) - + # Down BMM: (B, N, D_hidden) @ (N, D_hidden, D_out) -> (B, N, D_out) output = strided_bmm(hidden, down_weight) - + # Reshape: (B, N, D_out) -> (B*N, D_out) return output.reshape(-1, output_dim) # ============================================================================= -# [3] nn.Module: GroupedMLP_CustomOp (Plan A using Custom Ops) +# [3] nn.Module: GroupedMLP_CustomOp # ============================================================================= -class GroupedMLP_CustomOp(nn.Module): + +class GroupedMLP(nn.Module): """ - Grouped MLP using custom ops (Plan A: BMM-based). - + Grouped MLP using custom ops. + This implementation uses registered custom ops instead of autograd.Function, making it compatible with torch.compile and scenarios where autograd is not available. - + Forward computation: For each group n: gate = x @ gate_weight[n] - up = x @ up_weight[n] + up = x @ up_weight[n] hidden = silu(gate) * up output = hidden @ down_weight[n] """ - + def __init__( self, input_dim: int, @@ -354,7 +388,7 @@ def __init__( self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim - + # Weights: (N, D_in, D_hidden) or (N, D_hidden, D_out) self.gate_weight = nn.Parameter( torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) @@ -365,22 +399,22 @@ def __init__( self.down_weight = nn.Parameter( torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) ) - + self._init_weights() - + def _init_weights(self): for i in range(self.num_groups): nn.init.xavier_normal_(self.gate_weight[i], gain=1.0) nn.init.xavier_normal_(self.up_weight[i], gain=1.0) nn.init.xavier_normal_(self.down_weight[i], gain=1.0) - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass using custom ops. - + Args: x: Input tensor, shape (B*N, D_in) - + Returns: Output tensor, shape (B*N, D_out) """ @@ -391,38 +425,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.down_weight, self.num_groups, ) - + def forward_decomposed(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass with explicit steps (for debugging/profiling). - + Same computation as forward(), but with explicit intermediate tensors. """ batch_size = x.shape[0] // self.num_groups - + # Reshape x_3d = x.reshape(batch_size, self.num_groups, self.input_dim) - + # 3 BMMs + 1 fused activation gate = strided_bmm(x_3d, self.gate_weight) up = strided_bmm(x_3d, self.up_weight) hidden = silu_mul(gate, up) output = strided_bmm(hidden, self.down_weight) - + return output.reshape(-1, self.output_dim) # ============================================================================= -# [3] nn.Module: ReferenceGroupedMLP +# [3] nn.Module: ReferenceGroupedMLP # ============================================================================= + class ReferenceGroupedMLP(nn.Module): """ Reference implementation using loop over groups with nn.Linear. - + This is the baseline for correctness verification. """ - + def __init__( self, input_dim: int, @@ -437,32 +472,40 @@ def __init__( self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim - - self.gate_proj = nn.ModuleList([ - nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) - for _ in range(num_groups) - ]) - self.up_proj = nn.ModuleList([ - nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) - for _ in range(num_groups) - ]) - self.down_proj = nn.ModuleList([ - nn.Linear(hidden_dim, output_dim, bias=False, device=device, dtype=dtype) - for _ in range(num_groups) - ]) - + + self.gate_proj = nn.ModuleList( + [ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ] + ) + self.up_proj = nn.ModuleList( + [ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ] + ) + self.down_proj = nn.ModuleList( + [ + nn.Linear( + hidden_dim, output_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) + self._init_weights() - + def _init_weights(self): for module_list in [self.gate_proj, self.up_proj, self.down_proj]: for layer in module_list: nn.init.xavier_normal_(layer.weight, gain=1.0) - + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward with loop over groups.""" x = x.reshape(-1, self.num_groups, self.input_dim) x_split = torch.split(x, 1, dim=1) - + out_list = [] for i in range(self.num_groups): x_i = x_split[i].squeeze(1) @@ -471,7 +514,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_i = F.silu(gate_i) * up_i out_i = self.down_proj[i](hidden_i) out_list.append(out_i) - + return torch.stack(out_list, dim=1).reshape(-1, self.output_dim) @@ -479,9 +522,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [4] Weight Copy Utilities # ============================================================================= + def copy_weights_ref_to_customop( ref_model: ReferenceGroupedMLP, - customop_model: GroupedMLP_CustomOp, + customop_model: GroupedMLP, ): """Copy weights from Reference model to CustomOp model.""" with torch.no_grad(): @@ -496,75 +540,81 @@ def copy_weights_ref_to_customop( # [5] Correctness Verification # ============================================================================= + def check_forward_correctness( ref_model: ReferenceGroupedMLP, - customop_model: GroupedMLP_CustomOp, + customop_model: GroupedMLP, batch_size: int, dtype: torch.dtype = torch.bfloat16, ) -> Tuple[bool, float]: """Check forward correctness.""" num_groups = ref_model.num_groups input_dim = ref_model.input_dim - + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) - + with torch.no_grad(): ref_out = ref_model(x) customop_out = customop_model(x) - + max_diff = (ref_out - customop_out).abs().max().item() # bf16 has limited precision, use looser tolerance atol = 5e-2 if dtype == torch.bfloat16 else 1e-3 rtol = 5e-2 if dtype == torch.bfloat16 else 1e-3 passed = torch.allclose(ref_out, customop_out, atol=atol, rtol=rtol) - + return passed, max_diff def check_backward_correctness( ref_model: ReferenceGroupedMLP, - customop_model: GroupedMLP_CustomOp, + customop_model: GroupedMLP, batch_size: int, dtype: torch.dtype = torch.bfloat16, ) -> Tuple[bool, float]: """Check backward correctness (input gradient).""" num_groups = ref_model.num_groups input_dim = ref_model.input_dim - + x_ref = torch.randn( - batch_size * num_groups, input_dim, - device="cuda", dtype=dtype, requires_grad=True + batch_size * num_groups, + input_dim, + device="cuda", + dtype=dtype, + requires_grad=True, ) x_customop = x_ref.detach().clone().requires_grad_(True) - + ref_out = ref_model(x_ref) customop_out = customop_model(x_customop) - + grad_output = torch.randn_like(ref_out) ref_out.backward(grad_output) customop_out.backward(grad_output) - + max_diff = (x_ref.grad - x_customop.grad).abs().max().item() # bf16 has ~3 significant digits precision, multiple BMM ops accumulate error # Use looser tolerance: 5e-2 for bf16 (vs 1e-2 for fp32) atol = 5e-2 if dtype == torch.bfloat16 else 1e-3 rtol = 5e-2 if dtype == torch.bfloat16 else 1e-3 passed = torch.allclose(x_ref.grad, x_customop.grad, atol=atol, rtol=rtol) - + return passed, max_diff -def check_torch_compile(customop_model: GroupedMLP_CustomOp, batch_size: int) -> bool: +def check_torch_compile(customop_model: GroupedMLP, batch_size: int) -> bool: """Check torch.compile compatibility.""" num_groups = customop_model.num_groups input_dim = customop_model.input_dim - + @torch.compile(fullgraph=True) def compiled_forward(model, x): return model(x) - - x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=torch.bfloat16) - + + x = torch.randn( + batch_size * num_groups, input_dim, device="cuda", dtype=torch.bfloat16 + ) + try: out = compiled_forward(customop_model, x) return out.shape == (batch_size * num_groups, customop_model.output_dim) @@ -573,16 +623,24 @@ def compiled_forward(model, x): return False -def run_opcheck(customop_model: GroupedMLP_CustomOp): +def run_opcheck(customop_model: GroupedMLP): """Run opcheck on individual custom ops.""" print("\nRunning opcheck on custom ops...") - + # Test silu_mul examples_silu = [ - [torch.randn(32, 64, device="cuda", dtype=torch.bfloat16), - torch.randn(32, 64, device="cuda", dtype=torch.bfloat16)], - [torch.randn(16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True), - torch.randn(16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True)], + [ + torch.randn(32, 64, device="cuda", dtype=torch.bfloat16), + torch.randn(32, 64, device="cuda", dtype=torch.bfloat16), + ], + [ + torch.randn( + 16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + torch.randn( + 16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + ], ] for i, ex in enumerate(examples_silu): try: @@ -590,13 +648,21 @@ def run_opcheck(customop_model: GroupedMLP_CustomOp): print(f" silu_mul example {i+1}: PASSED") except Exception as e: print(f" silu_mul example {i+1}: FAILED - {e}") - + # Test strided_bmm examples_bmm = [ - [torch.randn(32, 8, 64, device="cuda", dtype=torch.bfloat16), - torch.randn(8, 64, 128, device="cuda", dtype=torch.bfloat16)], - [torch.randn(16, 12, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True), - torch.randn(12, 256, 512, device="cuda", dtype=torch.bfloat16, requires_grad=True)], + [ + torch.randn(32, 8, 64, device="cuda", dtype=torch.bfloat16), + torch.randn(8, 64, 128, device="cuda", dtype=torch.bfloat16), + ], + [ + torch.randn( + 16, 12, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + torch.randn( + 12, 256, 512, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + ], ] for i, ex in enumerate(examples_bmm): try: @@ -610,6 +676,7 @@ def run_opcheck(customop_model: GroupedMLP_CustomOp): # [6] Benchmark Functions # ============================================================================= + def warmup_gpu(): """Warmup GPU.""" x = torch.randn(1000, 1000, device="cuda") @@ -628,16 +695,16 @@ def benchmark_forward( for i in range(num_warmup): _ = model(x_list[i % len(x_list)]) torch.cuda.synchronize() - + start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - + start.record() for i in range(num_iterations): _ = model(x_list[i % len(x_list)]) end.record() torch.cuda.synchronize() - + return start.elapsed_time(end) / num_iterations @@ -655,16 +722,16 @@ def benchmark_forward_backward( ] x_with_grad = [xi.requires_grad_(True) for xi in x_list] params = list(model.parameters()) - + for i in range(num_warmup): xi = x_with_grad[i % len(x_list)] out = model(xi) torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) torch.cuda.synchronize() - + start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - + start.record() for i in range(num_iterations): xi = x_with_grad[i % len(x_list)] @@ -672,7 +739,7 @@ def benchmark_forward_backward( torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) end.record() torch.cuda.synchronize() - + return start.elapsed_time(end) / num_iterations @@ -683,28 +750,30 @@ def benchmark_silu_mul_bandwidth( ) -> Tuple[float, float, float, float]: """ Benchmark silu_mul kernel bandwidth. - + Returns: (fwd_time_ms, fwd_bw_gb_s, bwd_time_ms, bwd_bw_gb_s) """ gate = torch.randn(shape, device="cuda", dtype=dtype) up = torch.randn(shape, device="cuda", dtype=dtype) grad_output = torch.randn(shape, device="cuda", dtype=dtype) - + numel = gate.numel() bytes_per_elem = gate.element_size() fwd_bytes = 3 * numel * bytes_per_elem # read gate, up; write output - bwd_bytes = 5 * numel * bytes_per_elem # read grad_out, gate, up; write grad_gate, grad_up - + bwd_bytes = ( + 5 * numel * bytes_per_elem + ) # read grad_out, gate, up; write grad_gate, grad_up + # Warmup for _ in range(10): _ = silu_mul(gate, up) _ = silu_mul_backward(grad_output, gate, up) torch.cuda.synchronize() - + # Forward benchmark start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - + start.record() for _ in range(num_iterations): _ = silu_mul(gate, up) @@ -712,7 +781,7 @@ def benchmark_silu_mul_bandwidth( torch.cuda.synchronize() fwd_ms = start.elapsed_time(end) / num_iterations fwd_bw = fwd_bytes / (fwd_ms / 1000) / 1e9 - + # Backward benchmark start.record() for _ in range(num_iterations): @@ -721,7 +790,7 @@ def benchmark_silu_mul_bandwidth( torch.cuda.synchronize() bwd_ms = start.elapsed_time(end) / num_iterations bwd_bw = bwd_bytes / (bwd_ms / 1000) / 1e9 - + return fwd_ms, fwd_bw, bwd_ms, bwd_bw @@ -729,14 +798,15 @@ def benchmark_silu_mul_bandwidth( # Main # ============================================================================= + def main(): print("=" * 70) print("Grouped MLP as PyTorch Custom Ops") print("=" * 70) - + torch.cuda.init() print(f"\nDevice: {torch.cuda.get_device_name(0)}") - + # Configuration batch_size = 2560 num_groups = 12 @@ -745,47 +815,53 @@ def main(): output_dim = 1024 dtype = torch.bfloat16 num_iterations = 100 - - print(f""" + + print( + f""" Config: Batch size: {batch_size} Num groups: {num_groups} Dimensions: {input_dim} -> {hidden_dim} -> {output_dim} Dtype: {dtype} -""") - +""" + ) + print("Warming up GPU...") warmup_gpu() - + # Create models print("Creating models...") ref_model = ReferenceGroupedMLP( input_dim, hidden_dim, output_dim, num_groups, dtype=dtype ).cuda() - - customop_model = GroupedMLP_CustomOp( + + customop_model = GroupedMLP( input_dim, hidden_dim, output_dim, num_groups, dtype=dtype ).cuda() - + # Copy weights copy_weights_ref_to_customop(ref_model, customop_model) - + # ========================= # Correctness Verification # ========================= print("\n" + "-" * 70) print("Correctness Verification") print("-" * 70) - - fwd_ok, fwd_diff = check_forward_correctness(ref_model, customop_model, batch_size, dtype) + + fwd_ok, fwd_diff = check_forward_correctness( + ref_model, customop_model, batch_size, dtype + ) print(f"\nForward: {'✓ PASS' if fwd_ok else '✗ FAIL'} (max_diff: {fwd_diff:.2e})") - - bwd_ok, bwd_diff = check_backward_correctness(ref_model, customop_model, batch_size, dtype) + + bwd_ok, bwd_diff = check_backward_correctness( + ref_model, customop_model, batch_size, dtype + ) print(f"Backward: {'✓ PASS' if bwd_ok else '✗ FAIL'} (max_diff: {bwd_diff:.2e})") - + compile_ok = check_torch_compile(customop_model, batch_size) print(f"torch.compile(fullgraph=True): {'✓ PASS' if compile_ok else '✗ FAIL'}") - + # ========================= # opcheck # ========================= @@ -793,70 +869,75 @@ def main(): print("Op Registration Check") print("-" * 70) run_opcheck(customop_model) - + # ========================= # Performance Benchmark # ========================= print("\n" + "-" * 70) print("Performance Benchmark") print("-" * 70) - + x_list = [ torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) for _ in range(10) ] - + # Forward print("\n>>> Forward Pass <<<") ref_fwd = benchmark_forward(ref_model, x_list, num_iterations) customop_fwd = benchmark_forward(customop_model, x_list, num_iterations) - + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") print("-" * 52) print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") - print(f"{'CustomOp (BMM)':<30} {customop_fwd:<12.4f} {ref_fwd/customop_fwd:<10.2f}x") - + print( + f"{'CustomOp (BMM)':<30} {customop_fwd:<12.4f} {ref_fwd/customop_fwd:<10.2f}x" + ) + # Forward + Backward print("\n>>> Forward + Backward <<<") ref_fwdbwd = benchmark_forward_backward(ref_model, x_list, num_iterations) customop_fwdbwd = benchmark_forward_backward(customop_model, x_list, num_iterations) - + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") print("-" * 52) print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") - print(f"{'CustomOp (BMM)':<30} {customop_fwdbwd:<12.4f} {ref_fwdbwd/customop_fwdbwd:<10.2f}x") - + print( + f"{'CustomOp (BMM)':<30} {customop_fwdbwd:<12.4f} {ref_fwdbwd/customop_fwdbwd:<10.2f}x" + ) + # ========================= # SiLU*Up Kernel Bandwidth # ========================= print("\n" + "-" * 70) print("Fused SiLU*Up Kernel Bandwidth") print("-" * 70) - + shape = (batch_size, num_groups, hidden_dim) fwd_ms, fwd_bw, bwd_ms, bwd_bw = benchmark_silu_mul_bandwidth(shape, dtype) - + print(f"\nShape: {shape}") print(f"\n{'Kernel':<15} {'Time (ms)':<12} {'BW (GB/s)':<12}") print("-" * 40) print(f"{'Forward':<15} {fwd_ms:<12.4f} {fwd_bw:<12.1f}") print(f"{'Backward':<15} {bwd_ms:<12.4f} {bwd_bw:<12.1f}") - + # ========================= # Summary # ========================= print("\n" + "=" * 70) print("Summary") print("=" * 70) - print(f""" + print( + f""" Forward Speedup: {ref_fwd/customop_fwd:.2f}x Fwd+Bwd Speedup: {ref_fwdbwd/customop_fwdbwd:.2f}x SiLU*Up BW: {fwd_bw:.1f} GB/s (fwd), {bwd_bw:.1f} GB/s (bwd) -""") +""" + ) print("=" * 70) print("Done!") if __name__ == "__main__": main() -