Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX
This repo provides routines for Semiring and Semiring gradient Tensor operations for PyTorch. It also provides a DSL for writing your own custom Semiring and Semiring gradient routines that may include hyperparameters passed in to the kernel. These hyperparameters can either be single value tensors, single value tensors broadcast to a batch of tensors or a vector of batched hyperparameters applied to a batch of tensors.
import torch
import mm_kermac.hyper_semiring as kermac
device = torch.device("cuda")
M, N, K = 1024, 2048, 256
x = torch.randn(M, K, device=device)
z = torch.randn(N, K, device=device)
out = torch.empty((M, N), device=device)
# GEMM
gemm = kermac.Gemm()
# First call compiles and loads the module (cached per device).
gemm(x=x, z=z, out=out)
# L2 cdist
norm_l2 = kermac.NormL2()
norm_l2(x=x, z=z, out=out)
# Fractional p via hyper parameters
p = 1.3
norm_lp = kermac.NormLp(epsilon=0.0)
p_inner = torch.tensor(p, device=device)
p_outer = torch.tensor(1.0 / p, device=device)
norm_lp(x=x, z=z, p_inner=p_inner, p_outer=p_outer, out=out)mm-kermac only supports Nvidia cards with sm_80 or greater:
- For server cards A100 or greater, i.e. A10, H100, B100, BH200
- For consumer cards 3000 series or greater, i.e. 3070, 3090, 4090, 5090
To install, depending on your CUDA toolkit version do one of these:
pip install mm-kermac[cu12]pip install mm-kermac[cu13]The zoo provides ready-to-use kernels built on HyperSemiringKernel and HyperSemiringGradientKernel.
import torch
import mm_kermac.hyper_semiring as hs
device = torch.device("cuda")
x = torch.randn(1024, 256, device=device)
z = torch.randn(2048, 256, device=device)
out = torch.empty((x.size(0), z.size(0)), device=device)
gemm = hs.Gemm()
norm_l1 = hs.NormL1()
norm_l2 = hs.NormL2()
norm_lp = hs.NormLp(epsilon=0.0)
gemm(x=x, z=z, out=out)
norm_l2(x=x, z=z, out=out, try_to_align=False)For gradient kernels, see examples/hyper_semiring_gradient.py for the expected shapes and arguments.
The benchmark compares NormL1, NormL2, and NormLp against torch.cdist for p=1.0, p=2.0, and a fractional p.
python examples/bench_hyper_semiring_cdist.py --M 2048 --N 2048 --K 256 --iters 50 --warmup 10 --p-frac 1.3Sample output:
Device: NVIDIA GeForce RTX 5090
M=2048 N=2048 K=256 iters=50 warmup=10
Fractional p=1.3 epsilon=0.0 try_align=False
case | kermac ms | torch ms | speedup
p=1.0 | kermac 0.078 ms | torch 4.779 ms | 61.03x
p=2.0 | kermac 0.080 ms | torch 0.093 ms | 1.15x
p=1.3 | kermac 0.375 ms | torch 5.312 ms | 14.16x
HyperSemiringKernelrenders a CUTLASS/CuTe template, injects PTX stubs generated by Stack PTX, compiles to a cubin, and caches per device and signature.- The kernel is split into
mma_lambda(per multiply-accumulate step) andepilogue_lambda(post-reduction). hyper_dictmaps user names to hyperparameter tensors. Insertion order maps tohyper0,hyper1, etc in the generated PTX, and the lambdas receive areg_dictkeyed by those user names.- The number of hyper parameters is user defined; the template is generated and cached separately for each count.
- Hyper parameters may be scalar tensors or length-L tensors; L is inferred from input batches and hyper tensors, so you can batch multiple p values in one call.
- Zoo kernels are thin wrappers that predefine the lambdas and build the right
hyper_dictfor you.
Custom kernel sketch:
import torch
from mm_kermac import PtxInstruction
from mm_kermac.hyper_semiring import HyperSemiringKernel
device = torch.device("cuda")
x = torch.randn(1024, 256, device=device)
z = torch.randn(2048, 256, device=device)
out = torch.empty((x.size(0), z.size(0)), device=device)
kernel = HyperSemiringKernel(
mma_lambda=lambda a, b, c, reg: [
a, # push a
b, # push b
PtxInstruction.sub_ftz_f32, # diff = b - a
reg["beta"], # push beta (dynamically 0.5 from hyper["beta"] from torch.Tensor value)
PtxInstruction.mul_ftz_f32, # diff *= beta
c, # push accumulator
PtxInstruction.add_ftz_f32, # acc += diff
],
epilogue_lambda=lambda e, reg: [
e, # push accumulator
reg["gamma"], # push gamma (dynamically 2.0 from hyper["gamma"] from torch.Tensor value)
PtxInstruction.mul_ftz_f32, # scale output by gamma
],
)
hyper = {
"beta": torch.tensor(0.5, device=device),
"gamma": torch.tensor(2.0, device=device),
}
kernel(a=x, b=z, hyper_dict=hyper, out=out)This example defines a semiring where the "multiply" is mul(a, b) = beta * (b - a) and the "add" is standard addition. The epilogue then scales the accumulated sum by gamma.
With beta=0.5 and gamma=2.0, the kernel computes:
out[m, n] = gamma * sum_k beta * (z[n, k] - x[m, k])
which simplifies to:
out[m, n] = sum_k (z[n, k] - x[m, k])
Gradient kernels split the work into three stages:
multiply_lambdacomputes a per-element contribution froma,b, andd(often a derivative-like term).accumulate_lambdacombines that contribution withcand accumulates intoeacrossk.epilogue_lambdaapplies any final transform toe.
In the common pattern used by the zoo, the math looks like:
out[o, n, m] = epilogue( sum_k c[o, k] * multiply(d[n, m], b[n, k], a[k, m]) )
Example gradient kernel sketch:
import torch
from mm_kermac import Stack, PtxInstruction
from mm_kermac.hyper_semiring_gradient import HyperSemiringGradientKernel
device = torch.device("cuda")
grad_kernel_matrix = torch.randn(256, 128, device=device) # a: (K, M)
x = torch.randn(512, 256, device=device) # b: (N, K)
coefs = torch.randn(64, 256, device=device) # c: (O, K)
z = torch.randn(512, 128, device=device) # d: (N, M)
out = torch.empty((64, 512, 128), device=device) # e: (O, N, M)
kernel = HyperSemiringGradientKernel(
multiply_lambda=lambda d, b, a, reg: [
b, # push b
d, # push d
PtxInstruction.sub_ftz_f32, # diff = d - b
reg["alpha"], # push alpha (0.25 from hyper["alpha"])
PtxInstruction.mul_ftz_f32, # diff *= alpha
a, # push a
PtxInstruction.mul_ftz_f32, # diff *= a
],
accumulate_lambda=lambda c, diff, e, reg: [
c, # push c
diff, # push diff
PtxInstruction.mul_ftz_f32, # c * diff
e, # push accumulator
PtxInstruction.add_ftz_f32, # acc += c * diff
],
epilogue_lambda=lambda e, reg: [
e, # push accumulator
reg["scale"], # push scale (2.0 from hyper["scale"])
PtxInstruction.mul_ftz_f32, # scale output
],
)
hyper = {
"alpha": torch.tensor(0.25, device=device),
"scale": torch.tensor(2.0, device=device),
}
kernel(a=grad_kernel_matrix, b=x, c=coefs, d=z, hyper_dict=hyper, out=out)With alpha=0.25 and scale=2.0, this computes:
out[o, n, m] = scale * sum_k c[o, k] * (alpha * (z[n, m] - x[n, k]) * a[k, m])
See examples/hyper_semiring_gradient.py for concrete kernel definitions and expected shapes.
This repo relies on mm-ptx for Stack PTX and PTX Inject. Please see the repo for details on how Stack PTX works, how to use it, and simplified examples for using the system.
Tests are GPU-backed and require CUDA with sm_80 or greater. They are implemented as unittest copies of the examples.
Run all tests:
python -m unittest discover -s tests -p 'test_*.py' -vIf CUDA is unavailable or your GPU is below sm_80, the tests will be skipped.
This project is licensed under the MIT License - see the LICENSE file for details.
If you use this software in your work, please cite it using the following BibTeX entry (generated from the CITATION.cff file):
@software{Durham_mm-kermac_2025,
author = {Durham, Charlie},
title = {mm-kermac: Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX},
version = {1.0.0},
date-released = {2025-10-19},
url = {https://github.com/MetaMachines/mm-kermac-py}
}