Skip to content

MetaMachines/mm-kermac-py

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mm-kermac

Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX

This repo provides routines for Semiring and Semiring gradient Tensor operations for PyTorch. It also provides a DSL for writing your own custom Semiring and Semiring gradient routines that may include hyperparameters passed in to the kernel. These hyperparameters can either be single value tensors, single value tensors broadcast to a batch of tensors or a vector of batched hyperparameters applied to a batch of tensors.

Quickstart

import torch
import mm_kermac.hyper_semiring as kermac

device = torch.device("cuda")
M, N, K = 1024, 2048, 256
x = torch.randn(M, K, device=device)
z = torch.randn(N, K, device=device)
out = torch.empty((M, N), device=device)

# GEMM
gemm = kermac.Gemm()
# First call compiles and loads the module (cached per device).
gemm(x=x, z=z, out=out)

# L2 cdist
norm_l2 = kermac.NormL2()
norm_l2(x=x, z=z, out=out)

# Fractional p via hyper parameters
p = 1.3
norm_lp = kermac.NormLp(epsilon=0.0)
p_inner = torch.tensor(p, device=device)
p_outer = torch.tensor(1.0 / p, device=device)
norm_lp(x=x, z=z, p_inner=p_inner, p_outer=p_outer, out=out)

Installation

mm-kermac only supports Nvidia cards with sm_80 or greater:

  • For server cards A100 or greater, i.e. A10, H100, B100, BH200
  • For consumer cards 3000 series or greater, i.e. 3070, 3090, 4090, 5090

To install, depending on your CUDA toolkit version do one of these:

pip install mm-kermac[cu12]
pip install mm-kermac[cu13]

Zoo kernels

The zoo provides ready-to-use kernels built on HyperSemiringKernel and HyperSemiringGradientKernel.

import torch
import mm_kermac.hyper_semiring as hs

device = torch.device("cuda")
x = torch.randn(1024, 256, device=device)
z = torch.randn(2048, 256, device=device)
out = torch.empty((x.size(0), z.size(0)), device=device)

gemm = hs.Gemm()
norm_l1 = hs.NormL1()
norm_l2 = hs.NormL2()
norm_lp = hs.NormLp(epsilon=0.0)

gemm(x=x, z=z, out=out)
norm_l2(x=x, z=z, out=out, try_to_align=False)

For gradient kernels, see examples/hyper_semiring_gradient.py for the expected shapes and arguments.

Benchmarks

The benchmark compares NormL1, NormL2, and NormLp against torch.cdist for p=1.0, p=2.0, and a fractional p.

python examples/bench_hyper_semiring_cdist.py --M 2048 --N 2048 --K 256 --iters 50 --warmup 10 --p-frac 1.3

Sample output:

Device: NVIDIA GeForce RTX 5090
M=2048 N=2048 K=256 iters=50 warmup=10
Fractional p=1.3 epsilon=0.0 try_align=False
   case  |          kermac ms |          torch ms | speedup
   p=1.0 | kermac    0.078 ms | torch    4.779 ms |  61.03x
   p=2.0 | kermac    0.080 ms | torch    0.093 ms |   1.15x
   p=1.3 | kermac    0.375 ms | torch    5.312 ms |  14.16x

How it works

  • HyperSemiringKernel renders a CUTLASS/CuTe template, injects PTX stubs generated by Stack PTX, compiles to a cubin, and caches per device and signature.
  • The kernel is split into mma_lambda (per multiply-accumulate step) and epilogue_lambda (post-reduction).
  • hyper_dict maps user names to hyperparameter tensors. Insertion order maps to hyper0, hyper1, etc in the generated PTX, and the lambdas receive a reg_dict keyed by those user names.
  • The number of hyper parameters is user defined; the template is generated and cached separately for each count.
  • Hyper parameters may be scalar tensors or length-L tensors; L is inferred from input batches and hyper tensors, so you can batch multiple p values in one call.
  • Zoo kernels are thin wrappers that predefine the lambdas and build the right hyper_dict for you.

Custom kernel sketch:

import torch
from mm_kermac import PtxInstruction
from mm_kermac.hyper_semiring import HyperSemiringKernel

device = torch.device("cuda")
x = torch.randn(1024, 256, device=device)
z = torch.randn(2048, 256, device=device)
out = torch.empty((x.size(0), z.size(0)), device=device)

