From fa68781cba1f5673b01a9024f44c5b9b8c0ba00e Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Tue, 17 Feb 2026 08:51:35 -0800 Subject: [PATCH 1/2] Fix `build_tools` missing from sdist causing `uv` cached installs to fail (#2684) - Include `build_tools/` in the source distribution via `MANIFEST.in` so that cached builds from `uv` (and `pip`) can resolve `setup.py`'s top-level imports `setup.py` imports from `build_tools` at the top level: ```python from build_tools.build_ext import CMakeExtension, get_build_ext from build_tools.te_version import te_version from build_tools.utils import cuda_archs, cuda_version, ... ``` The `__legacy__` build backend in `pyproject.toml` adds the source root to `sys.path`, so these imports work when building directly from the source tree. However, `build_tools/` is not included in the sdist because: 1. `MANIFEST.in` did not list it 2. `build_tools/` is not discovered by `find_packages()` (it's a standalone directory at the repo root, not under `transformer_engine/`) When `uv` caches the sdist and later builds a wheel from it, the sdist is extracted to a temporary directory where `build_tools/` is absent, causing a `ModuleNotFoundError`. Passing `--no-cache` to `uv` works around this by forcing a fresh build from the full source tree. Added `build_tools` to `MANIFEST.in`: ```diff recursive-include transformer_engine/common/include *.* +recursive-include build_tools *.py *.txt ``` - [x] `python setup.py sdist` produces a tarball that contains `build_tools/` ``` $ tar tzf dist/transformer_engine-*.tar.gz | grep build_tools transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/ transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/VERSION.txt transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/__init__.py transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/build_ext.py transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/jax.py transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/pytorch.py transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/te_version.py transformer_engine-2.13.0.dev0+82f7ebeb/build_tools/utils.py ``` Signed-off-by: Hemil Desai Co-authored-by: Claude Opus 4.6 --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) diff --git a/MANIFEST.in b/MANIFEST.in index c34025772a..c2309a0370 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ recursive-include transformer_engine/common/include *.* +recursive-include build_tools *.py *.txt From 7e48fa1bace10749e2eabe1da4bbe045bacc005d Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 17 Feb 2026 09:04:04 -0800 Subject: [PATCH 2/2] [JAX] Debugging inspect utility (#2651) * initial debug of inspect ffi Signed-off-by: Jeremy Berchtold * writing binary dumps of tensors works Signed-off-by: Jeremy Berchtold * loading works Signed-off-by: Jeremy Berchtold * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor Signed-off-by: Jeremy Berchtold * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add tensor statistics Signed-off-by: Jeremy Berchtold * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint Signed-off-by: Jeremy Berchtold * Add cuda error check and tests Signed-off-by: Jeremy Berchtold * Ad __init__.py to debug folder Signed-off-by: Jeremy Berchtold * Fix lint Signed-off-by: Jeremy Berchtold * Fix lint Signed-off-by: Jeremy Berchtold * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address greptile comments Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Gate tests behind fp8 support Signed-off-by: Jeremy Berchtold * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_custom_call_compute.py | 34 ++++ transformer_engine/jax/csrc/extensions.h | 3 + .../jax/csrc/extensions/amax.cpp | 2 - .../jax/csrc/extensions/inspect.cpp | 99 ++++++++++ .../jax/csrc/extensions/pybind.cpp | 3 + transformer_engine/jax/debug/__init__.py | 11 ++ .../jax/debug/experimental/__init__.py | 14 ++ .../jax/debug/experimental/inspect.py | 174 ++++++++++++++++++ 8 files changed, 338 insertions(+), 2 deletions(-) create mode 100644 transformer_engine/jax/csrc/extensions/inspect.cpp create mode 100644 transformer_engine/jax/debug/__init__.py create mode 100644 transformer_engine/jax/debug/experimental/__init__.py create mode 100644 transformer_engine/jax/debug/experimental/inspect.py diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 80fcc68843..613aefc178 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1921,3 +1921,37 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + + +class TestDebugInspectFFI: + + @pytest_parametrize_wrapper("shape", [(256, 128)]) + @pytest_parametrize_wrapper( + "dtype", + [ + jnp.float32, + jnp.bfloat16, + jnp.float16, + # Note: fp4 currently doesn't work + # jnp.float4_e2m1fn + ] + + ([jnp.float8_e4m3fn, jnp.float8_e5m2] if is_fp8_supported else []), + ) + def test_debug_inspect_ffi(self, shape, dtype): + from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump + + def f(x): + x = x + 1 + x = inspect_array(x, "my_array") + x = x + 1 + return x + + key = jax.random.PRNGKey(0) + x = jax.random.uniform(key, shape, jnp.float32) + x = x.astype(dtype) + _ = jax.jit(f)(x) + + expected = x + 1 + actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype) + + assert_allclose(actual, expected, dtype=dtype) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3fd086e257..1c0bc52b88 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); +// Inspect +XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 5ffccaffb4..58c89cfd32 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -5,8 +5,6 @@ ************************************************************************/ #include -#include - #include "../extensions.h" #include "transformer_engine/cast.h" #include "transformer_engine/hadamard_transform.h" diff --git a/transformer_engine/jax/csrc/extensions/inspect.cpp b/transformer_engine/jax/csrc/extensions/inspect.cpp new file mode 100644 index 0000000000..9012cd054c --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/inspect.cpp @@ -0,0 +1,99 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include + +#include +#include + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf, + Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf, + Result_Type output_buf) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); + + std::vector input_data(input_buf.size_bytes()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), + input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream)); + + float min_val{}, max_val{}, mean_val{}, std_val{}; + NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + int device; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + + // Write the tensor data to a file as a binary blob + std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; + std::ofstream file(filename, std::ios::binary); + NVTE_CHECK(file.is_open(), "Failed to create file: ", filename); + file.write(reinterpret_cast(input_data.data()), input_data.size()); + file.close(); + + // Write out a metadata file + std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json"; + std::ofstream meta_file(meta_filename); + NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename); + meta_file << "{"; + meta_file << "\"shape\": ["; + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + meta_file << input_buf.dimensions()[i]; + if (i < input_buf.dimensions().size() - 1) { + meta_file << ", "; + } + } + meta_file << "], "; + meta_file << "\"dtype\": " << static_cast(input_buf.element_type()); + meta_file << ", \"min\": " << min_val; + meta_file << ", \"max\": " << max_val; + meta_file << ", \"mean\": " << mean_val; + meta_file << ", \"std\": " << std_val; + meta_file << "}"; + meta_file.close(); + + // Log the tensor metadata to the console + printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str()); + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + printf("%zu", static_cast(input_buf.dimensions()[i])); + if (i < input_buf.dimensions().size() - 1) { + printf(", "); + } + } + printf("], dtype: %d", static_cast(input_buf.element_type())); + printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // min + .Arg() // max + .Arg() // mean + .Arg() // std + .Ret() // output +); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index bd4b8fe2c2..71de897d9b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -81,6 +81,9 @@ pybind11::dict Registrations() { pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + dict["te_inspect_ffi"] = + pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); + return dict; } diff --git a/transformer_engine/jax/debug/__init__.py b/transformer_engine/jax/debug/__init__.py new file mode 100644 index 0000000000..7fcf194d75 --- /dev/null +++ b/transformer_engine/jax/debug/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""EXPERIMENTAL debugging utilities for Transformer Engine JAX. + +This API is experimental and may change or be removed without deprecation in future releases. +""" + +__all__ = [ + "experimental", +] diff --git a/transformer_engine/jax/debug/experimental/__init__.py b/transformer_engine/jax/debug/experimental/__init__.py new file mode 100644 index 0000000000..44a4847660 --- /dev/null +++ b/transformer_engine/jax/debug/experimental/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""EXPERIMENTAL debugging utilities for Transformer Engine JAX. + +This API is experimental and may change or be removed without deprecation in future releases. +""" + +from .inspect import inspect_array, load_array_dump + +__all__ = [ + "inspect_array", + "load_array_dump", +] diff --git a/transformer_engine/jax/debug/experimental/inspect.py b/transformer_engine/jax/debug/experimental/inspect.py new file mode 100644 index 0000000000..9ce46426cf --- /dev/null +++ b/transformer_engine/jax/debug/experimental/inspect.py @@ -0,0 +1,174 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Experimental JAX array inspection utilities.""" + +from functools import partial + +import jax +import jax.numpy as jnp +from jax import ffi + +from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive + +__all__ = ["inspect_array", "load_array_dump"] + + +class InspectPrimitive(BasePrimitive): + """ + No-op used for inspect array values. + """ + + name = "te_inspect_ffi" + multiple_results = False + impl_static_args = () + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + x_min_aval, + x_max_aval, + x_mean_aval, + x_std_aval, + ): + """ + inspect abstract + """ + assert ( + x_min_aval.shape == () and x_min_aval.dtype == jnp.float32 + ), "x_min must be a scalar with dtype float32" + assert ( + x_max_aval.shape == () and x_max_aval.dtype == jnp.float32 + ), "x_max must be a scalar with dtype float32" + assert ( + x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32 + ), "x_mean must be a scalar with dtype float32" + assert ( + x_std_aval.shape == () and x_std_aval.dtype == jnp.float32 + ), "x_std must be a scalar with dtype float32" + return x_aval + + @staticmethod + def lowering( + ctx, + x, + x_min, + x_max, + x_mean, + x_std, + ): + """ + inspect lowering rules + """ + + return ffi.ffi_lowering( + InspectPrimitive.name, + operand_output_aliases={0: 0}, # donate input buffer to output buffer + )( + ctx, + x, + x_min, + x_max, + x_mean, + x_std, + ) + + @staticmethod + def impl( + x, + x_min, + x_max, + x_mean, + x_std, + ): + """ + inspect implementation + """ + assert InspectPrimitive.inner_primitive is not None + (x) = InspectPrimitive.inner_primitive.bind( + x, + x_min, + x_max, + x_mean, + x_std, + ) + return x + + +register_primitive(InspectPrimitive) + + +def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: + assert InspectPrimitive.outer_primitive is not None, ( + "InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built" + " and registered." + ) + return InspectPrimitive.outer_primitive.bind( + x, + jnp.min(x).astype(jnp.float32), + jnp.max(x).astype(jnp.float32), + jnp.mean(x.astype(jnp.float32)), + jnp.std(x.astype(jnp.float32)), + ) + + +@partial(jax.custom_vjp, nondiff_argnums=()) +def _inspect( + x, +): + """ """ + output, _ = _inspect_fwd_rule( + x, + ) + return output + + +def _inspect_fwd_rule( + x, +): + """""" + ctx = () + x = _inspect_array_inner(x) + return x, ctx + + +def _inspect_bwd_rule( + ctx, + grad, +): + """""" + del ctx + return (grad,) + + +_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) + + +def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: + """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. + + Args: + x (jnp.ndarray): The JAX array to inspect. + name (str): The name of the array for identification in the output. + """ + del name # Name is currently unused, but can be included in the future for more informative output + return _inspect(x) + + +def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray: + """Utility function to load a JAX array from a dumped binary file. + + Args: + filename (str): The path to the binary file containing the array data. + shape (tuple): The shape of the array to be loaded. + dtype (jnp.dtype): The data type of the array to be loaded. + + Returns: + jnp.ndarray: The loaded JAX array. + """ + with open(filename, "rb") as f: + data = f.read() + array = jnp.frombuffer(data, dtype=dtype).reshape(shape) + return array