Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
3b4366b
Fix CI failures for UB overlap changes (#2149)
djns99 Sep 3, 2025
f378eaf
[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100…
KshitijLakhani Sep 3, 2025
0f68f7b
[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Gr…
zhongbozhu Sep 4, 2025
e9a5fa4
[PyTorch] fix cross entropy vanishing gradients (#2139)
casper-hansen Sep 4, 2025
11e9d66
Fix bug when enabling --overlap-grad-reduce in mcore (#2142)
lhb8125 Sep 5, 2025
b10f436
Fix CUDA version in setup.py (#2132)
vcherepanov-nv Sep 5, 2025
c47f329
[JAX] NoScaleTensor wrapper for non-quantized data (#2136)
jberchtold-nvidia Sep 5, 2025
5b3d65c
[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)
phu0ngng Sep 8, 2025
aa06107
Fixing few issues with multi-process launching. (#2155)
mingxu1067 Sep 8, 2025
603dbf7
Update list of authorized CI users (#2152)
timmoon10 Sep 8, 2025
84fa28d
Fused RoPE with combined QKV input. (#2122)
vasunvidia Sep 8, 2025
a26a7f1
Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162)
Autumn1998 Sep 9, 2025
5f2b831
[JAX] Scale swizzling via JAX transpose op (#2163)
phu0ngng Sep 9, 2025
4903f94
Extract cpp distributed tests into a separate project (#2165)
vcherepanov-nv Sep 10, 2025
483d959
Adds context parallelism utilities: moving cp shards to diff ranks an…
jomitchellnv Sep 10, 2025
405d474
[PyTorch Debug] Fix issue with negative underflow% stat. (#2107)
pggPL Sep 15, 2025
cd2034f
Lower precision gated-act to accelerate FP8 current-scaling. (#2153)
mingxu1067 Sep 15, 2025
59130cc
[PyTorch] Support activation CPU offloading in fusible ops (#2158)
timmoon10 Sep 15, 2025
258d084
Do not use normalization forward + amax fusion if cuDNN backend is re…
janekb04 Sep 16, 2025
c221909
Fix unjoined comm stream in UB communicator (#2160)
djns99 Sep 16, 2025
ba37529
FP8 Output Quantization for GEMM (#2123)
vthumbe1503 Sep 17, 2025
7042d7a
TE Gemma tutorial attempt#2 (#1839)
sudhakarsingh27 Sep 17, 2025
93a67af
Fix memory overhead of linear layer when all gather from sequence par…
yuzhongw-nvidia Sep 17, 2025
eb69fad
Fix incorrect TP rank calculation when using data parallel (#2179)
djns99 Sep 17, 2025
8aee1bb
[Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model…
cassiewilliam Sep 18, 2025
c334fc4
[PyTorch] Support FA3 for MLA and with CP (#1907)
zhujian19891203 Sep 18, 2025
7f77127
Fix cuDNN version checks when getting backend and for sm89 kv cache (…
KshitijLakhani Sep 18, 2025
5b3092a
Changed VERSION to 2.9.0.dev0
ptrendx Sep 19, 2025
57b4d7b
[JAX] Remove import jax.extend.ffi (#2193)
phu0ngng Sep 22, 2025
5e4e0b2
[PyTorch] Add sink attention support from cuDNN (#2148)
cyanguwa Sep 22, 2025
2db20a6
[QA] Add pytest xml report for all tests in qa folder that use pytest…
shengfangd Sep 23, 2025
a92a0ad
[JAX] Local-Amax for Current-Scaling (#2183)
mingxu1067 Sep 23, 2025
3f875fb
[JAX] Restore Shardy Rule with CompoundFactor (#2167)
phu0ngng Sep 23, 2025
afd15a1
[JAX] Update JAX version requirement in pyproject.toml (#2197)
phu0ngng Sep 24, 2025
972a842
Merge branch 'main' into hongbinl/adapt_for_offload_activation
Sep 26, 2025
7933781
temp fix to enable --overlap-grad-reduce
Sep 26, 2025
9e72796
[PyTorch] Unpin version of onnxscript and onnxruntime (#2202)
pggPL Sep 26, 2025
4d14578
[JAX] Fix XML filename in the L0_jax_uniitest (#2205)
phu0ngng Sep 27, 2025
d75bf43
[JAX] CollectiveGemm (#2166)
phu0ngng Sep 27, 2025
963b39c
fix to enable --overlap-grad-reduce
lhb8125 Sep 29, 2025
a91e458
[JAX] Add xml export for `test_multiprocessing_encoder` and `test_cge…
phu0ngng Sep 29, 2025
dfeef1a
[JAX] Address tolerance check for current scaling dact dbias (#2211)
jberchtold-nvidia Sep 29, 2025
3f5b475
[Core][PyTorch] NVFP4 recipe (#2177)
ksivaman Sep 29, 2025
2354fb8
Fix the segfault in the nvfp4 quantization (#2214)
ptrendx Sep 30, 2025
25252e9
[PyTorch] Add FP8 attention with current scaling (#2012)
cyanguwa Sep 30, 2025
7fa0f55
[Pytorch] Support for Swiglu Activation used in GPT OSS (#2161)
vthumbe1503 Sep 30, 2025
ce18bee
[JAX] Load modules during initialize for Norm and Act primitives (#2219)
jberchtold-nvidia Sep 30, 2025
7022d50
[PyTorch] Quantizer as API (#2039)
negvet Oct 1, 2025
ac4e0fd
[JAX] Rework amax reduction over TPSP (#2218)
phu0ngng Oct 1, 2025
b0d562d
[JAX] Fix `rng_state` shape in fused attention (#2217)
phu0ngng Oct 1, 2025
ac886c3
[PyTorch] Fix QuantizedTensorBase -> QuantizedTensorStorage (#2226)
negvet Oct 1, 2025
f0a9404
Fix hang during debug build (#2221)
ksivaman Oct 1, 2025
90449f7
Convert `NVFP4BlockScaling` to dataclass (#2227)
ksivaman Oct 1, 2025
aee5a82
Fix the cuBLAS workspace alignment (#2223)
ptrendx Oct 1, 2025
c100318
[PyTorch] Set usages for linear op quantizers before forward (#2222)
timmoon10 Oct 2, 2025
f936c2a
[JAX] Fix code block in fp8_autocast docstring (#2228)
jberchtold-nvidia Oct 2, 2025
be7f43f
[JAX] Fix shard map issue when `get_all_mesh_axes()` is used (#2229)
jberchtold-nvidia Oct 2, 2025
e30c36a
[PyTorch] fix int32 overflow in permute kernels (#2196)
hxbai Oct 2, 2025
b840898
[JAX] Clamped Swiglu Integration (#2194)
vthumbe1503 Oct 3, 2025
dfe5b7d
[Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek…
janekb04 Oct 3, 2025
5be8125
Fix bug where CUTLASS kernel was not being compiled for SM90a (#2235)
timmoon10 Oct 4, 2025
08779fd
Fix FP8 current scaling attention logic (#2234)
ksivaman Oct 4, 2025
7e45be7
Added the NVFP4 section to the low precision training tutorial (#2237)
ptrendx Oct 5, 2025
0db0f4d
[JAX] Fix for GEMM + fuse bias + AllReduce (#2230)
phu0ngng Oct 6, 2025
56e2fed
[Build] fix: TE installation failed to find uv-installed cuDNN librar…
KivenChen Oct 6, 2025
9f3e79b
[PyTorch] Fix tests for 🤗 integration (#2239)
ksivaman Oct 6, 2025
127b6d3
[JAX] Activation/Normalization to output amax for later quantization …
phu0ngng Oct 7, 2025
76bced5
`NVFP4BlockScaling` recipe docs (#2241)
ksivaman Oct 7, 2025
ac5e868
Skip fp8 tests on unsupported devices (#2243)
vcherepanov-nv Oct 7, 2025
66f9b3c
[PyTorch] Unblock fused bgrad quantization path for nvfp4 (#2246)
ksivaman Oct 8, 2025
af2a0c1
[JAX] Async issuing D2H memcpy for grouped_gemm group_sizes array (#2…
huanghua1994 Oct 8, 2025
e37e33e
Disallow pure E5M2 recipe for `Float8BlockScaling` (#2251)
ksivaman Oct 9, 2025
9bf4175
[PyTorch] Deprecate old `float8_tensor.py` (#2250)
ksivaman Oct 9, 2025
e99be1b
Update minimum python version to 3.10 and add checks in CI (#2247)
ksivaman Oct 9, 2025
8a7ab3d
[JAX] NVFP4 support in TE/JAX (#2254)
jberchtold-nvidia Oct 9, 2025
dd9433e
Don't pickle an empty dict in LayerNorm and pt base modules (#2253)
pstjohn Oct 9, 2025
663fc8e
Merge branch 'main' into hongbinl/adapt_for_offload_activation
lhb8125 Oct 11, 2025
e9f49f4
Merge branch 'hongbinl/adapt_for_offload_activation' of https://githu…
lhb8125 Oct 11, 2025
98d354c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 11, 2025
88c7b05
add comments
lhb8125 Oct 11, 2025
665bcd1
Merge branch 'hongbinl/adapt_for_offload_activation' of https://githu…
lhb8125 Oct 11, 2025
fe9bab4
remove unused code
lhb8125 Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
run: pip install pybind11[global] nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ jobs:
|| github.actor == 'tdophung'
|| github.actor == 'vthumbe1503'
|| github.actor == 'janekb04'
|| github.actor == 'shengfangd'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ repos:
entry: clang-format -i
args: ["-style=file"]
files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$

- repo: https://github.com/netromdk/vermin
rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3
hooks:
- id: vermin
args: ['-t=3.10', '--violations']
1 change: 1 addition & 0 deletions 3rdparty/cutlass
Submodule cutlass added at 57e3cf
152 changes: 152 additions & 0 deletions benchmarks/benchmark_rht_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext

from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer

scale_padding_to = 1
permute_scale = False

TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
}


def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
# Generate random input data
M, K = shape
x = torch.randn([M, K], dtype=input_dtype, device="cuda")

assert shape[0] % 16 == 0, "Shape must be divisible by 16"
assert shape[1] % 16 == 0, "Shape must be divisible by 16"

# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
stochastic_rounding=stochastic_rounding,
)
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, K), dtype=x.dtype, device=x.device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)

with torch.no_grad():
stmt = "kernel_func(input, output)"
globals_dict = {
"kernel_func": nvfp4_quantizer.update_quantized,
"input": x,
"output": x_nvfp4_sut,
}

timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=5)
print(timing)
timing_us = timing.median * 1e6

input_nbytes = shape[0] * shape[1] * 2 # bf16
output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems

total_nbytes = (
0
+ input_nbytes
* 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ 2 * 4 # Output 2 * float for scale & amax
+ 2 * 4 # Input 2 * float
+ output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ sf_nbytes * 2 # Scale factor
)

throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)

print(
f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
f" {throughput_GBps} GB/s"
)
return timing_us, throughput_GBps


# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
args = parser.parse_args()

if args.profile:
print("Profiling is enabled.")
else:
print("Profiling is disabled.")

shapes = [
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]

if args.profile:
shapes = [
(16384, 6144),
]

data = []
for stochastic_rounding in [True]: # , False]:
for shape in shapes:
print(
f"Running benchmark_func with shape {shape} and stochastic_rounding"
f" {stochastic_rounding}"
)
timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
data.append(
[
"benchmark_func",
shape,
stochastic_rounding,
timing_us,
throughput_GBps,
]
)

df = pd.DataFrame(
data=data,
columns=[
"kernel",
"shape",
"stochastic_rounding",
"timing_us",
"throughput(GB/s)",
],
)
print(df)
df.to_csv("benchmark_cast_nvfp4.csv", index=False)
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.8.0.dev0
2.9.0.dev0
1 change: 1 addition & 0 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
build_dir,
f"-DPython_EXECUTABLE={sys.executable}",
f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
f"-DPython_SITEARCH={sysconfig.get_path('platlib')}",
f"-DCMAKE_BUILD_TYPE={build_type}",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
Expand Down
1 change: 1 addition & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,5 @@ def setup_jax_extension(
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"],
)
4 changes: 2 additions & 2 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]
return ["torch>=2.1", "einops", "onnxscript", "onnx"]


def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy", "torchvision", "transformers"]
return ["numpy", "torchvision", "transformers", "torchao==0.13"]


def setup_pytorch_extension(
Expand Down
34 changes: 28 additions & 6 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,31 @@
import shutil
import subprocess
import sys
import platform
from pathlib import Path
from importlib.metadata import version as get_version
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union


# Needs to stay consistent with .pre-commit-config.yaml config.
def min_python_version() -> Tuple[int]:
"""Minimum supported Python version."""
return (3, 10, 0)


def min_python_version_str() -> str:
"""String representing minimum supported Python version."""
return ".".join(map(str, min_python_version()))


if sys.version_info < min_python_version():
raise RuntimeError(
f"Transformer Engine requires Python {min_python_version_str()} or newer, "
f"but found Python {platform.python_version()}."
)


@functools.lru_cache(maxsize=None)
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
Expand Down Expand Up @@ -234,15 +253,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:

@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
archs = os.getenv("NVTE_CUDA_ARCHS")
if archs is None:
version = cuda_version()
if version >= (13, 0):
os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120"
archs = "75;80;89;90;100;100a;103a;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8):
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120"
archs = "70;80;89;90;100;100a;120"
else:
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90"
return os.getenv("NVTE_CUDA_ARCHS")
archs = "70;80;89;90"
return archs


def cuda_version() -> Tuple[int, ...]:
Expand Down
2 changes: 2 additions & 0 deletions docs/api/common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Common API

.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)

.. autoapiclass:: transformer_engine.common.recipe.NVFP4BlockScaling(fp4_format=Format.E2M1)

.. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID)

.. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3)
Binary file added docs/examples/FP4_format.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/examples/FP4_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/examples/attention/attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n",
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n",
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
Expand Down
Loading