kernel = HyperSemiringKernel(
    mma_lambda=lambda a, b, c, reg: [
        a,  # push a
        b,  # push b
        PtxInstruction.sub_ftz_f32,  # diff = b - a
        reg["beta"],  # push beta (dynamically 0.5 from hyper["beta"] from torch.Tensor value)
        PtxInstruction.mul_ftz_f32,  # diff *= beta
        c,  # push accumulator
        PtxInstruction.add_ftz_f32,  # acc += diff
    ],
    epilogue_lambda=lambda e, reg: [
        e,  # push accumulator
        reg["gamma"],  # push gamma (dynamically 2.0 from hyper["gamma"] from torch.Tensor value)
        PtxInstruction.mul_ftz_f32,  # scale output by gamma
    ],
)

hyper = {
    "beta": torch.tensor(0.5, device=device),
    "gamma": torch.tensor(2.0, device=device),
}
kernel(a=x, b=z, hyper_dict=hyper, out=out)

Custom kernel explanation

This example defines a semiring where the "multiply" is mul(a, b) = beta * (b - a) and the "add" is standard addition. The epilogue then scales the accumulated sum by gamma.

With beta=0.5 and gamma=2.0, the kernel computes:

out[m, n] = gamma * sum_k beta * (z[n, k] - x[m, k])

which simplifies to:

out[m, n] = sum_k (z[n, k] - x[m, k])

HyperSemiringGradientKernel

Gradient kernels split the work into three stages:

  • multiply_lambda computes a per-element contribution from a, b, and d (often a derivative-like term).
  • accumulate_lambda combines that contribution with c and accumulates into e across k.
  • epilogue_lambda applies any final transform to e.

In the common pattern used by the zoo, the math looks like:

out[o, n, m] = epilogue( sum_k c[o, k] * multiply(d[n, m], b[n, k], a[k, m]) )

Example gradient kernel sketch:

import torch
from mm_kermac import Stack, PtxInstruction
from mm_kermac.hyper_semiring_gradient import HyperSemiringGradientKernel

device = torch.device("cuda")
grad_kernel_matrix = torch.randn(256, 128, device=device)  # a: (K, M)
x = torch.randn(512, 256, device=device)                   # b: (N, K)
coefs = torch.randn(64, 256, device=device)                # c: (O, K)
z = torch.randn(512, 128, device=device)                   # d: (N, M)
out = torch.empty((64, 512, 128), device=device)           # e: (O, N, M)

kernel = HyperSemiringGradientKernel(
    multiply_lambda=lambda d, b, a, reg: [
        b,  # push b
        d,  # push d
        PtxInstruction.sub_ftz_f32,  # diff = d - b
        reg["alpha"],  # push alpha (0.25 from hyper["alpha"])
        PtxInstruction.mul_ftz_f32,  # diff *= alpha
        a,  # push a
        PtxInstruction.mul_ftz_f32,  # diff *= a
    ],
    accumulate_lambda=lambda c, diff, e, reg: [
        c,  # push c
        diff,  # push diff
        PtxInstruction.mul_ftz_f32,  # c * diff
        e,  # push accumulator
        PtxInstruction.add_ftz_f32,  # acc += c * diff
    ],
    epilogue_lambda=lambda e, reg: [
        e,  # push accumulator
        reg["scale"],  # push scale (2.0 from hyper["scale"])
        PtxInstruction.mul_ftz_f32,  # scale output
    ],
)

hyper = {
    "alpha": torch.tensor(0.25, device=device),
    "scale": torch.tensor(2.0, device=device),
}
kernel(a=grad_kernel_matrix, b=x, c=coefs, d=z, hyper_dict=hyper, out=out)

With alpha=0.25 and scale=2.0, this computes:

out[o, n, m] = scale * sum_k c[o, k] * (alpha * (z[n, m] - x[n, k]) * a[k, m])

See examples/hyper_semiring_gradient.py for concrete kernel definitions and expected shapes.

mm-ptx

This repo relies on mm-ptx for Stack PTX and PTX Inject. Please see the repo for details on how Stack PTX works, how to use it, and simplified examples for using the system.

Tests

Tests are GPU-backed and require CUDA with sm_80 or greater. They are implemented as unittest copies of the examples.

Run all tests:

python -m unittest discover -s tests -p 'test_*.py' -v

If CUDA is unavailable or your GPU is below sm_80, the tests will be skipped.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this software in your work, please cite it using the following BibTeX entry (generated from the CITATION.cff file):

@software{Durham_mm-kermac_2025,
  author       = {Durham, Charlie},
  title        = {mm-kermac: Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX},
  version      = {1.0.0},
  date-released = {2025-10-19},
  url          = {https://github.com/MetaMachines/mm-kermac-py}
}

About

Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX

Resources

License

Stars

Watchers

Forks

Packages

No packages published