From 6d6a8f2e3c5dc1670e909c7132a0f649c595d166 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 16:52:26 +0000 Subject: [PATCH 01/21] jax integration changes --- tests/jax/test_custom_call_compute.py | 48 +++---- .../include/transformer_engine/activation.h | 1 + transformer_engine/jax/activation.py | 19 +-- .../jax/cpp_extensions/activation.py | 130 +++++++++++++----- transformer_engine/jax/csrc/extensions.h | 9 ++ .../jax/csrc/extensions/activation.cpp | 22 ++- .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/jax/flax/module.py | 13 +- transformer_engine/jax/flax/transformer.py | 4 + transformer_engine/jax/layernorm_mlp.py | 12 +- 10 files changed, 183 insertions(+), 76 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9e39b84c0b..776fa8ed55 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -18,7 +18,7 @@ from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias +from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias, ClampedSwigluParams from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, _jax_rmsnorm, @@ -170,29 +170,31 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), + ("clamped_silu","clamped_linear"), ] ACTIVATION_TYPES = { "L0": [ ("gelu",), ("gelu", "linear"), + ("clamped_silu", "clamped_linear"), ], "L2": ALL_ACTIVATION_TYPES, } class TestActivation: - def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type).data + def ref_act(self, x, activation_type, act_params): + return _jax_act_lu(x, activation_type, act_params=act_params).data - def value_n_grad_ref_func(self, x, activation_type): + def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) + value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)) ) return jitted_reference(x) - def primitive_func(self, inputs, activation_type, quantizer): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer) + def primitive_func(self, inputs, activation_type, quantizer, act_params): + out = activation(inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -209,12 +211,11 @@ def test_act_grad(self, shape, activation_type): x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_params = ClampedSwigluParams.create(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -234,7 +235,7 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -242,10 +243,10 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_params = ClampedSwigluParams.create(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) + assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -274,9 +275,9 @@ def test_act_forward_with_tensor_scaling_fp8( q_layout=q_layout, ) - te_output = tex.act_lu(x, activation_type, te_quantizer) - jax_output = _jax_act_lu(x, activation_type, jax_quantizer) - + act_params = ClampedSwigluParams.create(limit=1.0, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) + jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -296,10 +297,9 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - - output = tex.act_lu(x, activation_type, quantizer) - ref_out = self.ref_act(x, activation_type) - + act_params = ClampedSwigluParams.create(limit=7.0, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + output = tex.act_lu(x, activation_type, quantizer, act_params) + ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index e50d71040d..10a6459c3a 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { QGEGLU, SRELU, SREGLU, + CLAMPED_SWIGLU, }; /*! \brief Computes the GeLU activation of the input. diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 12b35ec43c..c210be228c 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -16,12 +16,13 @@ from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer - +from .cpp_extensions.activation import ClampedSwigluParams def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ClampedSwigluParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. @@ -37,12 +38,12 @@ def activation( Activated output tensor """ assert x.shape[-1] % len(activation_type) == 0 - output = _activation(x, activation_type, quantizer) + output = _activation(x, activation_type, quantizer, act_params) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _activation(x, activation_type, quantizer): +@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +def _activation(x, activation_type, quantizer, act_params=None): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -56,11 +57,11 @@ def _activation(x, activation_type, quantizer): Returns: Activated tensor """ - _output, _ = _activation_fwd_rule(x, activation_type, quantizer) + _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params) return _output -def _activation_fwd_rule(x, activation_type, quantizer): +def _activation_fwd_rule(x, activation_type, quantizer, act_params): """Forward pass rule for activation function. Args: @@ -71,13 +72,13 @@ def _activation_fwd_rule(x, activation_type, quantizer): Returns: Tuple of (output, context) for backward pass """ - fwd_output = tex.act_lu(x, activation_type, quantizer) + fwd_output = tex.act_lu(x, activation_type, quantizer, act_params) # This is a no-op for higher-precision tensors fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) -def _activation_bwd_rule(activation_type, ctx, g): +def _activation_bwd_rule(activation_type, act_params, ctx, g): """Backward pass rule for activation function. Args: @@ -90,7 +91,7 @@ def _activation_bwd_rule(activation_type, ctx, g): """ (x, _) = ctx assert x.dtype == g.dtype - dx = tex.dact_lu(g, x, activation_type) + dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should # always be a NoScaleTensor assert isinstance(dx, NoScaleTensor) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index cdda201668..c77d2968f4 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -13,9 +13,10 @@ from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec +import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type - +from dataclasses import dataclass from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -56,17 +57,38 @@ ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, + ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } - -def _convert_to_activation_function(fn_or_string): +@dataclass(frozen=True) +class ClampedSwigluParams: + limit: float = 7.0 + alpha: float = 1.702 + @staticmethod + def create(limit: float = 7.0, alpha: float = 1.702): + return ClampedSwigluParams(limit=limit, alpha=alpha) + def __hash__(self): + return hash(self.limit) + def __eq__(self, value): + if not isinstance(value, ClampedSwigluParams): + return False + return self.limit == value.limit and self.alpha == value.alpha + def to_ffi_lowering_dict(self): + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + +def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + if fn_or_string == "clamped_linear": + return lambda x: jnp.clip(x, min=-act_params.limit, max=act_params.limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if fn_or_string == "clamped_silu": + return lambda x: jax.nn.sigmoid(act_params.alpha * jnp.minimum(x, act_params.limit)) \ + * jnp.minimum(x, act_params.limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -89,7 +111,8 @@ class ActLuPrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer + 9, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params inner_primitive = None outer_primitive = None @@ -105,11 +128,12 @@ def abstract( is_2x, scale_dtype, is_outer, + act_params, ): """ te_act_lu_p abstract """ - del act_enum + del act_enum, act_params dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 @@ -155,7 +179,8 @@ def lowering( is_2x, scale_dtype, is_outer, - ): + act_params, + ): """ te_gated_act_lu_p lowering rules """ @@ -163,9 +188,8 @@ def lowering( x_aval, scale_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x + ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x, act_params=act_params.to_ffi_lowering_dict() ) return out @@ -180,13 +204,14 @@ def impl( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe implementation """ del is_outer assert ActLuPrimitive.inner_primitive is not None - + import numpy as np out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = ( ActLuPrimitive.inner_primitive.bind( x, @@ -198,6 +223,7 @@ def impl( is_2x=is_2x, scale_dtype=scale_dtype, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -226,6 +252,7 @@ def batcher( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -236,7 +263,7 @@ def batcher( x, scale = batched_args x_bdim, scale_bdim = batch_dims amax_bdim = scale_bdim - + import numpy as np out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim return ( ActLuPrimitive.outer_primitive.bind( @@ -247,6 +274,7 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + act_params=act_params, ), out_bdims, ) @@ -260,6 +288,7 @@ def infer_sharding_from_operands( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -271,6 +300,7 @@ def infer_sharding_from_operands( scale_dtype, act_len, is_outer, + act_params, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -323,6 +353,7 @@ def partition( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -383,6 +414,7 @@ def sharded_impl(x, scale): is_2x=is_2x, scale_dtype=scale_dtype, is_outer=True, + act_params=act_params, ) ) @@ -410,11 +442,12 @@ def shardy_sharding_rule( is_2x, scale_dtype, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types + del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params prefix = "ActLuPrimitive_" x_rank = len(value_types[0].shape) scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( @@ -458,8 +491,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params + impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) inner_primitive = None outer_primitive = None @@ -477,6 +510,7 @@ def abstract( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p abstract @@ -533,6 +567,7 @@ def abstract( jax_dtype_to_te_dtype(out_dtype), scaling_mode, is_2x, + act_params, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -578,6 +613,7 @@ def lowering( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p lowering rules @@ -596,6 +632,7 @@ def lowering( is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), + act_params=act_params.to_ffi_lowering_dict(), ) @staticmethod @@ -611,10 +648,12 @@ def impl( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p impl """ + import numpy as np del is_outer assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( @@ -630,6 +669,7 @@ def impl( act_enum=act_enum, act_len=act_len, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -658,6 +698,7 @@ def batcher( act_enum, act_len, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -688,6 +729,7 @@ def batcher( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, ), out_bdims, ) @@ -702,11 +744,12 @@ def infer_sharding_from_operands( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum + del out_dtype, result_infos, act_enum, act_params del scale_dtype, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -777,6 +820,7 @@ def partition( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -857,6 +901,7 @@ def sharded_impl(dz, x, scale): act_enum=act_enum, act_len=act_len, is_outer=True, + act_params=act_params, ) ) if is_dbias: @@ -883,11 +928,12 @@ def shardy_sharding_rule( act_enum, act_len, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types + del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params prefix = "BaseDActLuDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 @@ -922,20 +968,21 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu(inputs, activation_type, quantizer=None, act_params: Optional[ClampedSwigluParams] = None) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ + act_params = act_params if act_params is not None else ClampedSwigluParams.create() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ClampedSwigluParams.create() x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) + x_i = _convert_to_activation_function(act_fn, act_params)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) x = jnp.squeeze(x, axis=-2) @@ -950,10 +997,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ClampedSwigluParams] = None ): """ JAX implementation of dact_lu and dbias with optional quantization """ + act_params = act_params if act_params is not None else ClampedSwigluParams.create() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -961,7 +1010,7 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), x.astype(jnp.float32) ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -979,11 +1028,14 @@ def _jax_quantize_dact_dbias( return dx, dbias +from dataclasses import dataclass + def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ClampedSwigluParams] = None ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1005,24 +1057,22 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ClampedSwigluParams.create() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer) + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_act_lu(x, activation_type, quantizer) - + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer + f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) - if quantizer is None: out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, @@ -1034,6 +1084,7 @@ def act_lu( is_2x=False, scale_dtype=jnp.float32, is_outer=True, + act_params=act_params, ) out = out.reshape(output_shape) out = NoScaleTensor( @@ -1045,13 +1096,13 @@ def act_lu( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( - x=x, + x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, + act_params=act_params ) out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) return out - if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1071,6 +1122,7 @@ def act_lu( is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), is_outer=True, + act_params=act_params, ) quantizer.update(updated_amax) @@ -1093,6 +1145,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ClampedSwigluParams] = None, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1109,7 +1162,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - + act_params = act_params if act_params is not None else ClampedSwigluParams.create() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1122,8 +1175,7 @@ def quantize_dact_dbias( if not PrimitiveClass.enabled() or ( quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( dz, @@ -1139,6 +1191,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) output = output.astype(x.dtype) dbias = None @@ -1154,7 +1207,7 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None, act_params=act_params ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1171,6 +1224,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, + act_params=act_params ) if war_output is not None: return war_output @@ -1178,10 +1232,11 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz, - x=x, + dz=dz.astype(jnp.float32), + x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, + act_params=act_params ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1194,11 +1249,10 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type + dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type, act_params=act_params ) out, dbias = _quantize_dbias_impl( - dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 - ) + dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2) return out, dbias ( @@ -1220,6 +1274,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1248,6 +1303,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ClampedSwigluParams] = None ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1261,11 +1317,13 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ + act_params = act_params if act_params is not None else ClampedSwigluParams.create() output, _ = quantize_dact_dbias( dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=quantizer, + act_params=act_params ) return output diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 59079fe3f0..4a1a4c3104 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -38,6 +38,11 @@ XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); namespace transformer_engine { namespace jax { +struct ClampedSwigluConfig { + float limit; + float alpha; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -134,4 +139,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 17fa9906bb..ffa1a826ef 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int) { + bool is_2x_int, ClampedSwigluConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.limit; + auto swiglu_alpha = act_params.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -125,6 +128,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::SREGLU: nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -145,7 +151,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("is_2x") + .Attr("act_params"), // Can generalize the config later if we have more activations that need params FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -216,7 +223,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias) { + int64_t act_enum, bool is_2x, bool is_dbias, ClampedSwigluConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.limit; + auto swiglu_alpha = act_params.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -383,6 +393,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::SREGLU: nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -408,7 +421,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias"), + .Attr("is_dbias") + .Attr("act_params"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index afbeb644c1..08600fd3f4 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -133,6 +133,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) .value("SREGLU", NVTE_Activation_Type::SREGLU) + .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU) .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c548c54efa..a3fad35806 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -898,6 +898,12 @@ class LayerNormMLP(TransformerEngineBase): activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. + activation_params: dict, default = None + The parameters for the ClampedSwiglu activation used in GPT OSS. This is only + used when ('clamped_silu', 'clamped_linear') is in :attr:`activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. If there is more activation + functions that require parameters in the future, we might need to change it to a more gerneric + parameter container. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1 @@ -955,7 +961,8 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ("relu",) + activations: Sequence[Union[str, Callable]] = ("relu",), + activation_params: dict = None, intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1023,6 +1030,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ("relu", "linear"), ("quick_gelu", "linear"), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] @@ -1031,7 +1039,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + reversed(normalized_acts) if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) @@ -1150,6 +1158,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=self.ffn1_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, + activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fb3ac7b9ae..6bfe031bb1 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1631,6 +1631,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_activations: Sequence[str], default = ('relu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + mlp_activation_params: dict = None + This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases. @@ -1751,6 +1754,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None mlp_activations: Sequence[str] = ("relu",) + mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801af..07c9a06f20 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -48,6 +48,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + activation_params: dict = None, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -129,6 +130,7 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params quantizer_sets, ) return output @@ -154,6 +156,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + activation_params: dict, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -203,6 +206,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params quantizer_sets, ) return output @@ -227,6 +231,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -306,10 +311,13 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) + # At the moment the act_params is only used for ClampedSwiglu + # If there are more activations that require parameters in the future + # we might need to change it to a more generic parameter container casted_act_out = tex.act_lu( dot_1_output, activation_type, - quantizer=ffn2_quantizer_set.x, + act_params=tex.ClampedSwigluParams.create(**activation_params) if activation_params else None ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -371,6 +379,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, ctx, grad, ): @@ -457,6 +466,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + act_params=tex.ClampedSwigluParams.create(**activation_params) if activation_params else None ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From 5ecaaf8a7a183d4ece00610f27fbb8f31839ec5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:55:50 +0000 Subject: [PATCH 02/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 49 +++++++++++---- transformer_engine/jax/activation.py | 1 + .../jax/cpp_extensions/activation.py | 60 ++++++++++++++----- transformer_engine/jax/csrc/extensions.h | 11 ++-- .../jax/csrc/extensions/activation.cpp | 41 +++++++------ transformer_engine/jax/flax/module.py | 8 ++- 6 files changed, 116 insertions(+), 54 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 776fa8ed55..88ef5796ca 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -18,7 +18,11 @@ from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias, ClampedSwigluParams +from transformer_engine.jax.cpp_extensions.activation import ( + _jax_act_lu, + _jax_quantize_dact_dbias, + ClampedSwigluParams, +) from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, _jax_rmsnorm, @@ -170,7 +174,7 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), - ("clamped_silu","clamped_linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { @@ -189,12 +193,16 @@ def ref_act(self, x, activation_type, act_params): def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) def primitive_func(self, inputs, activation_type, quantizer, act_params): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params) + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -213,7 +221,11 @@ def test_act_grad(self, shape, activation_type): value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - act_params = ClampedSwigluParams.create(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + act_params = ( + ClampedSwigluParams.create(limit=0.75, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) @@ -235,7 +247,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3), + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -243,10 +256,16 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) - act_params = ClampedSwigluParams.create(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer, act_params) + act_params = ( + ClampedSwigluParams.create(limit=0.75, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) - + assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -275,7 +294,11 @@ def test_act_forward_with_tensor_scaling_fp8( q_layout=q_layout, ) - act_params = ClampedSwigluParams.create(limit=1.0, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + act_params = ( + ClampedSwigluParams.create(limit=1.0, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -297,7 +320,11 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - act_params = ClampedSwigluParams.create(limit=7.0, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + act_params = ( + ClampedSwigluParams.create(limit=7.0, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) output = tex.act_lu(x, activation_type, quantizer, act_params) ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index c210be228c..32d482198c 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -18,6 +18,7 @@ from .quantize.quantizer import Quantizer from .cpp_extensions.activation import ClampedSwigluParams + def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index c77d2968f4..72de7f91c3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -60,22 +60,28 @@ ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } + @dataclass(frozen=True) class ClampedSwigluParams: limit: float = 7.0 alpha: float = 1.702 + @staticmethod def create(limit: float = 7.0, alpha: float = 1.702): return ClampedSwigluParams(limit=limit, alpha=alpha) + def __hash__(self): return hash(self.limit) + def __eq__(self, value): if not isinstance(value, ClampedSwigluParams): return False return self.limit == value.limit and self.alpha == value.alpha + def to_ffi_lowering_dict(self): return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" if fn_or_string == "linear": @@ -87,8 +93,9 @@ def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSw if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) if fn_or_string == "clamped_silu": - return lambda x: jax.nn.sigmoid(act_params.alpha * jnp.minimum(x, act_params.limit)) \ - * jnp.minimum(x, act_params.limit) + return lambda x: jax.nn.sigmoid( + act_params.alpha * jnp.minimum(x, act_params.limit) + ) * jnp.minimum(x, act_params.limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -180,7 +187,7 @@ def lowering( scale_dtype, is_outer, act_params, - ): + ): """ te_gated_act_lu_p lowering rules """ @@ -189,7 +196,13 @@ def lowering( assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x, act_params=act_params.to_ffi_lowering_dict() + ctx, + x, + scale, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + is_2x=is_2x, + act_params=act_params.to_ffi_lowering_dict(), ) return out @@ -212,6 +225,7 @@ def impl( del is_outer assert ActLuPrimitive.inner_primitive is not None import numpy as np + out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = ( ActLuPrimitive.inner_primitive.bind( x, @@ -264,6 +278,7 @@ def batcher( x_bdim, scale_bdim = batch_dims amax_bdim = scale_bdim import numpy as np + out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim return ( ActLuPrimitive.outer_primitive.bind( @@ -654,6 +669,7 @@ def impl( te_dact_dbias_quantize_p impl """ import numpy as np + del is_outer assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( @@ -968,7 +984,9 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None, act_params: Optional[ClampedSwigluParams] = None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ClampedSwigluParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ @@ -997,7 +1015,7 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None + act_params: Optional[ClampedSwigluParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization @@ -1010,7 +1028,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -1028,6 +1047,7 @@ def _jax_quantize_dact_dbias( return dx, dbias + from dataclasses import dataclass @@ -1035,7 +1055,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None + act_params: Optional[ClampedSwigluParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1099,7 +1119,7 @@ def act_lu( x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, - act_params=act_params + act_params=act_params, ) out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) return out @@ -1207,7 +1227,11 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None, act_params=act_params + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1224,7 +1248,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, - act_params=act_params + act_params=act_params, ) if war_output is not None: return war_output @@ -1236,7 +1260,7 @@ def quantize_dact_dbias( x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, - act_params=act_params + act_params=act_params, ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1249,10 +1273,14 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type, act_params=act_params + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( - dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2) + dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + ) return out, dbias ( @@ -1303,7 +1331,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None + act_params: Optional[ClampedSwigluParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1324,6 +1352,6 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, - act_params=act_params + act_params=act_params, ) return output diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 4a1a4c3104..c3f730020d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -39,8 +39,8 @@ namespace transformer_engine { namespace jax { struct ClampedSwigluConfig { - float limit; - float alpha; + float limit; + float alpha; }; inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } @@ -139,8 +139,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::ClampedSwigluConfig, - ::xla::ffi::StructMember("limit"), - ::xla::ffi::StructMember("alpha")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index ffa1a826ef..31f346b2b5 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -129,7 +129,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; case NVTE_Activation_Type::CLAMPED_SWIGLU: - nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, stream); + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -139,21 +140,23 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Arg() // scale - .Ret() // output - .Ret() // colwise output - .Ret() // scale_inv - .Ret() // scale_inv colwise - .Ret() // amax - .Attr("act_enum") - .Attr("scaling_mode") - .Attr("is_2x") - .Attr("act_params"), // Can generalize the config later if we have more activations that need params - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + ActLuHandler, ActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x") + .Attr( + "act_params"), // Can generalize the config later if we have more activations that need params + FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -223,7 +226,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias, ClampedSwigluConfig act_params) { + int64_t act_enum, bool is_2x, bool is_dbias, + ClampedSwigluConfig act_params) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.limit; auto swiglu_alpha = act_params.alpha; @@ -394,7 +398,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; case NVTE_Activation_Type::CLAMPED_SWIGLU: - nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, stream); + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a3fad35806..8d73c4bf1a 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -961,8 +961,8 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ("relu",), - activation_params: dict = None, + activations: Sequence[Union[str, Callable]] = (("relu",),) + activation_params: dict = (None,) intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1039,7 +1039,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) From 8e7b19c5d0fcada964255ae8af0f2d2b0cfff428 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:16:21 +0000 Subject: [PATCH 03/21] address review comments: Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 13 +++---- transformer_engine/jax/activation.py | 18 ++++++++- .../jax/cpp_extensions/activation.py | 39 ++++--------------- transformer_engine/jax/flax/module.py | 4 +- transformer_engine/jax/layernorm_mlp.py | 9 +++-- 5 files changed, 37 insertions(+), 46 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 88ef5796ca..2726c08134 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -20,8 +20,7 @@ from transformer_engine.jax.cpp_extensions.activation import ( _jax_act_lu, - _jax_quantize_dact_dbias, - ClampedSwigluParams, + _jax_quantize_dact_dbias ) from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, @@ -46,7 +45,7 @@ noop_quantizer_set, ) from transformer_engine.jax.quantize import helper -from transformer_engine.jax.activation import activation +from transformer_engine.jax.activation import activation, ClampedSwigluParams from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense @@ -222,7 +221,7 @@ def test_act_grad(self, shape, activation_type): value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) act_params = ( - ClampedSwigluParams.create(limit=0.75, alpha=1.702) + ClampedSwigluParams(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -257,7 +256,7 @@ def test_act_grad_with_tensor_scaling_fp8( q_layout=QuantizeLayout.ROWWISE, ) act_params = ( - ClampedSwigluParams.create(limit=0.75, alpha=1.702) + ClampedSwigluParams(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -295,7 +294,7 @@ def test_act_forward_with_tensor_scaling_fp8( ) act_params = ( - ClampedSwigluParams.create(limit=1.0, alpha=1.702) + ClampedSwigluParams(limit=1.0, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -321,7 +320,7 @@ def test_act_forward_with_block_scaling_fp8( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) act_params = ( - ClampedSwigluParams.create(limit=7.0, alpha=1.702) + ClampedSwigluParams(limit=7.0, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 32d482198c..e8a749a13e 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -8,15 +8,29 @@ from typing import Sequence, Union, Callable, Optional from functools import partial +from dataclasses import dataclass import jax import jax.numpy as jnp - +import numpy as np from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer -from .cpp_extensions.activation import ClampedSwigluParams + +@dataclass(frozen=True) +class ClampedSwigluParams: + limit: float = 7.0 + alpha: float = 1.702 + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + def __hash__(self): + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + def activation( diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 72de7f91c3..f6d74b0a9b 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -37,6 +37,7 @@ DelayedScaleQuantizer, ScalingMode, ) +from ..activation import ClampedSwigluParams if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax import ffi # pylint: disable=ungrouped-imports @@ -61,26 +62,6 @@ } -@dataclass(frozen=True) -class ClampedSwigluParams: - limit: float = 7.0 - alpha: float = 1.702 - - @staticmethod - def create(limit: float = 7.0, alpha: float = 1.702): - return ClampedSwigluParams(limit=limit, alpha=alpha) - - def __hash__(self): - return hash(self.limit) - - def __eq__(self, value): - if not isinstance(value, ClampedSwigluParams): - return False - return self.limit == value.limit and self.alpha == value.alpha - - def to_ffi_lowering_dict(self): - return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} - def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" @@ -224,7 +205,6 @@ def impl( """ del is_outer assert ActLuPrimitive.inner_primitive is not None - import numpy as np out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = ( ActLuPrimitive.inner_primitive.bind( @@ -277,7 +257,6 @@ def batcher( x, scale = batched_args x_bdim, scale_bdim = batch_dims amax_bdim = scale_bdim - import numpy as np out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim return ( @@ -668,8 +647,6 @@ def impl( """ te_dact_dbias_quantize_p impl """ - import numpy as np - del is_outer assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( @@ -990,13 +967,13 @@ def _jax_act_lu( """ JAX native activation implementation """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): @@ -1020,7 +997,7 @@ def _jax_quantize_dact_dbias( """ JAX implementation of dact_lu and dbias with optional quantization """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1077,7 +1054,7 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() if not ActLuPrimitive.enabled(): return _jax_act_lu(x, activation_type, quantizer, act_params) @@ -1116,7 +1093,7 @@ def act_lu( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( - x=x.astype(jnp.float32), + x=x, activation_type=activation_type, quantizer=None, act_params=act_params, @@ -1182,7 +1159,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1345,7 +1322,7 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() output, _ = quantize_dact_dbias( dz=dz, x=x, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8d73c4bf1a..9234a1d083 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -961,8 +961,8 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = (("relu",),) - activation_params: dict = (None,) + activations: Sequence[Union[str, Callable]] = ("relu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 07c9a06f20..a815710b34 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,6 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .activation import ClampedSwigluParams from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -130,7 +131,7 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - activation_params + activation_params, quantizer_sets, ) return output @@ -206,7 +207,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - activation_params + activation_params, quantizer_sets, ) return output @@ -317,7 +318,7 @@ def _layernorm_mlp_fwd_rule( casted_act_out = tex.act_lu( dot_1_output, activation_type, - act_params=tex.ClampedSwigluParams.create(**activation_params) if activation_params else None + act_params=ClampedSwigluParams(**activation_params) if activation_params else None ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -466,7 +467,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - act_params=tex.ClampedSwigluParams.create(**activation_params) if activation_params else None + act_params=tex.ClampedSwigluParams(**activation_params) if activation_params else None ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From 6aa7b8dc9d1b819b1916f4a59371782f42351776 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:17:46 +0000 Subject: [PATCH 04/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 5 +---- transformer_engine/jax/activation.py | 2 +- transformer_engine/jax/cpp_extensions/activation.py | 1 - transformer_engine/jax/layernorm_mlp.py | 4 ++-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 2726c08134..3d1d7f1507 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -18,10 +18,7 @@ from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.cpp_extensions.activation import ( - _jax_act_lu, - _jax_quantize_dact_dbias -) +from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, _jax_rmsnorm, diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index e8a749a13e..196bf665b6 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -18,6 +18,7 @@ from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer + @dataclass(frozen=True) class ClampedSwigluParams: limit: float = 7.0 @@ -32,7 +33,6 @@ def to_ffi_lowering_dict(self): return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} - def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index f6d74b0a9b..3447017859 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -62,7 +62,6 @@ } - def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" if fn_or_string == "linear": diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index a815710b34..1f4882cb09 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -318,7 +318,7 @@ def _layernorm_mlp_fwd_rule( casted_act_out = tex.act_lu( dot_1_output, activation_type, - act_params=ClampedSwigluParams(**activation_params) if activation_params else None + act_params=ClampedSwigluParams(**activation_params) if activation_params else None, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -467,7 +467,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - act_params=tex.ClampedSwigluParams(**activation_params) if activation_params else None + act_params=tex.ClampedSwigluParams(**activation_params) if activation_params else None, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From ea5e24301235a4760ba6eaad715e6314df29f824 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:20:19 +0000 Subject: [PATCH 05/21] remove unnecessary imports Signed-off-by: Varun Thumbe --- transformer_engine/jax/cpp_extensions/activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index f6d74b0a9b..7952dd0c1c 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -16,7 +16,6 @@ import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type -from dataclasses import dataclass from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, From 88a50e00fd229c82db39615f7c1d87997fcc2d01 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:22:10 +0000 Subject: [PATCH 06/21] address review comments Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3d1d7f1507..0f2988dc5b 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -291,7 +291,7 @@ def test_act_forward_with_tensor_scaling_fp8( ) act_params = ( - ClampedSwigluParams(limit=1.0, alpha=1.702) + ClampedSwigluParams(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) From 1d30dcde562088ec013ff508ae8a1b793f3ce64b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:55:50 +0000 Subject: [PATCH 07/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 49 +++++++++++---- transformer_engine/jax/activation.py | 1 + .../jax/cpp_extensions/activation.py | 60 ++++++++++++++----- transformer_engine/jax/csrc/extensions.h | 11 ++-- .../jax/csrc/extensions/activation.cpp | 41 +++++++------ transformer_engine/jax/flax/module.py | 8 ++- 6 files changed, 116 insertions(+), 54 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 776fa8ed55..88ef5796ca 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -18,7 +18,11 @@ from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias, ClampedSwigluParams +from transformer_engine.jax.cpp_extensions.activation import ( + _jax_act_lu, + _jax_quantize_dact_dbias, + ClampedSwigluParams, +) from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, _jax_rmsnorm, @@ -170,7 +174,7 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), - ("clamped_silu","clamped_linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { @@ -189,12 +193,16 @@ def ref_act(self, x, activation_type, act_params): def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) def primitive_func(self, inputs, activation_type, quantizer, act_params): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params) + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -213,7 +221,11 @@ def test_act_grad(self, shape, activation_type): value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - act_params = ClampedSwigluParams.create(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + act_params = ( + ClampedSwigluParams.create(limit=0.75, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) @@ -235,7 +247,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3), + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -243,10 +256,16 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) - act_params = ClampedSwigluParams.create(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer, act_params) + act_params = ( + ClampedSwigluParams.create(limit=0.75, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) - + assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -275,7 +294,11 @@ def test_act_forward_with_tensor_scaling_fp8( q_layout=q_layout, ) - act_params = ClampedSwigluParams.create(limit=1.0, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + act_params = ( + ClampedSwigluParams.create(limit=1.0, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -297,7 +320,11 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - act_params = ClampedSwigluParams.create(limit=7.0, alpha=1.702) if activation_type == ("clamped_silu","clamped_linear") else None + act_params = ( + ClampedSwigluParams.create(limit=7.0, alpha=1.702) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) output = tex.act_lu(x, activation_type, quantizer, act_params) ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index c210be228c..32d482198c 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -18,6 +18,7 @@ from .quantize.quantizer import Quantizer from .cpp_extensions.activation import ClampedSwigluParams + def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index c77d2968f4..72de7f91c3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -60,22 +60,28 @@ ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } + @dataclass(frozen=True) class ClampedSwigluParams: limit: float = 7.0 alpha: float = 1.702 + @staticmethod def create(limit: float = 7.0, alpha: float = 1.702): return ClampedSwigluParams(limit=limit, alpha=alpha) + def __hash__(self): return hash(self.limit) + def __eq__(self, value): if not isinstance(value, ClampedSwigluParams): return False return self.limit == value.limit and self.alpha == value.alpha + def to_ffi_lowering_dict(self): return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" if fn_or_string == "linear": @@ -87,8 +93,9 @@ def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSw if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) if fn_or_string == "clamped_silu": - return lambda x: jax.nn.sigmoid(act_params.alpha * jnp.minimum(x, act_params.limit)) \ - * jnp.minimum(x, act_params.limit) + return lambda x: jax.nn.sigmoid( + act_params.alpha * jnp.minimum(x, act_params.limit) + ) * jnp.minimum(x, act_params.limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -180,7 +187,7 @@ def lowering( scale_dtype, is_outer, act_params, - ): + ): """ te_gated_act_lu_p lowering rules """ @@ -189,7 +196,13 @@ def lowering( assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x, act_params=act_params.to_ffi_lowering_dict() + ctx, + x, + scale, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + is_2x=is_2x, + act_params=act_params.to_ffi_lowering_dict(), ) return out @@ -212,6 +225,7 @@ def impl( del is_outer assert ActLuPrimitive.inner_primitive is not None import numpy as np + out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = ( ActLuPrimitive.inner_primitive.bind( x, @@ -264,6 +278,7 @@ def batcher( x_bdim, scale_bdim = batch_dims amax_bdim = scale_bdim import numpy as np + out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim return ( ActLuPrimitive.outer_primitive.bind( @@ -654,6 +669,7 @@ def impl( te_dact_dbias_quantize_p impl """ import numpy as np + del is_outer assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( @@ -968,7 +984,9 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None, act_params: Optional[ClampedSwigluParams] = None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ClampedSwigluParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ @@ -997,7 +1015,7 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None + act_params: Optional[ClampedSwigluParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization @@ -1010,7 +1028,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -1028,6 +1047,7 @@ def _jax_quantize_dact_dbias( return dx, dbias + from dataclasses import dataclass @@ -1035,7 +1055,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None + act_params: Optional[ClampedSwigluParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1099,7 +1119,7 @@ def act_lu( x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, - act_params=act_params + act_params=act_params, ) out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) return out @@ -1207,7 +1227,11 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None, act_params=act_params + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1224,7 +1248,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, - act_params=act_params + act_params=act_params, ) if war_output is not None: return war_output @@ -1236,7 +1260,7 @@ def quantize_dact_dbias( x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, - act_params=act_params + act_params=act_params, ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1249,10 +1273,14 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type, act_params=act_params + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( - dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2) + dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + ) return out, dbias ( @@ -1303,7 +1331,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None + act_params: Optional[ClampedSwigluParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1324,6 +1352,6 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, - act_params=act_params + act_params=act_params, ) return output diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 4a1a4c3104..c3f730020d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -39,8 +39,8 @@ namespace transformer_engine { namespace jax { struct ClampedSwigluConfig { - float limit; - float alpha; + float limit; + float alpha; }; inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } @@ -139,8 +139,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::ClampedSwigluConfig, - ::xla::ffi::StructMember("limit"), - ::xla::ffi::StructMember("alpha")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index ffa1a826ef..31f346b2b5 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -129,7 +129,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; case NVTE_Activation_Type::CLAMPED_SWIGLU: - nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, stream); + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -139,21 +140,23 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Arg() // scale - .Ret() // output - .Ret() // colwise output - .Ret() // scale_inv - .Ret() // scale_inv colwise - .Ret() // amax - .Attr("act_enum") - .Attr("scaling_mode") - .Attr("is_2x") - .Attr("act_params"), // Can generalize the config later if we have more activations that need params - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + ActLuHandler, ActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x") + .Attr( + "act_params"), // Can generalize the config later if we have more activations that need params + FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -223,7 +226,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias, ClampedSwigluConfig act_params) { + int64_t act_enum, bool is_2x, bool is_dbias, + ClampedSwigluConfig act_params) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.limit; auto swiglu_alpha = act_params.alpha; @@ -394,7 +398,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; case NVTE_Activation_Type::CLAMPED_SWIGLU: - nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, stream); + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a3fad35806..8d73c4bf1a 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -961,8 +961,8 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ("relu",), - activation_params: dict = None, + activations: Sequence[Union[str, Callable]] = (("relu",),) + activation_params: dict = (None,) intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1039,7 +1039,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) From 506e5efbbaba7eb7226542d518550b9a794a1526 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:16:21 +0000 Subject: [PATCH 08/21] address review comments: Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 13 +++---- transformer_engine/jax/activation.py | 18 ++++++++- .../jax/cpp_extensions/activation.py | 39 ++++--------------- transformer_engine/jax/flax/module.py | 4 +- transformer_engine/jax/layernorm_mlp.py | 9 +++-- 5 files changed, 37 insertions(+), 46 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 88ef5796ca..2726c08134 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -20,8 +20,7 @@ from transformer_engine.jax.cpp_extensions.activation import ( _jax_act_lu, - _jax_quantize_dact_dbias, - ClampedSwigluParams, + _jax_quantize_dact_dbias ) from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, @@ -46,7 +45,7 @@ noop_quantizer_set, ) from transformer_engine.jax.quantize import helper -from transformer_engine.jax.activation import activation +from transformer_engine.jax.activation import activation, ClampedSwigluParams from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense @@ -222,7 +221,7 @@ def test_act_grad(self, shape, activation_type): value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) act_params = ( - ClampedSwigluParams.create(limit=0.75, alpha=1.702) + ClampedSwigluParams(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -257,7 +256,7 @@ def test_act_grad_with_tensor_scaling_fp8( q_layout=QuantizeLayout.ROWWISE, ) act_params = ( - ClampedSwigluParams.create(limit=0.75, alpha=1.702) + ClampedSwigluParams(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -295,7 +294,7 @@ def test_act_forward_with_tensor_scaling_fp8( ) act_params = ( - ClampedSwigluParams.create(limit=1.0, alpha=1.702) + ClampedSwigluParams(limit=1.0, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -321,7 +320,7 @@ def test_act_forward_with_block_scaling_fp8( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) act_params = ( - ClampedSwigluParams.create(limit=7.0, alpha=1.702) + ClampedSwigluParams(limit=7.0, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 32d482198c..e8a749a13e 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -8,15 +8,29 @@ from typing import Sequence, Union, Callable, Optional from functools import partial +from dataclasses import dataclass import jax import jax.numpy as jnp - +import numpy as np from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer -from .cpp_extensions.activation import ClampedSwigluParams + +@dataclass(frozen=True) +class ClampedSwigluParams: + limit: float = 7.0 + alpha: float = 1.702 + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + def __hash__(self): + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + def activation( diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 72de7f91c3..f6d74b0a9b 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -37,6 +37,7 @@ DelayedScaleQuantizer, ScalingMode, ) +from ..activation import ClampedSwigluParams if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax import ffi # pylint: disable=ungrouped-imports @@ -61,26 +62,6 @@ } -@dataclass(frozen=True) -class ClampedSwigluParams: - limit: float = 7.0 - alpha: float = 1.702 - - @staticmethod - def create(limit: float = 7.0, alpha: float = 1.702): - return ClampedSwigluParams(limit=limit, alpha=alpha) - - def __hash__(self): - return hash(self.limit) - - def __eq__(self, value): - if not isinstance(value, ClampedSwigluParams): - return False - return self.limit == value.limit and self.alpha == value.alpha - - def to_ffi_lowering_dict(self): - return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} - def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" @@ -224,7 +205,6 @@ def impl( """ del is_outer assert ActLuPrimitive.inner_primitive is not None - import numpy as np out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = ( ActLuPrimitive.inner_primitive.bind( @@ -277,7 +257,6 @@ def batcher( x, scale = batched_args x_bdim, scale_bdim = batch_dims amax_bdim = scale_bdim - import numpy as np out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim return ( @@ -668,8 +647,6 @@ def impl( """ te_dact_dbias_quantize_p impl """ - import numpy as np - del is_outer assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( @@ -990,13 +967,13 @@ def _jax_act_lu( """ JAX native activation implementation """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): @@ -1020,7 +997,7 @@ def _jax_quantize_dact_dbias( """ JAX implementation of dact_lu and dbias with optional quantization """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1077,7 +1054,7 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() if not ActLuPrimitive.enabled(): return _jax_act_lu(x, activation_type, quantizer, act_params) @@ -1116,7 +1093,7 @@ def act_lu( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( - x=x.astype(jnp.float32), + x=x, activation_type=activation_type, quantizer=None, act_params=act_params, @@ -1182,7 +1159,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1345,7 +1322,7 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ - act_params = act_params if act_params is not None else ClampedSwigluParams.create() + act_params = act_params if act_params is not None else ClampedSwigluParams() output, _ = quantize_dact_dbias( dz=dz, x=x, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8d73c4bf1a..9234a1d083 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -961,8 +961,8 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = (("relu",),) - activation_params: dict = (None,) + activations: Sequence[Union[str, Callable]] = ("relu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 07c9a06f20..a815710b34 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,6 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .activation import ClampedSwigluParams from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -130,7 +131,7 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - activation_params + activation_params, quantizer_sets, ) return output @@ -206,7 +207,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - activation_params + activation_params, quantizer_sets, ) return output @@ -317,7 +318,7 @@ def _layernorm_mlp_fwd_rule( casted_act_out = tex.act_lu( dot_1_output, activation_type, - act_params=tex.ClampedSwigluParams.create(**activation_params) if activation_params else None + act_params=ClampedSwigluParams(**activation_params) if activation_params else None ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -466,7 +467,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - act_params=tex.ClampedSwigluParams.create(**activation_params) if activation_params else None + act_params=tex.ClampedSwigluParams(**activation_params) if activation_params else None ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From b2e458885d32782ac43bde548d9aa191ae9e7692 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:20:19 +0000 Subject: [PATCH 09/21] remove unnecessary imports Signed-off-by: Varun Thumbe --- transformer_engine/jax/cpp_extensions/activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index f6d74b0a9b..7952dd0c1c 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -16,7 +16,6 @@ import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type -from dataclasses import dataclass from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, From c9177b49aed38974c158f383efa11aed8e7979ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:17:46 +0000 Subject: [PATCH 10/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 5 +---- transformer_engine/jax/activation.py | 2 +- transformer_engine/jax/cpp_extensions/activation.py | 1 - transformer_engine/jax/layernorm_mlp.py | 4 ++-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 2726c08134..3d1d7f1507 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -18,10 +18,7 @@ from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.cpp_extensions.activation import ( - _jax_act_lu, - _jax_quantize_dact_dbias -) +from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, _jax_rmsnorm, diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index e8a749a13e..196bf665b6 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -18,6 +18,7 @@ from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer + @dataclass(frozen=True) class ClampedSwigluParams: limit: float = 7.0 @@ -32,7 +33,6 @@ def to_ffi_lowering_dict(self): return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} - def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 7952dd0c1c..b4cb4fc9c5 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -61,7 +61,6 @@ } - def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" if fn_or_string == "linear": diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index a815710b34..1f4882cb09 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -318,7 +318,7 @@ def _layernorm_mlp_fwd_rule( casted_act_out = tex.act_lu( dot_1_output, activation_type, - act_params=ClampedSwigluParams(**activation_params) if activation_params else None + act_params=ClampedSwigluParams(**activation_params) if activation_params else None, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -467,7 +467,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - act_params=tex.ClampedSwigluParams(**activation_params) if activation_params else None + act_params=tex.ClampedSwigluParams(**activation_params) if activation_params else None, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From d89b878631fb3dda3ebfaf112c0c1fa09be8d0fe Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:22:10 +0000 Subject: [PATCH 11/21] address review comments Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3d1d7f1507..0f2988dc5b 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -291,7 +291,7 @@ def test_act_forward_with_tensor_scaling_fp8( ) act_params = ( - ClampedSwigluParams(limit=1.0, alpha=1.702) + ClampedSwigluParams(limit=0.75, alpha=1.702) if activation_type == ("clamped_silu", "clamped_linear") else None ) From 0b9ff1e6cb620fc1bd64cef74dd7821f02107628 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:46:41 +0000 Subject: [PATCH 12/21] minor comments and missed hooking up transformer layer to layernorml mlp Signed-off-by: Varun Thumbe --- transformer_engine/jax/activation.py | 10 +++++++++- transformer_engine/jax/flax/transformer.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 196bf665b6..faad120a39 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -48,6 +48,8 @@ def activation( x: Input tensor to apply activations to activation_type: Sequence of activation functions quantizer: Optional quantizer for quantizing the output + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated output tensor @@ -58,7 +60,7 @@ def activation( @partial(jax.custom_vjp, nondiff_argnums=(1, 3)) -def _activation(x, activation_type, quantizer, act_params=None): +def _activation(x, activation_type, quantizer, act_params): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -68,6 +70,8 @@ def _activation(x, activation_type, quantizer, act_params=None): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated tensor @@ -83,6 +87,8 @@ def _activation_fwd_rule(x, activation_type, quantizer, act_params): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Tuple of (output, context) for backward pass @@ -98,6 +104,8 @@ def _activation_bwd_rule(activation_type, act_params, ctx, g): Args: activation_type: Sequence of activation functions + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. ctx: Context from forward pass g: Gradient from upstream diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 6bfe031bb1..fc72b9bc3f 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -2049,6 +2049,7 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, + activation_params=self.mlp_activation_params, intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, From 4fa022e8e8bf4de3a164b5172441b569ff0761fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:55:50 +0000 Subject: [PATCH 13/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/activation.py | 1 + transformer_engine/jax/cpp_extensions/activation.py | 1 + 2 files changed, 2 insertions(+) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index faad120a39..40df1099fe 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -33,6 +33,7 @@ def to_ffi_lowering_dict(self): return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index b4cb4fc9c5..7952dd0c1c 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -61,6 +61,7 @@ } + def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" if fn_or_string == "linear": From b02246ced1c3d627abc2830610a30828ba016c8b Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:16:21 +0000 Subject: [PATCH 14/21] address review comments: Signed-off-by: Varun Thumbe --- transformer_engine/jax/activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 40df1099fe..b8e62152f3 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -18,7 +18,6 @@ from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer - @dataclass(frozen=True) class ClampedSwigluParams: limit: float = 7.0 From 38230d19b459d16930ac67395747df834099a670 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:17:46 +0000 Subject: [PATCH 15/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe --- transformer_engine/jax/activation.py | 2 +- transformer_engine/jax/cpp_extensions/activation.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index b8e62152f3..faad120a39 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -18,6 +18,7 @@ from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer + @dataclass(frozen=True) class ClampedSwigluParams: limit: float = 7.0 @@ -32,7 +33,6 @@ def to_ffi_lowering_dict(self): return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} - def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 7952dd0c1c..b4cb4fc9c5 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -61,7 +61,6 @@ } - def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): """Convert a string to an activation function.""" if fn_or_string == "linear": From 0e59e7cd7a7f0c7ac475e64e5e7c7a0c592e71e8 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 21:54:38 +0000 Subject: [PATCH 16/21] accidentally removed passing a parameter Signed-off-by: Varun Thumbe --- transformer_engine/jax/layernorm_mlp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 1f4882cb09..cf7aff15b2 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -318,6 +318,7 @@ def _layernorm_mlp_fwd_rule( casted_act_out = tex.act_lu( dot_1_output, activation_type, + quantizer=ffn2_quantizer_set.x, act_params=ClampedSwigluParams(**activation_params) if activation_params else None, ) From b096ac0af296bd001c414f076412bad79415352c Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 22 Sep 2025 23:55:58 +0000 Subject: [PATCH 17/21] generic jax params struct working Signed-off-by: Varun Thumbe --- tests/jax/test_custom_call_compute.py | 16 +++++---- transformer_engine/jax/activation.py | 20 ++++++++++- .../jax/cpp_extensions/activation.py | 34 ++++++++++--------- transformer_engine/jax/csrc/extensions.h | 8 +++++ .../jax/csrc/extensions/activation.cpp | 16 ++++----- transformer_engine/jax/layernorm_mlp.py | 7 ++-- 6 files changed, 67 insertions(+), 34 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 0f2988dc5b..116ac7aea7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -42,7 +42,7 @@ noop_quantizer_set, ) from transformer_engine.jax.quantize import helper -from transformer_engine.jax.activation import activation, ClampedSwigluParams +from transformer_engine.jax.activation import activation, ActivationParams from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense @@ -217,8 +217,9 @@ def test_act_grad(self, shape, activation_type): value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} act_params = ( - ClampedSwigluParams(limit=0.75, alpha=1.702) + ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -252,8 +253,10 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} + act_params = ( - ClampedSwigluParams(limit=0.75, alpha=1.702) + ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -289,9 +292,9 @@ def test_act_forward_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=q_layout, ) - + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} act_params = ( - ClampedSwigluParams(limit=0.75, alpha=1.702) + ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) @@ -316,8 +319,9 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} act_params = ( - ClampedSwigluParams(limit=7.0, alpha=1.702) + ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index faad120a39..f9e6537f18 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -32,12 +32,30 @@ def __hash__(self): def to_ffi_lowering_dict(self): return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} +@dataclass(frozen=True) +class ActivationParams: + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + # Add other activation-specific parameter fields here as needed in the future + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + if activation_type == ("clamped_silu", "clamped_linear") or activation_type == "clamped_silu" or activation_type == "clamped_linear": + return ActivationParams(ClampedSwigluParams(**kwargs)) + else: + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None, + act_params: Optional[ActivationParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index b4cb4fc9c5..fd41e1d330 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -36,7 +36,7 @@ DelayedScaleQuantizer, ScalingMode, ) -from ..activation import ClampedSwigluParams +from ..activation import ActivationParams if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax import ffi # pylint: disable=ungrouped-imports @@ -61,20 +61,23 @@ } -def _convert_to_activation_function(fn_or_string, act_params: Optional[ClampedSwigluParams] = None): +def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x if fn_or_string == "clamped_linear": - return lambda x: jnp.clip(x, min=-act_params.limit, max=act_params.limit) + 1 + limit = act_params.clamped_swiglu.limit + return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) if fn_or_string == "clamped_silu": + limit = act_params.clamped_swiglu.limit + alpha = act_params.clamped_swiglu.alpha return lambda x: jax.nn.sigmoid( - act_params.alpha * jnp.minimum(x, act_params.limit) - ) * jnp.minimum(x, act_params.limit) + alpha * jnp.minimum(x, limit) + ) * jnp.minimum(x, limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -960,18 +963,17 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): def _jax_act_lu( - inputs, activation_type, quantizer=None, act_params: Optional[ClampedSwigluParams] = None + inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None ) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ - act_params = act_params if act_params is not None else ClampedSwigluParams() + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - act_params = act_params if act_params is not None else ClampedSwigluParams() x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): @@ -990,12 +992,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None, + act_params: Optional[ActivationParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization """ - act_params = act_params if act_params is not None else ClampedSwigluParams() + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1030,7 +1032,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1052,7 +1054,7 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - act_params = act_params if act_params is not None else ClampedSwigluParams() + act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): return _jax_act_lu(x, activation_type, quantizer, act_params) @@ -1140,7 +1142,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None, + act_params: Optional[ActivationParams] = None, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1157,7 +1159,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - act_params = act_params if act_params is not None else ClampedSwigluParams() + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1306,7 +1308,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - act_params: Optional[ClampedSwigluParams] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1320,7 +1322,7 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ - act_params = act_params if act_params is not None else ClampedSwigluParams() + act_params = act_params if act_params is not None else ActivationParams() output, _ = quantize_dact_dbias( dz=dz, x=x, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c3f730020d..0982e0ac88 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -43,6 +43,11 @@ struct ClampedSwigluConfig { float alpha; }; +struct ActivationConfig { + ClampedSwigluConfig clamped_swiglu_config; +}; + + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -142,4 +147,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, ::xla::ffi::StructMember("limit"), ::xla::ffi::StructMember("alpha")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ActivationConfig, + ::xla::ffi::StructMember( + "clamped_swiglu")); #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 31f346b2b5..7e7e3178b4 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,10 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int, ClampedSwigluConfig act_params) { + bool is_2x_int, ActivationConfig act_params) { // parameters for clamped swiglu used in GPT OSS - auto swiglu_limit = act_params.limit; - auto swiglu_alpha = act_params.alpha; + auto swiglu_limit = act_params.clamped_swiglu_config.limit; + auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -154,7 +154,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("act_enum") .Attr("scaling_mode") .Attr("is_2x") - .Attr( + .Attr( "act_params"), // Can generalize the config later if we have more activations that need params FFI_CudaGraph_Traits); @@ -227,10 +227,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, bool is_dbias, - ClampedSwigluConfig act_params) { + ActivationConfig act_params) { // parameters for clamped swiglu used in GPT OSS - auto swiglu_limit = act_params.limit; - auto swiglu_alpha = act_params.alpha; + auto swiglu_limit = act_params.clamped_swiglu_config.limit; + auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -427,7 +427,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("act_enum") .Attr("is_2x") .Attr("is_dbias") - .Attr("act_params"), + .Attr("act_params"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index cf7aff15b2..3878d5cc51 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,7 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex -from .activation import ClampedSwigluParams +from .activation import ActivationParams from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -319,7 +319,8 @@ def _layernorm_mlp_fwd_rule( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, - act_params=ClampedSwigluParams(**activation_params) if activation_params else None, + act_params=ActivationParams.create(activation_type, **activation_params) + if activation_params else None, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -468,7 +469,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - act_params=tex.ClampedSwigluParams(**activation_params) if activation_params else None, + act_params=ActivationParams.create(activation_type, **activation_params) if activation_params else None, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From f5f1c4ded81b21da1d8290407f0ad242eea6cd43 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 23 Sep 2025 17:50:23 +0000 Subject: [PATCH 18/21] remove unnecessary import Signed-off-by: Varun Thumbe --- transformer_engine/jax/cpp_extensions/activation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index fd41e1d330..d19c093147 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1025,8 +1025,6 @@ def _jax_quantize_dact_dbias( return dx, dbias -from dataclasses import dataclass - def act_lu( x: jnp.ndarray, From c649ae68c0b630e7b69074c8bdf9b39942236bac Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 23 Sep 2025 14:48:58 -0700 Subject: [PATCH 19/21] Update transformer_engine/pytorch/ops/basic/activation.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: vthumbe1503 --- transformer_engine/pytorch/ops/basic/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 8a754c6382..d8ddf8b594 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -403,7 +403,7 @@ class ClampedSwiGLU(_ActivationOperation): 1. Both gate and pre-activations are clipped based on parameter limit. 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is different from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. Parameters From c8aa9053ae3c665c54b419793fe98cc150d33de6 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 23 Sep 2025 14:49:12 -0700 Subject: [PATCH 20/21] Update transformer_engine/jax/flax/module.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: vthumbe1503 --- transformer_engine/jax/flax/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 9234a1d083..7b248c33bf 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -902,7 +902,7 @@ class LayerNormMLP(TransformerEngineBase): The parameters for the ClampedSwiglu activation used in GPT OSS. This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`activations`. At the moment ClampedSwiglu is the only activation that requires parameters. If there is more activation - functions that require parameters in the future, we might need to change it to a more gerneric + functions that require parameters in the future, we might need to change it to a more generic parameter container. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. From 563aa115d4ab781a8f1f4c2e1364ca7df55a074d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 23 Sep 2025 21:56:12 +0000 Subject: [PATCH 21/21] fix the comment Signed-off-by: Varun Thumbe --- transformer_engine/jax/flax/module.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 9234a1d083..586eb705f9 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -899,11 +899,9 @@ class LayerNormMLP(TransformerEngineBase): The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. activation_params: dict, default = None - The parameters for the ClampedSwiglu activation used in GPT OSS. This is only - used when ('clamped_silu', 'clamped_linear') is in :attr:`activations`. At the moment - ClampedSwiglu is the only activation that requires parameters. If there is more activation - functions that require parameters in the future, we might need to change it to a more gerneric - parameter container. + The parameters needed(if any) by the activation functions specified in :attr:`activations`. + At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS + need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1