Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
recursive-include transformer_engine/common/include *.*
recursive-include build_tools *.py *.txt
34 changes: 34 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,3 +1921,37 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)


class TestDebugInspectFFI:

@pytest_parametrize_wrapper("shape", [(256, 128)])
@pytest_parametrize_wrapper(
"dtype",
[
jnp.float32,
jnp.bfloat16,
jnp.float16,
# Note: fp4 currently doesn't work
# jnp.float4_e2m1fn
]
+ ([jnp.float8_e4m3fn, jnp.float8_e5m2] if is_fp8_supported else []),
)
def test_debug_inspect_ffi(self, shape, dtype):
from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump

def f(x):
x = x + 1
x = inspect_array(x, "my_array")
x = x + 1
return x

key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32)
x = x.astype(dtype)
_ = jax.jit(f)(x)

expected = x + 1
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)

assert_allclose(actual, expected, dtype=dtype)
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);

// Inspect
XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler);

// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/jax/csrc/extensions/amax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
************************************************************************/
#include <cuda_runtime.h>

#include <iostream>

#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
Expand Down
99 changes: 99 additions & 0 deletions transformer_engine/jax/csrc/extensions/inspect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>

#include <fstream>
#include <iostream>

#include "../extensions.h"
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf,
Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf,
Result_Type output_buf) {
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
NVTE_CHECK(output_buf->untyped_data() != nullptr,
"Output must be provided for inspect operation");
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
"Input and output must point to the same buffer for inspect operation");

std::vector<uint8_t> input_data(input_buf.size_bytes());
NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(),
input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream));

float min_val{}, max_val{}, mean_val{}, std_val{};
NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));

NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));

int device;
NVTE_CHECK_CUDA(cudaGetDevice(&device));

// Write the tensor data to a file as a binary blob
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
std::ofstream file(filename, std::ios::binary);
NVTE_CHECK(file.is_open(), "Failed to create file: ", filename);
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();

// Write out a metadata file
std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json";
std::ofstream meta_file(meta_filename);
NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename);
meta_file << "{";
meta_file << "\"shape\": [";
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
meta_file << input_buf.dimensions()[i];
if (i < input_buf.dimensions().size() - 1) {
meta_file << ", ";
}
}
meta_file << "], ";
meta_file << "\"dtype\": " << static_cast<int>(input_buf.element_type());
meta_file << ", \"min\": " << min_val;
meta_file << ", \"max\": " << max_val;
meta_file << ", \"mean\": " << mean_val;
meta_file << ", \"std\": " << std_val;
meta_file << "}";
meta_file.close();

// Log the tensor metadata to the console
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);

return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // min
.Arg<Buffer_Type>() // max
.Arg<Buffer_Type>() // mean
.Arg<Buffer_Type>() // std
.Ret<Buffer_Type>() // output
);

} // namespace jax
} // namespace transformer_engine
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ pybind11::dict Registrations() {
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));

dict["te_inspect_ffi"] =
pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));

return dict;
}

Expand Down
11 changes: 11 additions & 0 deletions transformer_engine/jax/debug/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.

This API is experimental and may change or be removed without deprecation in future releases.
"""

__all__ = [
"experimental",
]
14 changes: 14 additions & 0 deletions transformer_engine/jax/debug/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.

This API is experimental and may change or be removed without deprecation in future releases.
"""

from .inspect import inspect_array, load_array_dump

__all__ = [
"inspect_array",
"load_array_dump",
]
174 changes: 174 additions & 0 deletions transformer_engine/jax/debug/experimental/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental JAX array inspection utilities."""

from functools import partial

import jax
import jax.numpy as jnp
from jax import ffi

from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive

__all__ = ["inspect_array", "load_array_dump"]


class InspectPrimitive(BasePrimitive):
"""
No-op used for inspect array values.
"""

name = "te_inspect_ffi"
multiple_results = False
impl_static_args = ()
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
x_aval,
x_min_aval,
x_max_aval,
x_mean_aval,
x_std_aval,
):
"""
inspect abstract
"""
assert (
x_min_aval.shape == () and x_min_aval.dtype == jnp.float32
), "x_min must be a scalar with dtype float32"
assert (
x_max_aval.shape == () and x_max_aval.dtype == jnp.float32
), "x_max must be a scalar with dtype float32"
assert (
x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32
), "x_mean must be a scalar with dtype float32"
assert (
x_std_aval.shape == () and x_std_aval.dtype == jnp.float32
), "x_std must be a scalar with dtype float32"
return x_aval

@staticmethod
def lowering(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect lowering rules
"""

return ffi.ffi_lowering(
InspectPrimitive.name,
operand_output_aliases={0: 0}, # donate input buffer to output buffer
)(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
)

@staticmethod
def impl(
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect implementation
"""
assert InspectPrimitive.inner_primitive is not None
(x) = InspectPrimitive.inner_primitive.bind(
x,
x_min,
x_max,
x_mean,
x_std,
)
return x


register_primitive(InspectPrimitive)


def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
assert InspectPrimitive.outer_primitive is not None, (
"InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built"
" and registered."
)
return InspectPrimitive.outer_primitive.bind(
x,
jnp.min(x).astype(jnp.float32),
jnp.max(x).astype(jnp.float32),
jnp.mean(x.astype(jnp.float32)),
jnp.std(x.astype(jnp.float32)),
)


@partial(jax.custom_vjp, nondiff_argnums=())
def _inspect(
x,
):
""" """
output, _ = _inspect_fwd_rule(
x,
)
return output


def _inspect_fwd_rule(
x,
):
""""""
ctx = ()
x = _inspect_array_inner(x)
return x, ctx


def _inspect_bwd_rule(
ctx,
grad,
):
""""""
del ctx
return (grad,)


_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)


def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.

Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
del name # Name is currently unused, but can be included in the future for more informative output
return _inspect(x)


def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:
"""Utility function to load a JAX array from a dumped binary file.

Args:
filename (str): The path to the binary file containing the array data.
shape (tuple): The shape of the array to be loaded.
dtype (jnp.dtype): The data type of the array to be loaded.

Returns:
jnp.ndarray: The loaded JAX array.
"""
with open(filename, "rb") as f:
data = f.read()
array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
return array
Loading