diff --git a/examples/commons/ops/GroupedLinear_example.py b/examples/commons/ops/GroupedLinear_example.py new file mode 100644 index 000000000..71fe3633c --- /dev/null +++ b/examples/commons/ops/GroupedLinear_example.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped Linear Layer with Strided BMM Optimization + +================================================================================ +Problem +================================================================================ +Apply num_groups different linear transformations to corresponding slices of input: + + Input: x of shape (B * num_groups, input_dim) + Output: y of shape (B * num_groups, output_dim) + + For each group n: y[b, n, :] = x[b, n, :] @ W[n, :, :] + +================================================================================ +Reference Implementation +================================================================================ +The straightforward approach uses a loop over groups: + + x = x.reshape(B, num_groups, D_in) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(num_groups): + x_i = x_split[i].squeeze(1) # (B, D_in) + out_i = linear_layers[i](x_i) # (B, D_out) + out_list.append(out_i) + + output = torch.stack(out_list, dim=1).reshape(-1, D_out) + +================================================================================ +Optimized Implementation +================================================================================ +Use torch.bmm with strided output to fuse all GEMMs into one kernel: + + x = x.reshape(B, num_groups, D_in) + output = torch.empty(B, num_groups, D_out, ...) # pre-allocate final layout + torch.bmm(x.permute(1,0,2), weight, + out=output.permute(1,0,2)) # cuBLAS writes to strided memory + return output.view(-1, D_out) # O(1) view, no copy. + +Key feature: cuBLAS strided batched GEMM supports strided output via ldc/strideC +parameters, allowing direct write to the transposed memory layout. + +================================================================================ +Performance Results +================================================================================ +Config: batch_size=2560, num_groups=12, input_dim=1024, output_dim=3072, dtype=bf16 +Device: NVIDIA H100 +BMM_Opt Forward: 1.46x +BMM_Opt Forward+Backward:1.41x + +================================================================================ +""" + +import argparse +from typing import List, Tuple + +import torch +import torch.nn as nn + + +def warmup_gpu(): + """Warmup GPU to get stable timing""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +class ReferenceImpl(nn.Module): + """ + Reference implementation using reshape + split + loop + stack. + Simple but slow due to multiple kernel launches. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_groups: int, + device="cuda", + dtype=torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear_layers = nn.ModuleList( + [ + nn.Linear(input_dim, output_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ] + ) + + for layer in self.linear_layers: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # reshape: (B*ns, D) -> (B, ns, D) + x = x.reshape(-1, self.num_groups, self.input_dim) + + # split and loop + x_split = torch.split(x, 1, dim=1) + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) # (B, D) + out_i = self.linear_layers[i](x_i) # (B, D_out) + out_list.append(out_i) + + # stack: ns * (B, D_out) -> (B, ns, D_out) -> (B*ns, D_out) + return torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + +class StridedBmmFunction(torch.autograd.Function): + """Custom autograd function for BMM with strided output.""" + + @staticmethod + def forward(ctx, x, weight, batch_size, num_groups, output_dim, batch_first): + ctx.save_for_backward(x, weight) + ctx.batch_first = batch_first + + if batch_first: + # x: [B, G, D] -> need permute to [G, B, D] for bmm + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + else: + # x: [G, B, D] -> already in correct layout for bmm + output = torch.bmm(x, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + batch_first = ctx.batch_first + grad_x = grad_weight = None + + if batch_first: + # grad_output: [B, G, D_out] + grad_output_t = grad_output.permute(1, 0, 2) # [G, B, D_out] + + if ctx.needs_input_grad[0]: + grad_x = torch.empty_like(x) + torch.bmm( + grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2) + ) + + if ctx.needs_input_grad[1]: + grad_weight = torch.bmm( + x.permute(1, 0, 2).transpose(-1, -2), grad_output_t + ) + else: + # grad_output: [G, B, D_out] + if ctx.needs_input_grad[0]: + grad_x = torch.bmm(grad_output, weight.transpose(-1, -2)) + + if ctx.needs_input_grad[1]: + grad_weight = torch.bmm(x.transpose(-1, -2), grad_output) + + return grad_x, grad_weight, None, None, None, None + + +class GroupedLinear(nn.Module): + """ + Grouped linear layer: applies num_groups different linear transforms in parallel. + Optimized using batched GEMM with strided output for single kernel launch. + + Args: + input_dim: Input feature dimension. + output_dim: Output feature dimension. + num_groups: Number of independent linear transforms. + batch_first: If True, input layout is [B, G, D] (batch-first, default). + If False, input layout is [G, B, D] (group-first). + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_groups: int, + device="cuda", + dtype=torch.bfloat16, + batch_first: bool = True, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.output_dim = output_dim + self.batch_first = batch_first + + self.weight = nn.Parameter( + torch.empty(num_groups, input_dim, output_dim, device=device, dtype=dtype) + ) + for i in range(num_groups): + nn.init.xavier_normal_(self.weight[i], gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + + if self.batch_first: + # Input flattened from [B, G, D] -> reshape to [B, G, D] + x = x.reshape(batch_size, self.num_groups, self.input_dim) + else: + # Input flattened from [G, B, D] -> reshape to [G, B, D] + x = x.reshape(self.num_groups, batch_size, self.input_dim) + + output = StridedBmmFunction.apply( + x, + self.weight, + batch_size, + self.num_groups, + self.output_dim, + self.batch_first, + ) + return output.view(-1, self.output_dim) + + +def copy_weights(ref_model: ReferenceImpl, opt_model: GroupedLinear): + """Copy weights from reference to optimized model.""" + with torch.no_grad(): + for i in range(ref_model.num_groups): + opt_model.weight[i].copy_(ref_model.linear_layers[i].weight.T) + + +def check_correctness( + ref_model: ReferenceImpl, + opt_model: GroupedLinear, + batch_size: int, + num_groups: int, + input_dim: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[float, float, float]: + """Check forward and backward correctness.""" + # Forward check + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + with torch.no_grad(): + ref_out = ref_model(x) + opt_out = opt_model(x) + fwd_diff = (ref_out - opt_out).abs().max().item() + + # Backward check + x_ref = torch.randn( + batch_size * num_groups, + input_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + x_opt = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + opt_out = opt_model(x_opt) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + opt_out.backward(grad_output) + + # Input gradient + bwd_x_diff = (x_ref.grad - x_opt.grad).abs().max().item() + + # Weight gradient + ref_weight_grad = torch.stack( + [ref_model.linear_layers[i].weight.grad.T for i in range(num_groups)] + ) + bwd_w_diff = (ref_weight_grad - opt_model.weight.grad).abs().max().item() + + return fwd_diff, bwd_x_diff, bwd_w_diff + + +def benchmark( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + with_backward: bool = False, +) -> float: + """Benchmark forward or forward+backward pass using CUDA events for accurate GPU timing.""" + if with_backward: + x_list = [xi.requires_grad_(True) for xi in x_list] + grad_outputs = [ + torch.randn(xi.shape[0], model.output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + params = list(model.parameters()) + + # Warmup + for i in range(num_warmup): + xi = x_list[i % len(x_list)] + out = model(xi) + if with_backward: + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + for i in range(num_iterations): + xi = x_list[i % len(x_list)] + out = model(xi) + if with_backward: + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations # ms + + +def main(): + parser = argparse.ArgumentParser( + description="Grouped GEMM: Reference vs Strided BMM" + ) + parser.add_argument( + "--batch-size", type=int, default=2560, help="Batch size to test" + ) + parser.add_argument("--num-groups", type=int, default=12, help="Number of groups") + parser.add_argument("--input-dim", type=int, default=1024, help="Input dimension") + parser.add_argument("--output-dim", type=int, default=3072, help="Output dimension") + parser.add_argument( + "--iterations", type=int, default=100, help="Number of iterations for timing" + ) + args = parser.parse_args() + + torch.cuda.init() + + # Configuration from args + batch_size = args.batch_size + num_groups = args.num_groups + input_dim = args.input_dim + output_dim = args.output_dim + dtype = torch.bfloat16 + num_iterations = args.iterations + + print("=" * 60) + print("Grouped GEMM: Reference vs Strided BMM") + print("=" * 60) + print( + f"\nConfig: B={batch_size}, groups={num_groups}, D_in={input_dim}, D_out={output_dim}" + ) + print(f"Device: {torch.cuda.get_device_name(0)}") + + # Warmup GPU + print("\nWarming up GPU...") + warmup_gpu() + + # Create models + ref_model = ReferenceImpl(input_dim, output_dim, num_groups, dtype=dtype).cuda() + opt_model = GroupedLinear(input_dim, output_dim, num_groups, dtype=dtype).cuda() + copy_weights(ref_model, opt_model) + + # Correctness check + print("\n" + "-" * 40) + print("Correctness Check") + print("-" * 40) + fwd_diff, bwd_x_diff, bwd_w_diff = check_correctness( + ref_model, opt_model, batch_size, num_groups, input_dim, dtype + ) + print(f"Forward max diff: {fwd_diff:.2e} {'✓' if fwd_diff < 1e-3 else '✗'}") + print(f"Backward dL/dx diff: {bwd_x_diff:.2e} {'✓' if bwd_x_diff < 1e-3 else '✗'}") + print(f"Backward dL/dW diff: {bwd_w_diff:.2e} {'✓' if bwd_w_diff < 1e-3 else '✗'}") + + # Benchmark + print("\n" + "-" * 40) + print("Performance Benchmark") + print("-" * 40) + + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Forward only + ref_fwd = benchmark(ref_model, x_list, num_iterations, with_backward=False) + opt_fwd = benchmark(opt_model, x_list, num_iterations, with_backward=False) + + print(f"\nForward pass (ms):") + print(f" Reference (loop): {ref_fwd:.4f}") + print(f" GroupedLinear: {opt_fwd:.4f}") + print(f" Speedup: {ref_fwd/opt_fwd:.2f}x") + + # Forward + Backward + ref_fwdbwd = benchmark(ref_model, x_list, num_iterations, with_backward=True) + opt_fwdbwd = benchmark(opt_model, x_list, num_iterations, with_backward=True) + + print(f"\nForward + Backward (ms):") + print(f" Reference (loop): {ref_fwdbwd:.4f}") + print(f" GroupedLinear: {opt_fwdbwd:.4f}") + print(f" Speedup: {ref_fwdbwd/opt_fwdbwd:.2f}x") + + print("\n" + "=" * 60) + print("Done!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/commons/ops/GroupedMLP_example.py b/examples/commons/ops/GroupedMLP_example.py new file mode 100644 index 000000000..0736c2df8 --- /dev/null +++ b/examples/commons/ops/GroupedMLP_example.py @@ -0,0 +1,1162 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped MLP Benchmark: Reference vs Plan A + +================================================================================ +Problem +================================================================================ +Apply num_groups different MLP transformations with GLU gating (SwiGLU/GeGLU): + + For each group n: y[b, n, :] = down(act(gate(x)) * up(x)) + +================================================================================ +Implementations +================================================================================ +Reference: Loop over groups with separate nn.Linear layers +Plan A: 3 independent strided BMMs (gate, up, down) + +================================================================================ +""" +import sys + +sys.path.insert( + 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" +) + +import argparse +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.cuda.nvtx as nvtx +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +try: + # @manual=//triton:triton + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + # pyre-ignore: Undefined import [21] + # @manual=//triton:triton + from triton.language.math import fast_dividef + +from ops.triton_ops.common import triton_autotune + + +def silu_configs(): + configs = [] + for x_block_size in [256, 512, 1024, 2048]: + for num_warps in [2, 4, 8, 16]: + config = triton.Config({"x_block_size": x_block_size}, num_warps) + configs.append(config) + return configs + + +# ============================================================================= +# Fused SiLU * Up (SwiGLU pattern): output = silu(gate) * up +# ============================================================================= + + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_forward( + output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """Fused forward: output = silu(gate) * up""" + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + # silu(gate) = gate * sigmoid(gate) = gate / (1 + exp(-gate)) + silu_gate = fast_dividef(gate, 1.0 + tl.exp(-gate)) + output = (silu_gate * up).to(output_ptr.dtype.element_ty) + + tl.store(output_ptr + x_offset + cols, output, mask=mask) + + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_backward( + grad_gate_ptr: tl.tensor, + grad_up_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """ + Fused backward for output = silu(gate) * up + + grad_gate = grad_output * up * d(silu)/d(gate) + = grad_output * up * (sigmoid(gate) + gate * sigmoid(gate) * (1 - sigmoid(gate))) + grad_up = grad_output * silu(gate) + """ + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to( + tl.float32 + ) + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + sigma = tl.sigmoid(gate) + silu_gate = gate * sigma + + # d(silu)/d(gate) = sigma + gate * sigma * (1 - sigma) + dsilu_dgate = sigma + gate * sigma * (1.0 - sigma) + + grad_gate = grad_output * up * dsilu_dgate + grad_up = grad_output * silu_gate + + tl.store( + grad_gate_ptr + x_offset + cols, + grad_gate.to(grad_gate_ptr.dtype.element_ty), + mask=mask, + ) + tl.store( + grad_up_ptr + x_offset + cols, + grad_up.to(grad_up_ptr.dtype.element_ty), + mask=mask, + ) + + +def triton_silu_mul_fwd(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """Forward: output = silu(gate) * up""" + assert gate.shape == up.shape, f"Shape mismatch: gate {gate.shape} vs up {up.shape}" + x_size = gate.numel() + + with torch.cuda.nvtx.range("gate_contiguous"): + gate_1d = gate.view(-1).contiguous() + + with torch.cuda.nvtx.range("up_contiguous"): + up_1d = up.view(-1).contiguous() + output = torch.empty_like(gate_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_forward[grid]( + output, + gate_1d, + up_1d, + x_size, + ) + return output.view(gate.shape) + + +def triton_silu_mul_bwd( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward: returns (grad_gate, grad_up)""" + shape = gate.shape + x_size = gate.numel() + gate_1d = gate.view(-1).contiguous() + up_1d = up.view(-1).contiguous() + grad_output_1d = grad_output.view(-1).contiguous() + grad_gate = torch.empty_like(gate_1d) + grad_up = torch.empty_like(up_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_backward[grid]( + grad_gate, + grad_up, + grad_output_1d, + gate_1d, + up_1d, + x_size, + ) + return grad_gate.view(shape), grad_up.view(shape) + + +class TritonSiluMul(torch.autograd.Function): + """Autograd function for fused silu(gate) * up""" + + @staticmethod + def forward(ctx, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + output = triton_silu_mul_fwd(gate, up) + ctx.save_for_backward(gate, up) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + gate, up = ctx.saved_tensors + grad_gate, grad_up = triton_silu_mul_bwd(grad_output, gate, up) + return grad_gate, grad_up + + +def triton_silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """ + Fused SiLU multiplication (SwiGLU pattern). + + Computes: output = silu(gate) * up + + Args: + gate: Input tensor that goes through SiLU activation + up: Input tensor that multiplies with activated gate + + Returns: + output: silu(gate) * up + """ + gate = gate.contiguous() + up = up.contiguous() + return TritonSiluMul.apply(gate, up) + + +def warmup_gpu(): + """Warmup GPU to get stable timing.""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +def benchmark_fused_silu_mul_bandwidth( + batch_size: int, + num_groups: int, + hidden_dim: int, + dtype: torch.dtype = torch.bfloat16, + num_iterations: int = 100, + num_warmup: int = 10, +) -> Tuple[float, float, float]: + """ + Benchmark fused SiLU * up kernel and calculate memory bandwidth. + + Returns: + (time_ms, bandwidth_gb_s, achieved_percent) + """ + # Create tensors matching the shape used in GroupedMLP + shape = (batch_size, num_groups, hidden_dim) + gate = torch.randn(shape, device="cuda", dtype=dtype) + up = torch.randn(shape, device="cuda", dtype=dtype) + + # Calculate data volume + numel = gate.numel() + bytes_per_element = gate.element_size() # 2 for bf16, 4 for fp32 + # Read: gate + up, Write: output + total_bytes = 3 * numel * bytes_per_element + + # Warmup + for _ in range(num_warmup): + _ = triton_silu_mul(gate, up) + torch.cuda.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iterations): + _ = triton_silu_mul(gate, up) + end_event.record() + torch.cuda.synchronize() + + time_ms = start_event.elapsed_time(end_event) / num_iterations + time_s = time_ms / 1000.0 + + # Calculate bandwidth + bandwidth_gb_s = total_bytes / time_s / 1e9 + + # Get theoretical peak bandwidth (for reference) + # H100 SXM: ~3.35 TB/s, A100 SXM: ~2.0 TB/s, A100 PCIe: ~1.9 TB/s + # You can query this from torch.cuda.get_device_properties + torch.cuda.get_device_properties(0) + # Memory bandwidth = memory_clock_rate (kHz) * bus_width (bits) * 2 (DDR) / 8 (bits to bytes) + # Note: This is approximate; actual peak may differ + theoretical_peak_gb_s = 2000.0 # Default estimate, adjust based on your GPU + + achieved_percent = bandwidth_gb_s / theoretical_peak_gb_s * 100 + + return time_ms, bandwidth_gb_s, achieved_percent + + +def benchmark_fused_silu_mul_backward_bandwidth( + batch_size: int, + num_groups: int, + hidden_dim: int, + dtype: torch.dtype = torch.bfloat16, + num_iterations: int = 100, + num_warmup: int = 10, +) -> Tuple[float, float]: + """ + Benchmark fused SiLU * up backward kernel and calculate memory bandwidth. + + Backward reads: grad_output, gate, up (3 tensors) + Backward writes: grad_gate, grad_up (2 tensors) + Total: 5 * numel * bytes_per_element + + Returns: + (time_ms, bandwidth_gb_s) + """ + shape = (batch_size, num_groups, hidden_dim) + gate = torch.randn(shape, device="cuda", dtype=dtype) + up = torch.randn(shape, device="cuda", dtype=dtype) + grad_output = torch.randn(shape, device="cuda", dtype=dtype) + + numel = gate.numel() + bytes_per_element = gate.element_size() + # Read: grad_output + gate + up, Write: grad_gate + grad_up + total_bytes = 5 * numel * bytes_per_element + + # Warmup + for _ in range(num_warmup): + _ = triton_silu_mul_bwd(grad_output, gate, up) + torch.cuda.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iterations): + _ = triton_silu_mul_bwd(grad_output, gate, up) + end_event.record() + torch.cuda.synchronize() + + time_ms = start_event.elapsed_time(end_event) / num_iterations + time_s = time_ms / 1000.0 + bandwidth_gb_s = total_bytes / time_s / 1e9 + + return time_ms, bandwidth_gb_s + + +def get_activation_fn(activation: Optional[str]) -> Optional[Callable]: + """Get activation function by name.""" + if activation is None: + return None + activation_map = { + "silu": F.silu, + "swish": F.silu, + "gelu": F.gelu, + "relu": F.relu, + "tanh": torch.tanh, + "sigmoid": torch.sigmoid, + "swiglu": triton_silu_mul, + } + if activation.lower() not in activation_map: + raise ValueError(f"Unknown activation: {activation}") + return activation_map[activation.lower()] + + +# ============================================================================= +# Reference Implementation +# ============================================================================= + + +class ReferenceGroupedMLP(nn.Module): + """Reference implementation using loop over groups.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "swiglu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + if use_gating: + self.gate_proj = nn.ModuleList( + [ + nn.Linear( + input_dim, hidden_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) + self.up_proj = nn.ModuleList( + [ + nn.Linear( + input_dim, hidden_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) + else: + self.proj = nn.ModuleList( + [ + nn.Linear( + input_dim, hidden_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) + + self.down_proj = nn.ModuleList( + [ + nn.Linear( + hidden_dim, output_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) + + self._init_weights() + + def _init_weights(self): + for module_list in [ + getattr(self, "gate_proj", []), + getattr(self, "up_proj", []), + getattr(self, "proj", []), + self.down_proj, + ]: + for layer in module_list: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + if enable_nvtx: + with nvtx.range("Ref_reshape"): + x = x.reshape(-1, self.num_groups, self.input_dim) + + with nvtx.range("Ref_split"): + x_split = torch.split(x, 1, dim=1) + + with nvtx.range("Ref_loop_gemm"): + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + with nvtx.range("Ref_stack"): + output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + else: + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + return output + + def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass up to hidden (excluding down projection).""" + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + hidden_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + hidden_list.append(hidden_i) + + return torch.stack(hidden_list, dim=1).reshape(-1, self.hidden_dim) + + +# ============================================================================= +# Strided BMM Function +# ============================================================================= + + +class StridedBmmFunction(torch.autograd.Function): + """Custom autograd function for BMM with strided output.""" + + @staticmethod + def forward(ctx, x, weight, batch_size, num_groups, output_dim): + ctx.save_for_backward(x, weight) + + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + grad_x = grad_weight = None + + grad_output_t = grad_output.permute(1, 0, 2) + + if ctx.needs_input_grad[0]: + grad_x = torch.empty_like(x) + torch.bmm( + grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2) + ) + + if ctx.needs_input_grad[1]: + x_t = x.permute(1, 0, 2) + grad_weight = torch.bmm(x_t.transpose(-1, -2), grad_output_t) + + return grad_x, grad_weight, None, None, None + + +# ============================================================================= +# Plan A: 3 Independent BMMs +# ============================================================================= + + +class GroupedMLP_PlanA(nn.Module): + """Plan A: 3 independent strided BMMs (gate, up, down).""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "silu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + if use_gating: + self.gate_weight = nn.Parameter( + torch.empty( + num_groups, input_dim, hidden_dim, device=device, dtype=dtype + ) + ) + self.up_weight = nn.Parameter( + torch.empty( + num_groups, input_dim, hidden_dim, device=device, dtype=dtype + ) + ) + else: + self.proj_weight = nn.Parameter( + torch.empty( + num_groups, input_dim, hidden_dim, device=device, dtype=dtype + ) + ) + + self.down_weight = nn.Parameter( + torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) + ) + + self._init_weights() + + def _init_weights(self): + for i in range(self.num_groups): + if self.use_gating: + nn.init.xavier_normal_(self.gate_weight[i], gain=1.0) + nn.init.xavier_normal_(self.up_weight[i], gain=1.0) + else: + nn.init.xavier_normal_(self.proj_weight[i], gain=1.0) + nn.init.xavier_normal_(self.down_weight[i], gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + + if enable_nvtx: + with nvtx.range("PlanA_reshape"): + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + with nvtx.range("PlanA_gate_bmm"): + gate = StridedBmmFunction.apply( + x, + self.gate_weight, + batch_size, + self.num_groups, + self.hidden_dim, + ) + + with nvtx.range("PlanA_up_bmm"): + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + + with nvtx.range("PlanA_activation"): + if self.act_fn is not None: + # hidden = self.act_fn(gate) * up + hidden = triton_silu_mul(gate, up) + else: + hidden = gate * up + else: + with nvtx.range("PlanA_proj_bmm"): + hidden = StridedBmmFunction.apply( + x, + self.proj_weight, + batch_size, + self.num_groups, + self.hidden_dim, + ) + with nvtx.range("PlanA_activation"): + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + with nvtx.range("PlanA_down_bmm"): + output = StridedBmmFunction.apply( + hidden, + self.down_weight, + batch_size, + self.num_groups, + self.output_dim, + ) + + with nvtx.range("PlanA_view"): + return output.view(-1, self.output_dim) + else: + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + ) + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + # hidden = self.act_fn(gate) * up + hidden = triton_silu_mul(gate, up) + else: + hidden = gate * up + else: + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + output = StridedBmmFunction.apply( + hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + ) + + return output.view(-1, self.output_dim) + + def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass up to hidden (excluding down projection).""" + batch_size = x.shape[0] // self.num_groups + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + ) + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = triton_silu_mul(gate, up) + else: + hidden = gate * up + else: + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + return hidden.view(-1, self.hidden_dim) + + +# ============================================================================= +# Weight Copy Utilities +# ============================================================================= + + +def copy_weights_to_plan_a(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanA): + with torch.no_grad(): + num_groups = ref_model.num_groups + if ref_model.use_gating: + for i in range(num_groups): + opt_model.gate_weight[i].copy_(ref_model.gate_proj[i].weight.T) + opt_model.up_weight[i].copy_(ref_model.up_proj[i].weight.T) + else: + for i in range(num_groups): + opt_model.proj_weight[i].copy_(ref_model.proj[i].weight.T) + for i in range(num_groups): + opt_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) + + +# ============================================================================= +# Correctness Check +# ============================================================================= + + +def check_correctness( + ref_model: nn.Module, + opt_model: nn.Module, + batch_size: int, + num_groups: int, + input_dim: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[float, float]: + """Check forward and backward correctness.""" + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + with torch.no_grad(): + ref_out = ref_model(x) + opt_out = opt_model(x) + fwd_diff = (ref_out - opt_out).abs().max().item() + + x_ref = torch.randn( + batch_size * num_groups, + input_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + x_opt = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + opt_out = opt_model(x_opt) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + opt_out.backward(grad_output) + + bwd_x_diff = (x_ref.grad - x_opt.grad).abs().max().item() + + return fwd_diff, bwd_x_diff + + +# ============================================================================= +# Benchmark Functions (same as benchmark_batched_gemm.py) +# ============================================================================= + + +def benchmark_forward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + enable_nvtx: bool = False, +) -> float: + """Benchmark forward pass using CUDA events for accurate GPU timing.""" + model_name = model.__class__.__name__ + + # Warmup + for i in range(num_warmup): + _ = model(x_list[i % len(x_list)], enable_nvtx=False) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + if enable_nvtx: + for i in range(num_iterations): + with nvtx.range(f"{model_name}_fwd_iter{i}"): + _ = model(x_list[i % len(x_list)], enable_nvtx=True) + else: + for i in range(num_iterations): + _ = model(x_list[i % len(x_list)], enable_nvtx=False) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +def benchmark_forward_to_hidden( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, +) -> float: + """Benchmark forward pass up to hidden (excluding down projection).""" + # Warmup + for i in range(num_warmup): + _ = model.forward_to_hidden(x_list[i % len(x_list)]) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + for i in range(num_iterations): + _ = model.forward_to_hidden(x_list[i % len(x_list)]) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +def benchmark_forward_backward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + enable_nvtx: bool = False, +) -> float: + """Benchmark forward + backward pass using CUDA events.""" + model_name = model.__class__.__name__ + output_dim = model.output_dim + + grad_outputs = [ + torch.randn(xi.shape[0], output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + + x_with_grad = [xi.requires_grad_(True) for xi in x_list] + params = list(model.parameters()) + + # Warmup + for i in range(num_warmup): + xi = x_with_grad[i % len(x_list)] + out = model(xi, enable_nvtx=False) + grads = torch.autograd.grad( + outputs=out, + inputs=[xi] + params, + grad_outputs=grad_outputs[i % len(x_list)], + ) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + if enable_nvtx: + for i in range(num_iterations): + with nvtx.range(f"{model_name}_fwdbwd_iter{i}"): + xi = x_with_grad[i % len(x_list)] + grad_out = grad_outputs[i % len(x_list)] + with nvtx.range("forward"): + out = model(xi, enable_nvtx=True) + # Separate backward into two parts for clearer profiling + with nvtx.range("backward_activation"): + # dL/dx (activation gradient) + grad_x = torch.autograd.grad( + outputs=out, + inputs=xi, + grad_outputs=grad_out, + retain_graph=True, # Keep graph for weight gradient + ) + with nvtx.range("backward_weight"): + # dL/dW (weight gradient) + grad_w = torch.autograd.grad( + outputs=out, + inputs=params, + grad_outputs=grad_out, + retain_graph=False, + ) + else: + for i in range(num_iterations): + xi = x_with_grad[i % len(x_list)] + out = model(xi, enable_nvtx=False) + grads = torch.autograd.grad( + outputs=out, + inputs=[xi] + params, + grad_outputs=grad_outputs[i % len(x_list)], + ) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="Grouped MLP Benchmark: Reference vs Plan A" + ) + parser.add_argument("--batch-size", type=int, default=2560) + parser.add_argument("--num-groups", type=int, default=12) + parser.add_argument("--input-dim", type=int, default=1024) + parser.add_argument("--hidden-dim", type=int, default=3072) + parser.add_argument("--output-dim", type=int, default=1024) + parser.add_argument( + "--activation", + type=str, + default="silu", + choices=["silu", "gelu", "relu", "tanh", "none"], + ) + parser.add_argument("--no-gating", action="store_true") + parser.add_argument("--iterations", type=int, default=100) + parser.add_argument( + "--enable-nvtx", + action="store_true", + help="Enable NVTX markers (use with nsys profile)", + ) + parser.add_argument( + "--compile", action="store_true", help="Use torch.compile() to optimize models" + ) + args = parser.parse_args() + + torch.cuda.init() + + batch_size = args.batch_size + num_groups = args.num_groups + input_dim = args.input_dim + hidden_dim = args.hidden_dim + output_dim = args.output_dim + activation = None if args.activation == "none" else args.activation + use_gating = not args.no_gating + dtype = torch.bfloat16 + num_iterations = args.iterations + + print("=" * 80) + print("Grouped MLP Benchmark: Reference vs Plan A") + print("=" * 80) + + if args.enable_nvtx: + print("\n*** NVTX PROFILING MODE ***") + print("Run with: nsys profile -o --trace=cuda,nvtx python ...") + + print( + f""" +Config: + Batch size: {batch_size} + Num groups: {num_groups} + Dimensions: {input_dim} -> {hidden_dim} -> {output_dim} + Mode: {"GLU (SwiGLU)" if use_gating else "Simple MLP"} + Activation: {activation if activation else "None"} + Dtype: {dtype} + Device: {torch.cuda.get_device_name(0)} + Iterations: {num_iterations} +""" + ) + + print("Warming up GPU...") + warmup_gpu() + + # Create models + print("Creating models...") + ref_model = ReferenceGroupedMLP( + input_dim, + hidden_dim, + output_dim, + num_groups, + use_gating=use_gating, + activation=activation, + dtype=dtype, + ).cuda() + + plan_a_model = GroupedMLP_PlanA( + input_dim, + hidden_dim, + output_dim, + num_groups, + use_gating=use_gating, + activation=activation, + dtype=dtype, + ).cuda() + + copy_weights_to_plan_a(ref_model, plan_a_model) + + # Apply torch.compile() if requested + if args.compile: + print("\nApplying torch.compile() to all models...") + ref_model = torch.compile(ref_model) + plan_a_model = torch.compile(plan_a_model) + print("Compilation complete (will JIT compile on first run).") + + # Correctness check + print("-" * 60) + print("Correctness Check") + print("-" * 60) + + fwd_a, bwd_a = check_correctness( + ref_model, plan_a_model, batch_size, num_groups, input_dim, dtype + ) + print(f"Plan A - Forward diff: {fwd_a:.2e}, Backward diff: {bwd_a:.2e}") + + # Prepare test data + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Benchmark (NEVER use NVTX for timing - NVTX adds Python overhead) + print("\n" + "-" * 60) + print("Performance Benchmark (NVTX disabled for accurate timing)") + print("-" * 60) + + # Forward - always benchmark without NVTX + print("\n>>> Forward Pass <<<") + ref_fwd = benchmark_forward(ref_model, x_list, num_iterations, enable_nvtx=False) + plan_a_fwd = benchmark_forward( + plan_a_model, x_list, num_iterations, enable_nvtx=False + ) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") + print( + f"{'Plan A (batched BMM)':<30} {plan_a_fwd:<12.4f} {ref_fwd/plan_a_fwd:<10.2f}x" + ) + + # Forward to Hidden (excluding down projection) + print("\n>>> Forward to Hidden (excluding down_proj) <<<") + ref_hidden = benchmark_forward_to_hidden(ref_model, x_list, num_iterations) + plan_a_hidden = benchmark_forward_to_hidden(plan_a_model, x_list, num_iterations) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_hidden:<12.4f} {'1.00x':<10}") + print( + f"{'Plan A (batched BMM)':<30} {plan_a_hidden:<12.4f} {ref_hidden/plan_a_hidden:<10.2f}x" + ) + + # Forward + Backward - always benchmark without NVTX + print("\n>>> Forward + Backward <<<") + ref_fwdbwd = benchmark_forward_backward( + ref_model, x_list, num_iterations, enable_nvtx=False + ) + plan_a_fwdbwd = benchmark_forward_backward( + plan_a_model, x_list, num_iterations, enable_nvtx=False + ) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") + print( + f"{'Plan A (batched BMM)':<30} {plan_a_fwdbwd:<12.4f} {ref_fwdbwd/plan_a_fwdbwd:<10.2f}x" + ) + + # NVTX profiling run (separate from benchmark) + if args.enable_nvtx: + print("\n" + "-" * 60) + print("NVTX Profiling Run (for nsys analysis only)") + print("-" * 60) + torch.cuda.profiler.start() + + # Run a few iterations with NVTX for profiling + nvtx_iterations = min(10, num_iterations) + _ = benchmark_forward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward_backward( + ref_model, x_list, nvtx_iterations, enable_nvtx=True + ) + _ = benchmark_forward_backward( + plan_a_model, x_list, nvtx_iterations, enable_nvtx=True + ) + + torch.cuda.profiler.stop() + print("NVTX profiling complete.") + + # Fused kernel bandwidth benchmark + print("\n" + "-" * 60) + print("Fused SiLU*Up Kernel Bandwidth Analysis") + print("-" * 60) + + fwd_time, fwd_bw, fwd_pct = benchmark_fused_silu_mul_bandwidth( + batch_size, num_groups, hidden_dim, dtype, num_iterations + ) + bwd_time, bwd_bw = benchmark_fused_silu_mul_backward_bandwidth( + batch_size, num_groups, hidden_dim, dtype, num_iterations + ) + + # Calculate data sizes for reference + numel = batch_size * num_groups * hidden_dim + bytes_per_elem = 2 if dtype == torch.bfloat16 else 4 + fwd_data_mb = 3 * numel * bytes_per_elem / 1e6 + bwd_data_mb = 5 * numel * bytes_per_elem / 1e6 + + print(f"\nTensor shape: ({batch_size}, {num_groups}, {hidden_dim})") + print(f"Elements: {numel:,} ({numel * bytes_per_elem / 1e6:.2f} MB per tensor)") + + print(f"\n{'Kernel':<20} {'Time (ms)':<12} {'Data (MB)':<12} {'BW (GB/s)':<12}") + print("-" * 56) + print( + f"{'Forward (silu*up)':<20} {fwd_time:<12.4f} {fwd_data_mb:<12.2f} {fwd_bw:<12.1f}" + ) + print(f"{'Backward':<20} {bwd_time:<12.4f} {bwd_data_mb:<12.2f} {bwd_bw:<12.1f}") + + print(f"\nNote: Peak memory bandwidth varies by GPU:") + print(f" - H100 SXM: ~3350 GB/s") + print(f" - A100 SXM: ~2039 GB/s") + print(f" - A100 PCIe: ~1935 GB/s") + + # Summary + print("\n" + "=" * 80) + print("Summary") + print("=" * 80) + print( + f""" +Implementation Details: + Reference: Loop over {num_groups} groups, uses nn.Linear (C++ autograd) + Plan A: Batched BMM with custom StridedBmmFunction + +Forward Speedup (full MLP): + Plan A vs Reference: {ref_fwd/plan_a_fwd:.2f}x + +Forward to Hidden (excluding down_proj): + Plan A vs Reference: {ref_hidden/plan_a_hidden:.2f}x + +Fwd+Bwd Speedup: + Plan A vs Reference: {ref_fwdbwd/plan_a_fwdbwd:.2f}x + +Fused SiLU*Up Kernel: + Forward: {fwd_bw:.1f} GB/s + Backward: {bwd_bw:.1f} GB/s +""" + ) + print("=" * 80) + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/examples/commons/ops/grouped_mlp_customop.py b/examples/commons/ops/grouped_mlp_customop.py new file mode 100644 index 000000000..1f0fdd3e1 --- /dev/null +++ b/examples/commons/ops/grouped_mlp_customop.py @@ -0,0 +1,943 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped MLP as PyTorch Custom Ops + +Structure +[1] Triton Kernels - SiLU*Up kernel +[2] Custom Ops - strided_bmm, silu_mul, grouped_mlp_gated_fwd/bwd +[3] nn.Module Wrappers - GroupedMLP_CustomOp, ReferenceGroupedMLP +[4] Correctness & Benchmark +""" + +import sys + +sys.path.insert( + 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" +) + +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +try: + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + from triton.language.math import fast_dividef + +from ops.triton_ops.common import triton_autotune + +# ============================================================================= +# [1] Triton Kernel Configurations +# ============================================================================= + + +def silu_configs(): + configs = [] + for x_block_size in [256, 512, 1024, 2048]: + for num_warps in [2, 4, 8, 16]: + config = triton.Config({"x_block_size": x_block_size}, num_warps) + configs.append(config) + return configs + + +# ============================================================================= +# [1] Triton Kernels - SiLU * Up (SwiGLU pattern) +# ============================================================================= + + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_forward_kernel( + output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """Fused forward: output = silu(gate) * up""" + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + silu_gate = fast_dividef(gate, 1.0 + tl.exp(-gate)) + output = (silu_gate * up).to(output_ptr.dtype.element_ty) + + tl.store(output_ptr + x_offset + cols, output, mask=mask) + + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_backward_kernel( + grad_gate_ptr: tl.tensor, + grad_up_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """Fused backward for output = silu(gate) * up""" + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to( + tl.float32 + ) + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + sigma = tl.sigmoid(gate) + silu_gate = gate * sigma + dsilu_dgate = sigma + gate * sigma * (1.0 - sigma) + + grad_gate = grad_output * up * dsilu_dgate + grad_up = grad_output * silu_gate + + tl.store( + grad_gate_ptr + x_offset + cols, + grad_gate.to(grad_gate_ptr.dtype.element_ty), + mask=mask, + ) + tl.store( + grad_up_ptr + x_offset + cols, + grad_up.to(grad_up_ptr.dtype.element_ty), + mask=mask, + ) + + +# ============================================================================= +# [1] Triton Kernel Launchers (Internal) +# ============================================================================= + + +def _launch_silu_mul_fwd(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """Internal: launch forward kernel""" + x_size = gate.numel() + gate_1d = gate.reshape(-1).contiguous() + up_1d = up.reshape(-1).contiguous() + output = torch.empty_like(gate_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_forward_kernel[grid](output, gate_1d, up_1d, x_size) + return output.view(gate.shape) + + +def _launch_silu_mul_bwd( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Internal: launch backward kernel""" + shape = gate.shape + x_size = gate.numel() + gate_1d = gate.reshape(-1).contiguous() + up_1d = up.reshape(-1).contiguous() + grad_output_1d = grad_output.reshape(-1).contiguous() + grad_gate = torch.empty_like(gate_1d) + grad_up = torch.empty_like(up_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_backward_kernel[grid]( + grad_gate, grad_up, grad_output_1d, gate_1d, up_1d, x_size + ) + return grad_gate.view(shape), grad_up.view(shape) + + +# ============================================================================= +# [2] Custom Op: silu_mul +# ============================================================================= + + +@torch.library.custom_op("grouped_mlp::silu_mul", mutates_args=(), device_types="cuda") +def silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """ + Fused SiLU multiplication: output = silu(gate) * up + + This is the SwiGLU activation pattern used in modern LLMs. + """ + torch._check( + gate.shape == up.shape, lambda: f"Shape mismatch: {gate.shape} vs {up.shape}" + ) + return _launch_silu_mul_fwd(gate.contiguous(), up.contiguous()) + + +@silu_mul.register_fake +def _(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + torch._check(gate.shape == up.shape) + return torch.empty_like(gate) + + +@torch.library.custom_op( + "grouped_mlp::silu_mul_backward", mutates_args=(), device_types="cuda" +) +def silu_mul_backward( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward for silu_mul: returns (grad_gate, grad_up)""" + return _launch_silu_mul_bwd( + grad_output.contiguous(), gate.contiguous(), up.contiguous() + ) + + +@silu_mul_backward.register_fake +def _( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(gate), torch.empty_like(up) + + +def _silu_mul_bwd_fn(ctx, grad_output): + gate, up = ctx.saved_tensors + return silu_mul_backward(grad_output, gate, up) + + +def _silu_mul_setup_ctx(ctx, inputs, output): + gate, up = inputs + ctx.save_for_backward(gate, up) + + +silu_mul.register_autograd(_silu_mul_bwd_fn, setup_context=_silu_mul_setup_ctx) + + +# ============================================================================= +# [2] Custom Op: strided_bmm (Strided Batched Matrix Multiply) +# ============================================================================= + + +@torch.library.custom_op( + "grouped_mlp::strided_bmm", mutates_args=(), device_types="cuda" +) +def strided_bmm( + x: torch.Tensor, # (B, N, K) + weight: torch.Tensor, # (N, K, M) +) -> torch.Tensor: + """ + Strided BMM: output[b, n, :] = x[b, n, :] @ weight[n, :, :] + + Input: x - (batch_size, num_groups, input_dim) + Weight: weight - (num_groups, input_dim, output_dim) + Output: - (batch_size, num_groups, output_dim) + + Internally uses permute + torch.bmm for efficiency. + """ + batch_size, num_groups, _ = x.shape + output_dim = weight.shape[2] + + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) + # x: (B, N, K) -> (N, B, K) + # weight: (N, K, M) + # out: (N, B, M) -> (B, N, M) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + return output + + +@strided_bmm.register_fake +def _(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + batch_size, num_groups, _ = x.shape + output_dim = weight.shape[2] + return x.new_empty(batch_size, num_groups, output_dim) + + +@torch.library.custom_op( + "grouped_mlp::strided_bmm_backward", mutates_args=(), device_types="cuda" +) +def strided_bmm_backward( + grad_output: torch.Tensor, # (B, N, M) + x: torch.Tensor, # (B, N, K) + weight: torch.Tensor, # (N, K, M) +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Backward for strided_bmm. + + grad_x = grad_output @ weight.T => (B, N, K) + grad_weight = x.T @ grad_output => (N, K, M) + """ + grad_output_t = grad_output.permute(1, 0, 2) # (N, B, M) + + # grad_x: (N, B, M) @ (N, M, K) -> (N, B, K) -> (B, N, K) + grad_x = torch.empty_like(x) + torch.bmm(grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2)) + + # grad_weight: (N, K, B) @ (N, B, M) -> (N, K, M) + x_t = x.permute(1, 0, 2) # (N, B, K) + grad_weight = torch.bmm(x_t.transpose(-1, -2), grad_output_t) + + return grad_x, grad_weight + + +@strided_bmm_backward.register_fake +def _( + grad_output: torch.Tensor, x: torch.Tensor, weight: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(weight) + + +def _strided_bmm_bwd_fn(ctx, grad_output): + x, weight = ctx.saved_tensors + return strided_bmm_backward(grad_output, x, weight) + + +def _strided_bmm_setup_ctx(ctx, inputs, output): + x, weight = inputs + ctx.save_for_backward(x, weight) + + +strided_bmm.register_autograd(_strided_bmm_bwd_fn, setup_context=_strided_bmm_setup_ctx) + + +def grouped_mlp_gated_forward( + x: torch.Tensor, # (B*N, D_in) + gate_weight: torch.Tensor, # (N, D_in, D_hidden) + up_weight: torch.Tensor, # (N, D_in, D_hidden) + down_weight: torch.Tensor, # (N, D_hidden, D_out) + num_groups: int, +) -> torch.Tensor: + """ + Grouped MLP forward with gating (SwiGLU pattern). + + This is a composition of custom ops (strided_bmm, silu_mul). + Autograd is handled automatically through the registered backward of each op. + + For each group n: + gate = x @ gate_weight[n] + up = x @ up_weight[n] + hidden = silu(gate) * up + output = hidden @ down_weight[n] + + Args: + x: Input tensor, shape (B*N, D_in) + gate_weight: Gate projection weights, shape (N, D_in, D_hidden) + up_weight: Up projection weights, shape (N, D_in, D_hidden) + down_weight: Down projection weights, shape (N, D_hidden, D_out) + num_groups: Number of groups (N) + + Returns: + Output tensor, shape (B*N, D_out) + """ + batch_size = x.shape[0] // num_groups + input_dim = gate_weight.shape[1] + output_dim = down_weight.shape[2] + + # Reshape: (B*N, D_in) -> (B, N, D_in) + x_3d = x.reshape(batch_size, num_groups, input_dim) + + # Gate BMM: (B, N, D_in) @ (N, D_in, D_hidden) -> (B, N, D_hidden) + gate = strided_bmm(x_3d, gate_weight) + + # Up BMM: (B, N, D_in) @ (N, D_in, D_hidden) -> (B, N, D_hidden) + up = strided_bmm(x_3d, up_weight) + + # Fused SiLU activation: hidden = silu(gate) * up + hidden = silu_mul(gate, up) + + # Down BMM: (B, N, D_hidden) @ (N, D_hidden, D_out) -> (B, N, D_out) + output = strided_bmm(hidden, down_weight) + + # Reshape: (B, N, D_out) -> (B*N, D_out) + return output.reshape(-1, output_dim) + + +# ============================================================================= +# [3] nn.Module: GroupedMLP_CustomOp +# ============================================================================= + + +class GroupedMLP(nn.Module): + """ + Grouped MLP using custom ops. + + This implementation uses registered custom ops instead of autograd.Function, + making it compatible with torch.compile and scenarios where autograd is not available. + + Forward computation: + For each group n: + gate = x @ gate_weight[n] + up = x @ up_weight[n] + hidden = silu(gate) * up + output = hidden @ down_weight[n] + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + # Weights: (N, D_in, D_hidden) or (N, D_hidden, D_out) + self.gate_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + self.up_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + self.down_weight = nn.Parameter( + torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) + ) + + self._init_weights() + + def _init_weights(self): + for i in range(self.num_groups): + nn.init.xavier_normal_(self.gate_weight[i], gain=1.0) + nn.init.xavier_normal_(self.up_weight[i], gain=1.0) + nn.init.xavier_normal_(self.down_weight[i], gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass using custom ops. + + Args: + x: Input tensor, shape (B*N, D_in) + + Returns: + Output tensor, shape (B*N, D_out) + """ + return grouped_mlp_gated_forward( + x, + self.gate_weight, + self.up_weight, + self.down_weight, + self.num_groups, + ) + + def forward_decomposed(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with explicit steps (for debugging/profiling). + + Same computation as forward(), but with explicit intermediate tensors. + """ + batch_size = x.shape[0] // self.num_groups + + # Reshape + x_3d = x.reshape(batch_size, self.num_groups, self.input_dim) + + # 3 BMMs + 1 fused activation + gate = strided_bmm(x_3d, self.gate_weight) + up = strided_bmm(x_3d, self.up_weight) + hidden = silu_mul(gate, up) + output = strided_bmm(hidden, self.down_weight) + + return output.reshape(-1, self.output_dim) + + +# ============================================================================= +# [3] nn.Module: ReferenceGroupedMLP +# ============================================================================= + + +class ReferenceGroupedMLP(nn.Module): + """ + Reference implementation using loop over groups with nn.Linear. + + This is the baseline for correctness verification. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.gate_proj = nn.ModuleList( + [ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ] + ) + self.up_proj = nn.ModuleList( + [ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ] + ) + self.down_proj = nn.ModuleList( + [ + nn.Linear( + hidden_dim, output_dim, bias=False, device=device, dtype=dtype + ) + for _ in range(num_groups) + ] + ) + + self._init_weights() + + def _init_weights(self): + for module_list in [self.gate_proj, self.up_proj, self.down_proj]: + for layer in module_list: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward with loop over groups.""" + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + hidden_i = F.silu(gate_i) * up_i + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + return torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + +# ============================================================================= +# [4] Weight Copy Utilities +# ============================================================================= + + +def copy_weights_ref_to_customop( + ref_model: ReferenceGroupedMLP, + customop_model: GroupedMLP, +): + """Copy weights from Reference model to CustomOp model.""" + with torch.no_grad(): + for i in range(ref_model.num_groups): + # nn.Linear weight is (out, in), we need (in, out) + customop_model.gate_weight[i].copy_(ref_model.gate_proj[i].weight.T) + customop_model.up_weight[i].copy_(ref_model.up_proj[i].weight.T) + customop_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) + + +# ============================================================================= +# [5] Correctness Verification +# ============================================================================= + + +def check_forward_correctness( + ref_model: ReferenceGroupedMLP, + customop_model: GroupedMLP, + batch_size: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[bool, float]: + """Check forward correctness.""" + num_groups = ref_model.num_groups + input_dim = ref_model.input_dim + + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + + with torch.no_grad(): + ref_out = ref_model(x) + customop_out = customop_model(x) + + max_diff = (ref_out - customop_out).abs().max().item() + # bf16 has limited precision, use looser tolerance + atol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + rtol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + passed = torch.allclose(ref_out, customop_out, atol=atol, rtol=rtol) + + return passed, max_diff + + +def check_backward_correctness( + ref_model: ReferenceGroupedMLP, + customop_model: GroupedMLP, + batch_size: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[bool, float]: + """Check backward correctness (input gradient).""" + num_groups = ref_model.num_groups + input_dim = ref_model.input_dim + + x_ref = torch.randn( + batch_size * num_groups, + input_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + x_customop = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + customop_out = customop_model(x_customop) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + customop_out.backward(grad_output) + + max_diff = (x_ref.grad - x_customop.grad).abs().max().item() + # bf16 has ~3 significant digits precision, multiple BMM ops accumulate error + # Use looser tolerance: 5e-2 for bf16 (vs 1e-2 for fp32) + atol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + rtol = 5e-2 if dtype == torch.bfloat16 else 1e-3 + passed = torch.allclose(x_ref.grad, x_customop.grad, atol=atol, rtol=rtol) + + return passed, max_diff + + +def check_torch_compile(customop_model: GroupedMLP, batch_size: int) -> bool: + """Check torch.compile compatibility.""" + num_groups = customop_model.num_groups + input_dim = customop_model.input_dim + + @torch.compile(fullgraph=True) + def compiled_forward(model, x): + return model(x) + + x = torch.randn( + batch_size * num_groups, input_dim, device="cuda", dtype=torch.bfloat16 + ) + + try: + out = compiled_forward(customop_model, x) + return out.shape == (batch_size * num_groups, customop_model.output_dim) + except Exception as e: + print(f"torch.compile failed: {e}") + return False + + +def run_opcheck(customop_model: GroupedMLP): + """Run opcheck on individual custom ops.""" + print("\nRunning opcheck on custom ops...") + + # Test silu_mul + examples_silu = [ + [ + torch.randn(32, 64, device="cuda", dtype=torch.bfloat16), + torch.randn(32, 64, device="cuda", dtype=torch.bfloat16), + ], + [ + torch.randn( + 16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + torch.randn( + 16, 12, 128, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + ], + ] + for i, ex in enumerate(examples_silu): + try: + torch.library.opcheck(silu_mul, ex) + print(f" silu_mul example {i+1}: PASSED") + except Exception as e: + print(f" silu_mul example {i+1}: FAILED - {e}") + + # Test strided_bmm + examples_bmm = [ + [ + torch.randn(32, 8, 64, device="cuda", dtype=torch.bfloat16), + torch.randn(8, 64, 128, device="cuda", dtype=torch.bfloat16), + ], + [ + torch.randn( + 16, 12, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + torch.randn( + 12, 256, 512, device="cuda", dtype=torch.bfloat16, requires_grad=True + ), + ], + ] + for i, ex in enumerate(examples_bmm): + try: + torch.library.opcheck(strided_bmm, ex) + print(f" strided_bmm example {i+1}: PASSED") + except Exception as e: + print(f" strided_bmm example {i+1}: FAILED - {e}") + + +# ============================================================================= +# [6] Benchmark Functions +# ============================================================================= + + +def warmup_gpu(): + """Warmup GPU.""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +def benchmark_forward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, +) -> float: + """Benchmark forward pass.""" + for i in range(num_warmup): + _ = model(x_list[i % len(x_list)]) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for i in range(num_iterations): + _ = model(x_list[i % len(x_list)]) + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / num_iterations + + +def benchmark_forward_backward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, +) -> float: + """Benchmark forward + backward pass.""" + output_dim = model.output_dim + grad_outputs = [ + torch.randn(xi.shape[0], output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + x_with_grad = [xi.requires_grad_(True) for xi in x_list] + params = list(model.parameters()) + + for i in range(num_warmup): + xi = x_with_grad[i % len(x_list)] + out = model(xi) + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for i in range(num_iterations): + xi = x_with_grad[i % len(x_list)] + out = model(xi) + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / num_iterations + + +def benchmark_silu_mul_bandwidth( + shape: Tuple[int, ...], + dtype: torch.dtype = torch.bfloat16, + num_iterations: int = 100, +) -> Tuple[float, float, float, float]: + """ + Benchmark silu_mul kernel bandwidth. + + Returns: (fwd_time_ms, fwd_bw_gb_s, bwd_time_ms, bwd_bw_gb_s) + """ + gate = torch.randn(shape, device="cuda", dtype=dtype) + up = torch.randn(shape, device="cuda", dtype=dtype) + grad_output = torch.randn(shape, device="cuda", dtype=dtype) + + numel = gate.numel() + bytes_per_elem = gate.element_size() + fwd_bytes = 3 * numel * bytes_per_elem # read gate, up; write output + bwd_bytes = ( + 5 * numel * bytes_per_elem + ) # read grad_out, gate, up; write grad_gate, grad_up + + # Warmup + for _ in range(10): + _ = silu_mul(gate, up) + _ = silu_mul_backward(grad_output, gate, up) + torch.cuda.synchronize() + + # Forward benchmark + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(num_iterations): + _ = silu_mul(gate, up) + end.record() + torch.cuda.synchronize() + fwd_ms = start.elapsed_time(end) / num_iterations + fwd_bw = fwd_bytes / (fwd_ms / 1000) / 1e9 + + # Backward benchmark + start.record() + for _ in range(num_iterations): + _ = silu_mul_backward(grad_output, gate, up) + end.record() + torch.cuda.synchronize() + bwd_ms = start.elapsed_time(end) / num_iterations + bwd_bw = bwd_bytes / (bwd_ms / 1000) / 1e9 + + return fwd_ms, fwd_bw, bwd_ms, bwd_bw + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + print("=" * 70) + print("Grouped MLP as PyTorch Custom Ops") + print("=" * 70) + + torch.cuda.init() + print(f"\nDevice: {torch.cuda.get_device_name(0)}") + + # Configuration + batch_size = 2560 + num_groups = 12 + input_dim = 1024 + hidden_dim = 3072 + output_dim = 1024 + dtype = torch.bfloat16 + num_iterations = 100 + + print( + f""" +Config: + Batch size: {batch_size} + Num groups: {num_groups} + Dimensions: {input_dim} -> {hidden_dim} -> {output_dim} + Dtype: {dtype} +""" + ) + + print("Warming up GPU...") + warmup_gpu() + + # Create models + print("Creating models...") + ref_model = ReferenceGroupedMLP( + input_dim, hidden_dim, output_dim, num_groups, dtype=dtype + ).cuda() + + customop_model = GroupedMLP( + input_dim, hidden_dim, output_dim, num_groups, dtype=dtype + ).cuda() + + # Copy weights + copy_weights_ref_to_customop(ref_model, customop_model) + + # ========================= + # Correctness Verification + # ========================= + print("\n" + "-" * 70) + print("Correctness Verification") + print("-" * 70) + + fwd_ok, fwd_diff = check_forward_correctness( + ref_model, customop_model, batch_size, dtype + ) + print(f"\nForward: {'✓ PASS' if fwd_ok else '✗ FAIL'} (max_diff: {fwd_diff:.2e})") + + bwd_ok, bwd_diff = check_backward_correctness( + ref_model, customop_model, batch_size, dtype + ) + print(f"Backward: {'✓ PASS' if bwd_ok else '✗ FAIL'} (max_diff: {bwd_diff:.2e})") + + compile_ok = check_torch_compile(customop_model, batch_size) + print(f"torch.compile(fullgraph=True): {'✓ PASS' if compile_ok else '✗ FAIL'}") + + # ========================= + # opcheck + # ========================= + print("\n" + "-" * 70) + print("Op Registration Check") + print("-" * 70) + run_opcheck(customop_model) + + # ========================= + # Performance Benchmark + # ========================= + print("\n" + "-" * 70) + print("Performance Benchmark") + print("-" * 70) + + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Forward + print("\n>>> Forward Pass <<<") + ref_fwd = benchmark_forward(ref_model, x_list, num_iterations) + customop_fwd = benchmark_forward(customop_model, x_list, num_iterations) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") + print( + f"{'CustomOp (BMM)':<30} {customop_fwd:<12.4f} {ref_fwd/customop_fwd:<10.2f}x" + ) + + # Forward + Backward + print("\n>>> Forward + Backward <<<") + ref_fwdbwd = benchmark_forward_backward(ref_model, x_list, num_iterations) + customop_fwdbwd = benchmark_forward_backward(customop_model, x_list, num_iterations) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") + print( + f"{'CustomOp (BMM)':<30} {customop_fwdbwd:<12.4f} {ref_fwdbwd/customop_fwdbwd:<10.2f}x" + ) + + # ========================= + # SiLU*Up Kernel Bandwidth + # ========================= + print("\n" + "-" * 70) + print("Fused SiLU*Up Kernel Bandwidth") + print("-" * 70) + + shape = (batch_size, num_groups, hidden_dim) + fwd_ms, fwd_bw, bwd_ms, bwd_bw = benchmark_silu_mul_bandwidth(shape, dtype) + + print(f"\nShape: {shape}") + print(f"\n{'Kernel':<15} {'Time (ms)':<12} {'BW (GB/s)':<12}") + print("-" * 40) + print(f"{'Forward':<15} {fwd_ms:<12.4f} {fwd_bw:<12.1f}") + print(f"{'Backward':<15} {bwd_ms:<12.4f} {bwd_bw:<12.1f}") + + # ========================= + # Summary + # ========================= + print("\n" + "=" * 70) + print("Summary") + print("=" * 70) + print( + f""" + Forward Speedup: {ref_fwd/customop_fwd:.2f}x + Fwd+Bwd Speedup: {ref_fwdbwd/customop_fwdbwd:.2f}x + SiLU*Up BW: {fwd_bw:.1f} GB/s (fwd), {bwd_bw:.1f} GB/s (bwd) +""" + ) + print("=" * 70) + print("Done!") + + +if __name__ == "__main__": + main()