diff --git a/.gitignore b/.gitignore index 7a86041a1e..789d3b0a5f 100644 --- a/.gitignore +++ b/.gitignore @@ -41,4 +41,4 @@ compile_commands.json .nfs tensor_dumps/ artifacts/ -*.DS_Store +.DS_Store diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index b372d39879..8d19d3182b 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2 +Subproject commit 8d19d3182bfbc304046a15e9236bec9ff31511fc diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index d1d54c0dda..90f68653cc 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -143,6 +143,86 @@ Tensor saving and restoring functions .. autoapifunction:: transformer_engine.pytorch.restore_from_saved +Operation fuser +--------------- + +.. autoapiclass:: transformer_engine.pytorch.ops.Sequential + :members: forward + +.. autoapiclass:: transformer_engine.pytorch.ops.FusibleOperation + :members: fuser_forward, fuser_backward + +.. autoapiclass:: transformer_engine.pytorch.ops.BasicOperation + :members: op_forward, op_backward + +.. autoapiclass:: transformer_engine.pytorch.ops.FusedOperation + :members: fuser_forward, fuser_backward + +.. autoapifunction:: transformer_engine.pytorch.ops.register_forward_fusion + +.. autoapifunction:: transformer_engine.pytorch.ops.register_backward_fusion + +.. autoapiclass:: transformer_engine.pytorch.ops.Linear + +.. autoapiclass:: transformer_engine.pytorch.ops.AddExtraInput + +.. autoapiclass:: transformer_engine.pytorch.ops.AllGather + +.. autoapiclass:: transformer_engine.pytorch.ops.AllReduce + +.. autoapiclass:: transformer_engine.pytorch.ops.BasicLinear + :members: _functional_forward, _functional_backward + +.. autoapiclass:: transformer_engine.pytorch.ops.Bias + +.. autoapiclass:: transformer_engine.pytorch.ops.ClampedSwiGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.ConstantScale + +.. autoapiclass:: transformer_engine.pytorch.ops.Dropout + +.. autoapiclass:: transformer_engine.pytorch.ops.GEGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.GELU + +.. autoapiclass:: transformer_engine.pytorch.ops.GLU + +.. autoapiclass:: transformer_engine.pytorch.ops.GroupedLinear + +.. autoapiclass:: transformer_engine.pytorch.ops.Identity + +.. autoapiclass:: transformer_engine.pytorch.ops.L2Normalization + +.. autoapiclass:: transformer_engine.pytorch.ops.LayerNorm + +.. autoapiclass:: transformer_engine.pytorch.ops.MakeExtraOutput + +.. autoapiclass:: transformer_engine.pytorch.ops.QGELU + +.. autoapiclass:: transformer_engine.pytorch.ops.QGEGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.Quantize + +.. autoapiclass:: transformer_engine.pytorch.ops.ReGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.ReLU + +.. autoapiclass:: transformer_engine.pytorch.ops.ReduceScatter + +.. autoapiclass:: transformer_engine.pytorch.ops.Reshape + +.. autoapiclass:: transformer_engine.pytorch.ops.RMSNorm + +.. autoapiclass:: transformer_engine.pytorch.ops.SReGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.SReLU + +.. autoapiclass:: transformer_engine.pytorch.ops.ScaledSwiGLU + +.. autoapiclass:: transformer_engine.pytorch.ops.SiLU + +.. autoapiclass:: transformer_engine.pytorch.ops.SwiGLU + Deprecated functions -------------------- diff --git a/docs/examples/op_fuser/fp8_layernorm_linear.png b/docs/examples/op_fuser/fp8_layernorm_linear.png new file mode 100644 index 0000000000..b5916a6152 Binary files /dev/null and b/docs/examples/op_fuser/fp8_layernorm_linear.png differ diff --git a/docs/examples/op_fuser/layernorm_mlp.png b/docs/examples/op_fuser/layernorm_mlp.png new file mode 100644 index 0000000000..f388c88fa9 Binary files /dev/null and b/docs/examples/op_fuser/layernorm_mlp.png differ diff --git a/docs/examples/op_fuser/op_fuser.rst b/docs/examples/op_fuser/op_fuser.rst new file mode 100644 index 0000000000..9613ba74b3 --- /dev/null +++ b/docs/examples/op_fuser/op_fuser.rst @@ -0,0 +1,353 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Operation fuser API +=================== + +Motivation +---------- + +Transformer Engine relies heavily on operation fusion to achieve high +performance. A typical training workload involves many memory-bound +operations such as activation functions and normalization, so +replacing them with fused kernels can deliver a significant +performance benefit. This is especially true for low-precision +training (e.g. FP8 and FP4) because it involves extra cast operations. + +Managing these fusions can be challenging because they differ based on +operation types, communication patterns, data types, and GPU +architectures. The most straightforward solution is to provide +monolithic modules like ``Linear``, ``LayerNormLinear``, or +``TransformerLayer``. These conform to the interface of a standard +PyTorch module, but can perform arbitrary fusions internally. These +hand-tuned implementations can achieve maximum performance, but they +tend to be complicated and difficult to modify. + +As an alternative to this "top-down" design, TE exposes a "bottom-up" +operation-based API. The user constructs individual operations and +passes them into a fuser, resulting in the same fused kernels as the +monolithic modules. This approach is more flexible, making it easier +to support new model architectures or to experiment with fusions. + +Basic usage +----------- + +Sequential operations +^^^^^^^^^^^^^^^^^^^^^ + +At the most basic level, the operation fuser API involves two classes +in the ``transformer_engine.pytorch.ops`` submodule: + +- ``FusibleOperation``: An abstract base class for tensor operations. + Examples include ``Linear``, ``LayerNorm``, and ``AllReduce``. It is + a subclass of ``torch.nn.Module``, so it can hold trainable + parameters and can be called to perform the operation's forward + pass. +- ``Sequential``: A container of modules in sequential order. Its + interface is very similar to ``torch.nn.Sequential``. If it contains + any ``FusibleOperation`` s, then it may attempt to fuse them in the + forward and backward passes. + +Thus, using the operation fuser simply involves constructing +``FusibleOperation`` s and passing them into a ``Sequential``. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Options + hidden_size = 4096 + ffn_size = 28672 + batch_size = 16384 + + # Construct operations and fuse + mlp = te.ops.Sequential( + te.ops.LayerNorm(hidden_size), + te.ops.Linear(hidden_size, ffn_size), + te.ops.SwiGLU(), + te.ops.Linear(ffn_size // 2, hidden_size), + ) + + # Forward pass + x = torch.randn(batch_size, hidden_size, device="cuda") + y = mlp(x) + +.. figure:: ./layernorm_mlp.png + :align: center + + Operations that match ``LayerNormMLP`` module. Note that different + fusions have been applied in the forward and backward passes. + +Quantization +^^^^^^^^^^^^ + +The operation fuser respects TE's APIs for low-precision ("quantized") +data formats like FP8 and FP4. Constructing operations within a +``quantized_model_init`` context will enable quantized weights and +performing the forward pass within an ``autocast`` context will enable +quantized compute. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct layer with quantized weights + with te.quantized_model_init(): + fc1 = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.Linear(4096, 28672), + ) + + # Forward pass within autocast context + x = torch.randn(16384, 4096, device="cuda") + with te.autocast(): + y = fc1(x) + + # Backward pass outside of autocast context + y.sum().backward() + +Branching operations +^^^^^^^^^^^^^^^^^^^^ + +The operation fuser supports very limited branching behavior. While +the operations must be in sequential order, some operations can accept +extra inputs or produce extra outputs. For example, ``AddExtraInput`` +will add an extra input tensor to the intermediate tensor and +``MakeExtraOutput`` will return the intermediate tensor as an extra +output. When calling a ``Sequential`` that contains any of these +branching operations, the extra inputs should be passed in as +arguments and the extra outputs will be returned. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct MLP with residual connection + fc1 = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.MakeExtraOutput(), # Output residual + te.ops.Linear(4096, 28672), + te.ops.SwiGLU(), + ) + fc2 = te.ops.Sequential( + te.ops.Linear(14336, 4096), + te.ops.AddExtraInput(), # Add residual + ) + + # Forward pass + x = torch.randn(16384, 4096, device="cuda") + y, residual = fc1(x) + y = fc2(y, residual) + +.. figure:: ./residual_layernorm_mlp.png + :align: center + + Operations for an MLP block with a residual connection. Note that + the block has been split into two sections, each with one branching + operation. + +Developer guide +--------------- + +Infrastructure +^^^^^^^^^^^^^^ + +In addition to ``FusibleOperation`` and ``Sequential``, the fuser +infrastructure relies on the following classes: + +- ``BasicOperation``: The most basic type of ``FusibleOperation``. + Examples include ``BasicLinear``, ``Bias``, and ``ReLU``. It holds + parameters and state, and it implements both a forward and backward + pass. The ``op_forward`` and ``op_backward`` functions have an + interface reminiscent of ``torch.autograd.Function``, e.g. they + accept a context object that caches state from the forward pass to + the backward pass. +- ``FusedOperation``: A ``FusibleOperation`` that can replace one or + more ``BasicOperation`` s. Examples include + ``ForwardLinearBiasActivation`` and ``BackwardActivationBias``. Its + forward and backward passes (the ``fuser_forward`` and + ``fuser_backward`` functions) must produce equivalent results as its + corresponding ``BasicOperation`` s. This also means that the + ``FusedOperation`` is stateless since it can access parameters and + state from the ``BasicOperation`` s. Note that different fusions may + be applied in the forward and backward pass, so a ``FusedOperation`` + may be missing its forward and/or backward implementation. +- ``OperationFuser``: This is the class that manages the operation + fusions. It launches the forward and backward passes within a + ``torch.autograd.Function``. It can also replace operations with + equivalent ``FusedOperation`` s. + +The first time that a ``Sequential`` is called, it will group adjacent +``FusibleOperation`` s together into ``OperationFuser`` s. The first +time an ``OperationFuser`` is called, it will attempt to fuse +operations for the forward pass and backward pass. Subsequent calls +will reuse the same state unless it has been invalidated, e.g. by +changing the quantization recipe. + +Quantization +^^^^^^^^^^^^ + +Each operation that supports quantized compute holds one or more +``Quantizer`` s, which are builder classes for converting +high-precision tensors (e.g. in FP32 or BF16) to quantized tensors. In +order to enable fused quantization kernels, operations can access the +quantizers of neighboring operations and quantize eagerly. + +.. figure:: ./fp8_layernorm_linear.png + :align: center + + Operations that match ``LayerNormLinear`` module with FP8 + quantization. + +In some situations, like when operations are split across multiple +``Sequential`` s, it may be helpful to encourage the fuser by manually +adding ``Quantize`` operations. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct layer with quantized weights + with te.quantized_model_init(): + norm = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.Quantize(), + ) + fc1 = te.ops.Sequential( + te.ops.Linear(4096, 28672), + ) + + # Forward pass + x = torch.randn(16384, 4096, device="cuda") + with te.autocast(): + y = norm(x) # y is a QuantizedTensor + z = fc1(y) + +.. warning:: + + This is an expert technique. Quantizer configurations can be quite + complicated, so the ``Quantize`` operation's quantizers may be + suboptimal. + +Implementing new operations +--------------------------- + +Implementing a basic operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Subclasses of ``BasicOperation`` must implement ``op_forward`` and +``op_backward``, which are reminiscent of the ``forward`` and +``backward`` methods of ``torch.autograd.Function``. They have an +argument for a context object that can be used to cache state from the +forward pass for use in the backward pass. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + class LearnableScale(te.ops.BasicOperation): + + def __init__(self) -> None: + super().__init__() + scale = torch.ones((), dtype=torch.float32, device="cuda") + self.register_parameter("scale", torch.nn.Parameter(scale)) + + def op_forward(self, ctx, input_: torch.Tensor, **unused) -> torch.Tensor: + out = self.scale * input_ + ctx.save_for_backward(self.scale, input_) + return out + + def op_backward( + self, + ctx, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: + scale, input_ = ctx.saved_tensors + grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)).reshape(()) + grad_input = scale * grad_output + return ( + grad_input, # Input gradient + (grad_scale,), # Param gradients + ) + +Implementing a fused operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Subclasses of ``FusedOperation`` should declare their corresponding +``BasicOperation`` s in the constructor. They should also implement +``fuser_forward`` and ``fuser_backward``, depending on usage. These +functions are similar to ``op_forward`` and ``op_backward`` from +``BasicOperation``, but some arguments and returns are lists. For +example, instead of taking a single context object, they take a list +of context objects for all the corresponding ``BasicOperation`` s. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + from typing import Optional + + class ForwardAxpy(te.ops.FusedOperation): + + def __init__(self, scale: te.ops.ConstantScale, add: te.ops.AddExtraInput) -> None: + super().__init__((scale, add)) # Equivalent basic ops + + def fuser_forward( + self, + basic_op_ctxs: list, + input_: torch.Tensor, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + **unused, + ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, ...]]]: + scale_op, add_op = self.basic_ops + extra_input = basic_op_extra_inputs[1][0] # Extra input to add op + out = scale_op.scale * input_ + extra_input + scale_ctx, add_ctx = basic_op_ctxs # No state needed for backward + return ( + out, # Output + [(), ()], # Extra outputs for each basic op + ) + +.. warning:: + + Remember the contract that the fused operation must produce outputs + that are interchangeable with the corresponding basic operation + outputs. + +In order to make these fused operations useful, they should be +registered with the operation fuser. To do this, first implement a +fusion function that can replace operations with the fused operation, +and then register it with the ``register_forward_fusion`` or +``register_backward_fusion`` functions. + +.. code-block:: python + + def fuse_axpy_ops( + ops: list[te.ops.FusibleOperation], + **unused, + ) -> list[te.ops.FusibleOperation]: + """Sliding window scan to perform ForwardAxpy fusion""" + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + if ( + isinstance(window[0], te.ops.ConstantScale) + and isinstance(window[1], te.ops.AddExtraInput) + ): + window = [ForwardAxpy(window[0], window[1])] + else: + out.append(window[0]) + window = window[1:] + window, ops = window + ops[:1], ops[1:] + out.extend(window + ops) + return out + + # Register fusion with operation fuser + te.ops.register_forward_fusion(fuse_axpy_ops) diff --git a/docs/examples/op_fuser/residual_layernorm_mlp.png b/docs/examples/op_fuser/residual_layernorm_mlp.png new file mode 100644 index 0000000000..fa95114a69 Binary files /dev/null and b/docs/examples/op_fuser/residual_layernorm_mlp.png differ diff --git a/docs/index.rst b/docs/index.rst index 336cd2d47f..194e76df24 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,6 +57,7 @@ Transformer Engine documentation examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/onnx/onnx_export.ipynb examples/te_jax_integration.ipynb + examples/op_fuser/op_fuser.rst .. toctree:: :hidden: diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 5d1a5ce61d..f95f065d78 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3428,8 +3428,15 @@ def test_custom_basic_op( ) -> None: """Custom basic op""" - class CustomScaleOp(te.ops.BasicOperation): - """Custom op that applies a learnable scale""" + class LearnableScale(te.ops.BasicOperation): + """Custom op that applies a learnable scale + + This class is as an example in the op fuser guide at + docs/examples/op_fuser/op_fuser.rst (see "Implementing a + basic operation"). Any changes made to this class should + also be made there. + + """ def __init__(self) -> None: super().__init__() @@ -3442,23 +3449,19 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], + **unused, ) -> torch.Tensor: + out = self.scale * input_ ctx.save_for_backward(self.scale, input_) - return self.scale * input_ + return out def op_backward( self, ctx: OperationContext, grad_output: torch.Tensor, - ) -> torch.Tensor: - ( - scale, - input_, - ) = ctx.saved_tensors - grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)) - grad_scale = grad_scale.reshape(()) + ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: + scale, input_ = ctx.saved_tensors + grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)).reshape(()) grad_input = scale * grad_output return grad_input, (grad_scale,) @@ -3485,7 +3488,7 @@ def op_backward( y_ref.backward(dy_ref) # Implementation with fusible operation - op = CustomScaleOp() + op = LearnableScale() forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity()) with torch.no_grad(): op.scale.copy_(w_test) @@ -3502,7 +3505,112 @@ def op_backward( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) - def test_custom_forward_fused_op( + def test_custom_forward_fused_op1( + self, + *, + shape: Iterable[int] = (5, 11), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in forward pass""" + + class ForwardAxpy(te.ops.FusedOperation): + """Custom op that computes BLAS SAXPY in forward pass + + This class is as an example in the op fuser guide at + docs/examples/op_fuser/op_fuser.rst (see "Implementing a + fused operation"). Any changes made to this class should + also be made there. + + """ + + _enabled = True + + def __init__( + self, + scale: te.ops.ConstantScale, + add: te.ops.AddExtraInput, + ) -> None: + super().__init__((scale, add)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + **unused, + ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, ...]]]: + scale_op, add_op = self.basic_ops + extra_input = basic_op_extra_inputs[1][0] # Extra input to add op + out = scale_op.scale * input_ + extra_input + scale_ctx, add_ctx = basic_op_ctxs # No state needed for backward + return ( + out, # Output + [(), ()], # Extra outputs for each basic op + ) + + def fuse_axpy_ops( + ops: list[te.ops.FusibleOperation], + **unused, + ) -> list[te.ops.FusibleOperation]: + """Apply fusion the first time this function is called""" + if ForwardAxpy._enabled: + ForwardAxpy._enabled = False + else: + return ops + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + if isinstance(window[0], te.ops.ConstantScale) and isinstance( + window[1], te.ops.AddExtraInput + ): + window = [ForwardAxpy(*window)] + else: + out.append(window[0]) + window = window[1:] + window, ops = window + ops[:1], ops[1:] + out.extend(window + ops) + return out + + # Random data + scale = 0.5 + x1_ref, x1_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = scale * x1_ref + x2_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_forward_fusion(fuse_axpy_ops) + model = te.ops.Sequential( + te.ops.ConstantScale(scale=scale), + te.ops.AddExtraInput(), + ) + y_test = model(x1_test, x2_test) + y_test.backward(dy_test) + + # Check values + tols = dtype_tols(dtype) + assert_close(y_test, y_ref, **tols) + assert_close_grads(x1_test, x1_ref, **tols) + assert_close_grads(x2_test, x2_ref, **tols) + + def test_custom_forward_fused_op2( self, *, shape: Iterable[int] = (7, 11), diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 9e23bb3fb1..13cb519c19 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -152,7 +152,7 @@ class GELU(_ActivationOperation): \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) - See `Gaussian Error Linear Units (GELUs)`__. + See `Gaussian Error Linear Units (GELUs) `__. """ @@ -183,8 +183,8 @@ class GLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `Language Modeling with Gated Convolutional Networks`__ - and `GLU Variants Improve Transformer`__. + See `Language Modeling with Gated Convolutional Networks `__ + and `GLU Variants Improve Transformer `__. """ @@ -219,7 +219,7 @@ class GEGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + See `GLU Variants Improve Transformer `__. """ @@ -233,8 +233,8 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: class QGELU(_ActivationOperation): r"""Quick Gaussian Error Linear Unit - Quick GELU from `HuggingFace`__ - and `paper`__. + Quick GELU from `HuggingFace `__ + and `paper `__. .. math:: @@ -316,7 +316,7 @@ class ReGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + See `GLU Variants Improve Transformer `__. """ @@ -334,7 +334,7 @@ class SReLU(_ActivationOperation): \text{SReLU}(x) = \max(x^2,0) - See `Primer: Searching for Efficient Transformers for Language Modeling`__. + See `Primer: Searching for Efficient Transformers for Language Modeling `__. """ diff --git a/transformer_engine/pytorch/ops/basic/add_extra_input.py b/transformer_engine/pytorch/ops/basic/add_extra_input.py index 47f2b6e248..fc3ca9cade 100644 --- a/transformer_engine/pytorch/ops/basic/add_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/add_extra_input.py @@ -30,7 +30,7 @@ class AddExtraInput(BasicOperation): feature and most users are discouraged from it. In-place operations break some autograd assumptions and they can result in subtle, esoteric bugs. - Compare to `MakeExtraOutput`, which does a similar operation in + Compare to ``MakeExtraOutput``, which does a similar operation in the backward pass. """ diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index e640f3ffb1..48376a297f 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -48,8 +48,8 @@ def _wait_async(handle: Optional[Any]) -> None: class BasicLinear(BasicOperation): """Apply linear transformation: :math:`y = x A^T` - This is a drop-in replacement for `torch.nn.Linear` with - `bias=False`. + This is a drop-in replacement for ``torch.nn.Linear`` with + ``bias=False``. Parameters ---------- @@ -61,27 +61,27 @@ class BasicLinear(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel_mode : {`None`, "column", "row"}, default = `None` + tensor_parallel_mode : {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group : torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel : bool, default = `False` + sequence_parallel : bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) rng_state_tracker_function : callable - Function that returns `CudaRNGStatesTracker`, which is used + Function that returns ``CudaRNGStatesTracker``, which is used for model-parallel weight initialization - accumulate_into_main_grad : bool, default = `False` + accumulate_into_main_grad : bool, default = False Whether to directly accumulate weight gradients into the - weight's `main_grad` attribute instead of relying on PyTorch - autograd. The weight's `main_grad` must be set externally and - there is no guarantee that `grad` will be set or be - meaningful. This is primarily intented to integrate with + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that ``grad`` will be set or be + meaningful. This is primarily intended to integrate with Megatron-LM. This argument along with weight tensor having - attribute 'overwrite_main_grad' set to True will overwrite - `main_grad` instead of accumulating. + attribute ``overwrite_main_grad`` set to ``True`` will + overwrite ``main_grad`` instead of accumulating. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -184,7 +184,7 @@ def _canonicalize_tensor_parallelism( Parameters ---------- - mode: {`None`, "column", "row"} + mode: {None, "column", "row"} Mode for tensor parallelism process_group: torch.distributed.ProcessGroup Process group for tensor parallelism @@ -200,7 +200,7 @@ def _canonicalize_tensor_parallelism( Returns ------- - mode: {`None`, "column", "row"} + mode: {None, "column", "row"} Mode for tensor parallelism process_group: torch.distributed.ProcessGroup Process group for tensor parallelism @@ -446,18 +446,18 @@ def _functional_forward( Output tensor beta: float, optional Scaling factor applied to original value of out when accumulating into it - accumulate_into_out: bool, default = `False` + accumulate_into_out: bool, default = False Add result to output tensor instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. @@ -465,10 +465,10 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. - input_requires_grad: bool, default = `True` + input_requires_grad: bool, default = True Whether the loss gradient w.r.t. the input tensor is required in the backward pass. - weight_requires_grad: bool, default = `True` + weight_requires_grad: bool, default = True Whether the loss gradient w.r.t. the weight tensor is required in the backward pass. @@ -477,11 +477,11 @@ def _functional_forward( torch.Tensor Output tensor torch.Tensor, optional - Input tensor, ready for use in backward pass. `None` is + Input tensor, ready for use in backward pass. ``None`` is returned if loss gradient w.r.t. the weight tensor is not required. torch.Tensor, optional - Weight tensor, ready for use in backward pass. `None` is + Weight tensor, ready for use in backward pass. ``None`` is returned if loss gradient w.r.t. the input tensor is not required. @@ -682,24 +682,24 @@ def _functional_backward( Loss gradient w.r.t. weight tensor grad_weight_beta: float, optional Scaling factor applied to original value of grad_weight when accumulating into it - accumulate_into_grad_weight: bool, default = `False` + accumulate_into_grad_weight: bool, default = False Add result to weight grad instead of overwriting grad_input: torch.Tensor, optional Loss gradient w.r.t. input tensor grad_input_beta: float, optional Scaling factor applied to original value of grad_input when accumulating into it - accumulate_into_grad_input: bool, default = `False` + accumulate_into_grad_input: bool, default = False Add result to input grad instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 8b60251088..d580f84866 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -18,7 +18,7 @@ class Bias(BasicOperation): """Apply additive bias - This is equivalent to the additive bias in `torch.nn.Linear`. + This is equivalent to the additive bias in ``torch.nn.Linear``. Parameters ---------- @@ -28,7 +28,7 @@ class Bias(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel : bool, default = `False` + tensor_parallel : bool, default = False Whether to distribute input tensor and bias tensors along inner dimension tensor_parallel_group : torch.distributed.ProcessGroup, default = world group diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index eb8a67600d..b44e77b0c6 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -65,7 +65,7 @@ class GroupedLinear(BasicOperation): weight's ``main_grad`` attribute instead of relying on PyTorch autograd. The weight's ``main_grad`` must be set externally and there is no guarantee that `grad` will be set or be - meaningful. This is primarily intented to integrate with + meaningful. This is primarily intended to integrate with Megatron-LM. This argument along with weight tensor having attribute ``overwrite_main_grad`` set to True will overwrite ``main_grad`` instead of accumulating. diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 631f0fafc9..3fda5145c6 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -31,7 +31,7 @@ class LayerNorm(BasicOperation): r"""Layer Normalization Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization `__ + the paper `Layer Normalization `__ . .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta @@ -51,9 +51,9 @@ class LayerNorm(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero - and the calculation changes to + zero_centered_gamma : bool, default = False + If ``True``, the :math:`\gamma` parameter is initialized to + zero and the calculation changes to .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index 61caaaf65d..0d9c870262 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -35,7 +35,7 @@ class MakeExtraOutput(BasicOperation): operations break some autograd assumptions and they can result in subtle, esoteric bugs. - Compare to `AddExtraInput`, which does a similar operation in the + Compare to ``AddExtraInput``, which does a similar operation in the backward pass. """ diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index d126b554b5..fa3efc3807 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -18,14 +18,14 @@ class Quantize(BasicOperation): """Quantize tensor data - Uses recipe from `autocast` context. When called outside - of an `autocast` context, this is an identity operation. + Uses recipe from ``autocast`` context. When called outside + of an ``autocast`` context, this is an identity operation. Parameters ---------- - forward : bool, default = `True` + forward : bool, default = True Perform quantization in forward pass - backward : bool, default = `False` + backward : bool, default = False Perform quantization in backward pass """ diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index f8ae86fecd..4a171c294b 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -20,7 +20,7 @@ class Reshape(BasicOperation): """Reshape tensor - See `torch.reshape`. + See ``torch.reshape``. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 3179d0a447..1d8d8be971 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -32,7 +32,7 @@ class RMSNorm(BasicOperation): Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in the paper - `Root Mean Square Layer Normalization `__ + `Root Mean Square Layer Normalization `__ . .. math:: y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma @@ -50,8 +50,8 @@ class RMSNorm(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero + zero_centered_gamma : bool, default = False + If ``True``, the :math:`\gamma` parameter is initialized to zero and the calculation changes to .. math:: diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index eaffbeee02..b4427df41a 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -46,7 +46,7 @@ class SwiGLU(BasicOperation): The Sigmoid Linear Unit (SiLU) gating function is also known as the swish function. See - ``GLU Variants Improve Transformer``__. + `GLU Variants Improve Transformer `__. Parameters ---------- @@ -189,14 +189,18 @@ def op_backward( class ClampedSwiGLU(BasicOperation): r"""GPT-OSS - Implementation based on ``GPT-OSS``__. + Implementation based on `GPT-OSS `__. This activation has two differences compared to the original SwiGLU 1. Both gate and pre-activations are clipped based on parameter limit. 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is different - from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + .. warning:: + + The input tensor is chunked along the last dimension to get + gates/pre-activations which is different from GPT OSS + implementation where the gates/pre-activations are assumed to + be interleaved in the input tensor. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 90ade030c8..fbaf69d75d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -125,18 +125,18 @@ def _functional_backward( Tensor datatype grad_weight: torch.Tensor, optional Loss gradient w.r.t. weight tensor - accumulate_into_grad_weight: bool, default = `False` + accumulate_into_grad_weight: bool, default = False Add result to weight grad instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 6ef9bf083b..0d3e1d0416 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -115,16 +115,16 @@ def _functional_forward( Tensor device dtype: torch.dtype Tensor datatype - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. @@ -132,10 +132,10 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. - input_requires_grad: bool, default = `True` + input_requires_grad: bool, default = True Whether the loss gradient w.r.t. the input tensor is required in the backward pass. - weight_requires_grad: bool, default = `True` + weight_requires_grad: bool, default = True Whether the loss gradient w.r.t. the weight tensor is required in the backward pass. ub_comm_name: str diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 7fe6ea37ed..bd3bc94b60 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -31,7 +31,7 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: def _is_graph_capturing() -> bool: - """Whether function is called within `make_graphed_callables` + """Whether function is called within ``make_graphed_callables`` Avoid circular import with lazy import. @@ -519,6 +519,8 @@ def register_forward_fusion( The fusion function should have the following signature: + .. code-block:: python + func(ops, *, recipe) -> updated ops Parameters @@ -545,6 +547,8 @@ def register_backward_fusion( The fusion function should have the following signature: + .. code-block:: python + func(ops, *, recipe) -> updated ops Parameters diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index d5829b0c50..c6ca4786b8 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -23,7 +23,7 @@ class Linear(FusedOperation): """Apply linear transformation: :math:`y = x A^T + b` - This is a drop-in replacement for `torch.nn.Linear`. + This is a drop-in replacement for ``torch.nn.Linear``. Parameters ---------- @@ -31,17 +31,17 @@ class Linear(FusedOperation): Inner dimension of input tensor out_features : int Inner dimension of output tensor - bias : bool, default = `True` + bias : bool, default = True Apply additive bias device : torch.device, default = default CUDA device Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel_mode : {`None`, "column", "row"}, default = `None` + tensor_parallel_mode : {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group : torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel : bool, default = `False` + sequence_parallel : bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing @@ -49,12 +49,12 @@ class Linear(FusedOperation): rng_state_tracker_function : callable Function that returns CudaRNGStatesTracker, which is used for model-parallel weight initialization - accumulate_into_main_grad : bool, default = `False` + accumulate_into_main_grad : bool, default = False Whether to directly accumulate weight gradients into the - weight's `main_grad` attribute instead of relying on PyTorch - autograd. The weight's `main_grad` must be set externally and - there is no guarantee that `grad` will be set or be - meaningful. This is primarily intented to integrate with + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally and + there is no guarantee that ``grad`` will be set or be + meaningful. This is primarily intended to integrate with Megatron-LM. """ diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 47286dfced..54b3f00117 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -94,7 +94,7 @@ def fuser_forward( several of this function's arguments are lists of arguments to forward functions of corresponding basic ops. - Called by `OperationFuser`. + Called by ``OperationFuser``. Parameters ---------- @@ -141,7 +141,7 @@ def fuser_backward( several of this function's arguments are lists of arguments to backward functions of corresponding basic ops. - Called by `OperationFuser`. + Called by ``OperationFuser``. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index a0db3cd2d0..592ddae23a 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -15,10 +15,10 @@ class Sequential(torch.nn.Module): - """Sequential container for fusible operations + """Sequential container for fusible operations. - This is a drop-in replacement for `torch.nn.Sequential`, with - support for fusing `FusibleOperation`s. + This is a drop-in replacement for ``torch.nn.Sequential`` with + support for fusing ``FusibleOperation`` s. Parameters ----------