Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
57b4d7b
[JAX] Remove import jax.extend.ffi (#2193)
phu0ngng Sep 22, 2025
5e4e0b2
[PyTorch] Add sink attention support from cuDNN (#2148)
cyanguwa Sep 22, 2025
2db20a6
[QA] Add pytest xml report for all tests in qa folder that use pytest…
shengfangd Sep 23, 2025
a92a0ad
[JAX] Local-Amax for Current-Scaling (#2183)
mingxu1067 Sep 23, 2025
3f875fb
[JAX] Restore Shardy Rule with CompoundFactor (#2167)
phu0ngng Sep 23, 2025
0c17c7e
JAX integration changes
vthumbe1503 Sep 24, 2025
90e070c
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 24, 2025
66c7086
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2025
af19dbf
revert line break
vthumbe1503 Sep 24, 2025
4f29915
revert line break
vthumbe1503 Sep 24, 2025
24828f3
missed adding oss swiglu to nvte enum in common
vthumbe1503 Sep 24, 2025
19410b6
fix jax linting errors
vthumbe1503 Sep 24, 2025
5480d29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2025
7a917ea
fix jax linting errors
vthumbe1503 Sep 24, 2025
53dd179
revert multi_gpu_encoder change
vthumbe1503 Sep 24, 2025
afd15a1
[JAX] Update JAX version requirement in pyproject.toml (#2197)
phu0ngng Sep 24, 2025
d048807
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 25, 2025
3bfae54
fix flax integration bug
vthumbe1503 Sep 25, 2025
9c60c47
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Sep 25, 2025
38382dc
fix linting error
vthumbe1503 Sep 25, 2025
9e72796
[PyTorch] Unpin version of onnxscript and onnxruntime (#2202)
pggPL Sep 26, 2025
c7ef078
bug fixed in other branch and not here
vthumbe1503 Sep 26, 2025
c39ab8d
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 26, 2025
4d14578
[JAX] Fix XML filename in the L0_jax_uniitest (#2205)
phu0ngng Sep 27, 2025
d75bf43
[JAX] CollectiveGemm (#2166)
phu0ngng Sep 27, 2025
a91e458
[JAX] Add xml export for `test_multiprocessing_encoder` and `test_cge…
phu0ngng Sep 29, 2025
8446cc4
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 29, 2025
2a2e6de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
dfeef1a
[JAX] Address tolerance check for current scaling dact dbias (#2211)
jberchtold-nvidia Sep 29, 2025
3f5b475
[Core][PyTorch] NVFP4 recipe (#2177)
ksivaman Sep 29, 2025
b2f4fcb
bug in dbias computation
vthumbe1503 Sep 29, 2025
2354fb8
Fix the segfault in the nvfp4 quantization (#2214)
ptrendx Sep 30, 2025
25252e9
[PyTorch] Add FP8 attention with current scaling (#2012)
cyanguwa Sep 30, 2025
7fa0f55
[Pytorch] Support for Swiglu Activation used in GPT OSS (#2161)
vthumbe1503 Sep 30, 2025
ce18bee
[JAX] Load modules during initialize for Norm and Act primitives (#2219)
jberchtold-nvidia Sep 30, 2025
b7df6b6
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 1, 2025
4f41c1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
7022d50
[PyTorch] Quantizer as API (#2039)
negvet Oct 1, 2025
ac4e0fd
[JAX] Rework amax reduction over TPSP (#2218)
phu0ngng Oct 1, 2025
b0d562d
[JAX] Fix `rng_state` shape in fused attention (#2217)
phu0ngng Oct 1, 2025
ac886c3
[PyTorch] Fix QuantizedTensorBase -> QuantizedTensorStorage (#2226)
negvet Oct 1, 2025
f0a9404
Fix hang during debug build (#2221)
ksivaman Oct 1, 2025
115e528
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 1, 2025
d2072b1
address review comments
vthumbe1503 Oct 1, 2025
6d9df80
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 1, 2025
978fcde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
13a3e3c
minor bug because of merge conflict
vthumbe1503 Oct 1, 2025
df0f449
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 1, 2025
d59526b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
4783514
accept copilot suggestion
vthumbe1503 Oct 1, 2025
cda0c82
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 1, 2025
90449f7
Convert `NVFP4BlockScaling` to dataclass (#2227)
ksivaman Oct 1, 2025
aee5a82
Fix the cuBLAS workspace alignment (#2223)
ptrendx Oct 1, 2025
14f8971
fix test and remove a redundant test addition
vthumbe1503 Oct 1, 2025
b9d7da7
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 1, 2025
c100318
[PyTorch] Set usages for linear op quantizers before forward (#2222)
timmoon10 Oct 2, 2025
f936c2a
[JAX] Fix code block in fp8_autocast docstring (#2228)
jberchtold-nvidia Oct 2, 2025
be7f43f
[JAX] Fix shard map issue when `get_all_mesh_axes()` is used (#2229)
jberchtold-nvidia Oct 2, 2025
6b0c73c
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 2, 2025
e30c36a
[PyTorch] fix int32 overflow in permute kernels (#2196)
hxbai Oct 2, 2025
d462da1
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 2, 2025
5a55b0d
address review comments
vthumbe1503 Oct 3, 2025
bf3e04b
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
run: pip install pybind11[global] nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand Down
152 changes: 152 additions & 0 deletions benchmarks/benchmark_rht_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext

from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer

scale_padding_to = 1
permute_scale = False

TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
}


def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
# Generate random input data
M, K = shape
x = torch.randn([M, K], dtype=input_dtype, device="cuda")

assert shape[0] % 16 == 0, "Shape must be divisible by 16"
assert shape[1] % 16 == 0, "Shape must be divisible by 16"

# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
stochastic_rounding=stochastic_rounding,
)
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, K), dtype=x.dtype, device=x.device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)

with torch.no_grad():
stmt = "kernel_func(input, output)"
globals_dict = {
"kernel_func": nvfp4_quantizer.update_quantized,
"input": x,
"output": x_nvfp4_sut,
}

timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=5)
print(timing)
timing_us = timing.median * 1e6

input_nbytes = shape[0] * shape[1] * 2 # bf16
output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems

total_nbytes = (
0
+ input_nbytes
* 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ 2 * 4 # Output 2 * float for scale & amax
+ 2 * 4 # Input 2 * float
+ output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ sf_nbytes * 2 # Scale factor
)

throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)

print(
f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
f" {throughput_GBps} GB/s"
)
return timing_us, throughput_GBps


# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
args = parser.parse_args()

if args.profile:
print("Profiling is enabled.")
else:
print("Profiling is disabled.")

shapes = [
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]

if args.profile:
shapes = [
(16384, 6144),
]

data = []
for stochastic_rounding in [True]: # , False]:
for shape in shapes:
print(
f"Running benchmark_func with shape {shape} and stochastic_rounding"
f" {stochastic_rounding}"
)
timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
data.append(
[
"benchmark_func",
shape,
stochastic_rounding,
timing_us,
throughput_GBps,
]
)

df = pd.DataFrame(
data=data,
columns=[
"kernel",
"shape",
"stochastic_rounding",
"timing_us",
"throughput(GB/s)",
],
)
print(df)
df.to_csv("benchmark_cast_nvfp4.csv", index=False)
1 change: 1 addition & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,5 @@ def setup_jax_extension(
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"],
)
2 changes: 1 addition & 1 deletion build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]
return ["torch>=2.1", "einops", "onnxscript", "onnx"]


def test_requirements() -> List[str]:
Expand Down
15 changes: 9 additions & 6 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:

@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
archs = os.getenv("NVTE_CUDA_ARCHS")
if archs is None:
version = cuda_version()
if version >= (13, 0):
os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120"
archs = "75;80;89;90;100;100a;103a;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8):
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120"
archs = "70;80;89;90;100;100a;120"
else:
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90"
return os.getenv("NVTE_CUDA_ARCHS")
archs = "70;80;89;90"
return archs


def cuda_version() -> Tuple[int, ...]:
Expand Down
Loading