Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6d6a8f2
jax integration changes
vthumbe1503 Sep 22, 2025
5ecaaf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2025
8e7b19c
address review comments:
vthumbe1503 Sep 22, 2025
6aa7b8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2025
ea5e243
remove unnecessary imports
vthumbe1503 Sep 22, 2025
ef29ea8
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Sep 22, 2025
88a50e0
address review comments
vthumbe1503 Sep 22, 2025
1d30dcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2025
506e5ef
address review comments:
vthumbe1503 Sep 22, 2025
b2e4588
remove unnecessary imports
vthumbe1503 Sep 22, 2025
c9177b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2025
d89b878
address review comments
vthumbe1503 Sep 22, 2025
0b9ff1e
minor comments and missed hooking up transformer layer to layernorml mlp
vthumbe1503 Sep 22, 2025
4fa022e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2025
b02246c
address review comments:
vthumbe1503 Sep 22, 2025
38230d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2025
5ca6cea
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Sep 22, 2025
0e59e7c
accidentally removed passing a parameter
vthumbe1503 Sep 22, 2025
b096ac0
generic jax params struct working
vthumbe1503 Sep 22, 2025
f5f1c4d
remove unnecessary import
vthumbe1503 Sep 23, 2025
c649ae6
Update transformer_engine/pytorch/ops/basic/activation.py
vthumbe1503 Sep 23, 2025
c8aa905
Update transformer_engine/jax/flax/module.py
vthumbe1503 Sep 23, 2025
563aa11
fix the comment
vthumbe1503 Sep 23, 2025
5589e2d
fix the comment
vthumbe1503 Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
noop_quantizer_set,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.activation import activation, ActivationParams
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense

Expand Down Expand Up @@ -170,29 +170,35 @@ def assert_dequantized_grouped_scaled_tensor(
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]

ACTIVATION_TYPES = {
"L0": [
("gelu",),
("gelu", "linear"),
("clamped_silu", "clamped_linear"),
],
"L2": ALL_ACTIVATION_TYPES,
}


class TestActivation:
def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type).data
def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type, act_params=act_params).data

def value_n_grad_ref_func(self, x, activation_type):
def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit(
value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
)
return jitted_reference(x)

def primitive_func(self, inputs, activation_type, quantizer):
out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out)

@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
Expand All @@ -209,12 +215,16 @@ def test_act_grad(self, shape, activation_type):
x = jnp.repeat(x, len(activation_type), axis=-2)

value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
)

prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)

act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {}
act_params = (
ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)

Expand All @@ -234,17 +244,26 @@ def test_act_grad_with_tensor_scaling_fp8(
self.activation_type = activation_type

value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
)

quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {}

prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_params = (
ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)

assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)
Expand Down Expand Up @@ -273,10 +292,14 @@ def test_act_forward_with_tensor_scaling_fp8(
q_dtype=output_type,
q_layout=q_layout,
)

te_output = tex.act_lu(x, activation_type, te_quantizer)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer)

act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {}
act_params = (
ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output)

@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
Expand All @@ -296,10 +319,14 @@ def test_act_forward_with_block_scaling_fp8(
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)

output = tex.act_lu(x, activation_type, quantizer)
ref_out = self.ref_act(x, activation_type)

act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {}
act_params = (
ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU,
SRELU,
SREGLU,
CLAMPED_SWIGLU,
};

/*! \brief Computes the GeLU activation of the input.
Expand Down
60 changes: 51 additions & 9 deletions transformer_engine/jax/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,54 @@

from typing import Sequence, Union, Callable, Optional
from functools import partial
from dataclasses import dataclass

import jax
import jax.numpy as jnp

import numpy as np
from . import cpp_extensions as tex

from .quantize.tensor import NoScaleTensor
from .quantize.quantizer import Quantizer


@dataclass(frozen=True)
class ClampedSwigluParams:
limit: float = 7.0
alpha: float = 1.702
"""Parameters for the Clamped SwiGLU activation function
used in GPT OSS."""

def __hash__(self):
return hash((self.limit, self.alpha))

def to_ffi_lowering_dict(self):
return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}

@dataclass(frozen=True)
class ActivationParams:
clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams()
# Add other activation-specific parameter fields here as needed in the future
@staticmethod
def create(activation_type, **kwargs):
"""Factory method to create ActivationParams based on activation_type."""
if activation_type == ("clamped_silu", "clamped_linear") or activation_type == "clamped_silu" or activation_type == "clamped_linear":
return ActivationParams(ClampedSwigluParams(**kwargs))
else:
return ActivationParams() # Default params for activations without parameters

def __hash__(self):
return hash((self.clamped_swiglu,))

def to_ffi_lowering_dict(self):
return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()}


def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization.

Expand All @@ -32,17 +66,19 @@ def activation(
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.

Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
output = _activation(x, activation_type, quantizer, act_params)
return output


@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _activation(x, activation_type, quantizer, act_params):
"""Internal implementation of activation with custom VJP.

This function implements the core activation logic with support for
Expand All @@ -52,36 +88,42 @@ def _activation(x, activation_type, quantizer):
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.

Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
_output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params)
return _output


def _activation_fwd_rule(x, activation_type, quantizer):
def _activation_fwd_rule(x, activation_type, quantizer, act_params):
"""Forward pass rule for activation function.

Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.

Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
fwd_output = tex.act_lu(x, activation_type, quantizer, act_params)
# This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)


def _activation_bwd_rule(activation_type, ctx, g):
def _activation_bwd_rule(activation_type, act_params, ctx, g):
"""Backward pass rule for activation function.

Args:
activation_type: Sequence of activation functions
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
ctx: Context from forward pass
g: Gradient from upstream

Expand All @@ -90,7 +132,7 @@ def _activation_bwd_rule(activation_type, ctx, g):
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = tex.dact_lu(g, x, activation_type, act_params=act_params)
# No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor)
Expand Down
Loading