diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9e39b84c0b..116ac7aea7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -42,7 +42,7 @@ noop_quantizer_set, ) from transformer_engine.jax.quantize import helper -from transformer_engine.jax.activation import activation +from transformer_engine.jax.activation import activation, ActivationParams from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense @@ -170,29 +170,35 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { "L0": [ ("gelu",), ("gelu", "linear"), + ("clamped_silu", "clamped_linear"), ], "L2": ALL_ACTIVATION_TYPES, } class TestActivation: - def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type).data + def ref_act(self, x, activation_type, act_params): + return _jax_act_lu(x, activation_type, act_params=act_params).data - def value_n_grad_ref_func(self, x, activation_type): + def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) - def primitive_func(self, inputs, activation_type, quantizer): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer) + def primitive_func(self, inputs, activation_type, quantizer, act_params): + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -209,12 +215,16 @@ def test_act_grad(self, shape, activation_type): x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} + act_params = ( + ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -234,7 +244,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -242,9 +253,17 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) + act_params = ( + ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -273,10 +292,14 @@ def test_act_forward_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=q_layout, ) - - te_output = tex.act_lu(x, activation_type, te_quantizer) - jax_output = _jax_act_lu(x, activation_type, jax_quantizer) - + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} + act_params = ( + ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) + jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -296,10 +319,14 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - - output = tex.act_lu(x, activation_type, quantizer) - ref_out = self.ref_act(x, activation_type) - + act_args = {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} + act_params = ( + ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + output = tex.act_lu(x, activation_type, quantizer, act_params) + ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index e50d71040d..10a6459c3a 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { QGEGLU, SRELU, SREGLU, + CLAMPED_SWIGLU, }; /*! \brief Computes the GeLU activation of the input. diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 12b35ec43c..f9e6537f18 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -8,20 +8,54 @@ from typing import Sequence, Union, Callable, Optional from functools import partial +from dataclasses import dataclass import jax import jax.numpy as jnp - +import numpy as np from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer +@dataclass(frozen=True) +class ClampedSwigluParams: + limit: float = 7.0 + alpha: float = 1.702 + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + def __hash__(self): + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + +@dataclass(frozen=True) +class ActivationParams: + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + # Add other activation-specific parameter fields here as needed in the future + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + if activation_type == ("clamped_silu", "clamped_linear") or activation_type == "clamped_silu" or activation_type == "clamped_linear": + return ActivationParams(ClampedSwigluParams(**kwargs)) + else: + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + + def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. @@ -32,17 +66,19 @@ def activation( x: Input tensor to apply activations to activation_type: Sequence of activation functions quantizer: Optional quantizer for quantizing the output + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated output tensor """ assert x.shape[-1] % len(activation_type) == 0 - output = _activation(x, activation_type, quantizer) + output = _activation(x, activation_type, quantizer, act_params) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _activation(x, activation_type, quantizer): +@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +def _activation(x, activation_type, quantizer, act_params): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -52,36 +88,42 @@ def _activation(x, activation_type, quantizer): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated tensor """ - _output, _ = _activation_fwd_rule(x, activation_type, quantizer) + _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params) return _output -def _activation_fwd_rule(x, activation_type, quantizer): +def _activation_fwd_rule(x, activation_type, quantizer, act_params): """Forward pass rule for activation function. Args: x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Tuple of (output, context) for backward pass """ - fwd_output = tex.act_lu(x, activation_type, quantizer) + fwd_output = tex.act_lu(x, activation_type, quantizer, act_params) # This is a no-op for higher-precision tensors fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) -def _activation_bwd_rule(activation_type, ctx, g): +def _activation_bwd_rule(activation_type, act_params, ctx, g): """Backward pass rule for activation function. Args: activation_type: Sequence of activation functions + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. ctx: Context from forward pass g: Gradient from upstream @@ -90,7 +132,7 @@ def _activation_bwd_rule(activation_type, ctx, g): """ (x, _) = ctx assert x.dtype == g.dtype - dx = tex.dact_lu(g, x, activation_type) + dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should # always be a NoScaleTensor assert isinstance(dx, NoScaleTensor) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index cdda201668..d19c093147 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -13,9 +13,9 @@ from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec +import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type - from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -36,6 +36,7 @@ DelayedScaleQuantizer, ScalingMode, ) +from ..activation import ActivationParams if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax import ffi # pylint: disable=ungrouped-imports @@ -56,17 +57,27 @@ ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, + ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } -def _convert_to_activation_function(fn_or_string): +def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + if fn_or_string == "clamped_linear": + limit = act_params.clamped_swiglu.limit + return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if fn_or_string == "clamped_silu": + limit = act_params.clamped_swiglu.limit + alpha = act_params.clamped_swiglu.alpha + return lambda x: jax.nn.sigmoid( + alpha * jnp.minimum(x, limit) + ) * jnp.minimum(x, limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -89,7 +100,8 @@ class ActLuPrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer + 9, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params inner_primitive = None outer_primitive = None @@ -105,11 +117,12 @@ def abstract( is_2x, scale_dtype, is_outer, + act_params, ): """ te_act_lu_p abstract """ - del act_enum + del act_enum, act_params dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 @@ -155,6 +168,7 @@ def lowering( is_2x, scale_dtype, is_outer, + act_params, ): """ te_gated_act_lu_p lowering rules @@ -163,9 +177,14 @@ def lowering( x_aval, scale_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x + ctx, + x, + scale, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + is_2x=is_2x, + act_params=act_params.to_ffi_lowering_dict(), ) return out @@ -180,6 +199,7 @@ def impl( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe implementation @@ -198,6 +218,7 @@ def impl( is_2x=is_2x, scale_dtype=scale_dtype, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -226,6 +247,7 @@ def batcher( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -247,6 +269,7 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + act_params=act_params, ), out_bdims, ) @@ -260,6 +283,7 @@ def infer_sharding_from_operands( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -271,6 +295,7 @@ def infer_sharding_from_operands( scale_dtype, act_len, is_outer, + act_params, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -323,6 +348,7 @@ def partition( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -383,6 +409,7 @@ def sharded_impl(x, scale): is_2x=is_2x, scale_dtype=scale_dtype, is_outer=True, + act_params=act_params, ) ) @@ -410,11 +437,12 @@ def shardy_sharding_rule( is_2x, scale_dtype, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types + del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params prefix = "ActLuPrimitive_" x_rank = len(value_types[0].shape) scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( @@ -458,8 +486,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params + impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) inner_primitive = None outer_primitive = None @@ -477,6 +505,7 @@ def abstract( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p abstract @@ -533,6 +562,7 @@ def abstract( jax_dtype_to_te_dtype(out_dtype), scaling_mode, is_2x, + act_params, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -578,6 +608,7 @@ def lowering( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p lowering rules @@ -596,6 +627,7 @@ def lowering( is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), + act_params=act_params.to_ffi_lowering_dict(), ) @staticmethod @@ -611,6 +643,7 @@ def impl( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p impl @@ -630,6 +663,7 @@ def impl( act_enum=act_enum, act_len=act_len, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -658,6 +692,7 @@ def batcher( act_enum, act_len, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -688,6 +723,7 @@ def batcher( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, ), out_bdims, ) @@ -702,11 +738,12 @@ def infer_sharding_from_operands( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum + del out_dtype, result_infos, act_enum, act_params del scale_dtype, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -777,6 +814,7 @@ def partition( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -857,6 +895,7 @@ def sharded_impl(dz, x, scale): act_enum=act_enum, act_len=act_len, is_outer=True, + act_params=act_params, ) ) if is_dbias: @@ -883,11 +922,12 @@ def shardy_sharding_rule( act_enum, act_len, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types + del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params prefix = "BaseDActLuDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 @@ -922,20 +962,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) + x_i = _convert_to_activation_function(act_fn, act_params)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) x = jnp.squeeze(x, axis=-2) @@ -950,10 +992,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -961,7 +1005,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -980,10 +1025,12 @@ def _jax_quantize_dact_dbias( return dx, dbias + def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1005,24 +1052,22 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer) + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_act_lu(x, activation_type, quantizer) - + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer + f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) - if quantizer is None: out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, @@ -1034,6 +1079,7 @@ def act_lu( is_2x=False, scale_dtype=jnp.float32, is_outer=True, + act_params=act_params, ) out = out.reshape(output_shape) out = NoScaleTensor( @@ -1048,10 +1094,10 @@ def act_lu( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) return out - if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1071,6 +1117,7 @@ def act_lu( is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), is_outer=True, + act_params=act_params, ) quantizer.update(updated_amax) @@ -1093,6 +1140,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1109,7 +1157,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1122,8 +1170,7 @@ def quantize_dact_dbias( if not PrimitiveClass.enabled() or ( quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( dz, @@ -1139,6 +1186,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) output = output.astype(x.dtype) dbias = None @@ -1154,7 +1202,11 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1171,6 +1223,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, + act_params=act_params, ) if war_output is not None: return war_output @@ -1178,10 +1231,11 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz, - x=x, + dz=dz.astype(jnp.float32), + x=x.astype(jnp.float32), activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1194,7 +1248,10 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1220,6 +1277,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1248,6 +1306,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1261,11 +1320,13 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ + act_params = act_params if act_params is not None else ActivationParams() output, _ = quantize_dact_dbias( dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=quantizer, + act_params=act_params, ) return output diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 59079fe3f0..0982e0ac88 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -38,6 +38,16 @@ XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); namespace transformer_engine { namespace jax { +struct ClampedSwigluConfig { + float limit; + float alpha; +}; + +struct ActivationConfig { + ClampedSwigluConfig clamped_swiglu_config; +}; + + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -134,4 +144,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ActivationConfig, + ::xla::ffi::StructMember( + "clamped_swiglu")); #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 17fa9906bb..7e7e3178b4 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int) { + bool is_2x_int, ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu_config.limit; + auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::SREGLU: nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -133,20 +140,23 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Arg() // scale - .Ret() // output - .Ret() // colwise output - .Ret() // scale_inv - .Ret() // scale_inv colwise - .Ret() // amax - .Attr("act_enum") - .Attr("scaling_mode") - .Attr("is_2x"), - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + ActLuHandler, ActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x") + .Attr( + "act_params"), // Can generalize the config later if we have more activations that need params + FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -216,7 +226,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias) { + int64_t act_enum, bool is_2x, bool is_dbias, + ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu_config.limit; + auto swiglu_alpha = act_params.clamped_swiglu_config.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -383,6 +397,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::SREGLU: nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -408,7 +426,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias"), + .Attr("is_dbias") + .Attr("act_params"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index afbeb644c1..08600fd3f4 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -133,6 +133,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) .value("SREGLU", NVTE_Activation_Type::SREGLU) + .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU) .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c548c54efa..586eb705f9 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase): activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. + activation_params: dict, default = None + The parameters needed(if any) by the activation functions specified in :attr:`activations`. + At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS + need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1 @@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ("relu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1023,6 +1028,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ("relu", "linear"), ("quick_gelu", "linear"), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] @@ -1031,7 +1037,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) @@ -1150,6 +1158,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=self.ffn1_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, + activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fb3ac7b9ae..fc72b9bc3f 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1631,6 +1631,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_activations: Sequence[str], default = ('relu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + mlp_activation_params: dict = None + This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases. @@ -1751,6 +1754,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None mlp_activations: Sequence[str] = ("relu",) + mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False @@ -2045,6 +2049,7 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, + activation_params=self.mlp_activation_params, intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801af..3878d5cc51 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,6 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .activation import ActivationParams from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -48,6 +49,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + activation_params: dict = None, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -129,6 +131,7 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, quantizer_sets, ) return output @@ -154,6 +157,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + activation_params: dict, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -203,6 +207,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, quantizer_sets, ) return output @@ -227,6 +232,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -306,10 +312,15 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) + # At the moment the act_params is only used for ClampedSwiglu + # If there are more activations that require parameters in the future + # we might need to change it to a more generic parameter container casted_act_out = tex.act_lu( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + act_params=ActivationParams.create(activation_type, **activation_params) + if activation_params else None, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -371,6 +382,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, ctx, grad, ): @@ -457,6 +469,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + act_params=ActivationParams.create(activation_type, **activation_params) if activation_params else None, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 8a754c6382..d8ddf8b594 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -403,7 +403,7 @@ class ClampedSwiGLU(_ActivationOperation): 1. Both gate and pre-activations are clipped based on parameter limit. 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is different from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. Parameters