diff --git a/README.rst b/README.rst index 3cc5f81293..5a6721b04c 100644 --- a/README.rst +++ b/README.rst @@ -137,7 +137,7 @@ Flax for _ in range(10): loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp) -For a more comprehensive tutorial, check out our `Quickstart Notebook `_. +For a more comprehensive tutorial, check out our `Getting Started Guide `_. .. overview-end-marker-do-not-remove diff --git a/examples/README.md b/examples/README.md index 004d1631f1..782dc42f58 100644 --- a/examples/README.md +++ b/examples/README.md @@ -23,8 +23,6 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr - **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency. - [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb) - Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage. -- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb) - - Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm. - [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist) # JAX @@ -34,7 +32,9 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr - Model Parallelism: Divide a model across multiple GPUs for parallel training. - Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup. - [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist) - +- [TE JAX Integration Tutorial](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb) + - Introduction to integrating TE into an existing JAX model framework, building a Transformer Layer, and instructions on integrating TE modules like Linear and LayerNorm. + # Third party - [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine) - Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3. diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 37fff943d6..f4b1fb23ae 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -451,11 +451,12 @@ def hook_fn( if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): + outputs_requiring_grad = tuple( + o for o in outputs if o is not None and o.requires_grad + ) torch.autograd.backward( - tuple(o for o in outputs if o.requires_grad), - grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o.requires_grad - ), + outputs_requiring_grad, + grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), ) grad_inputs = tuple(input.grad for input in inputs) @@ -616,19 +617,22 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + (o.shape, o.dtype, o.layout) + for o in static_outputs + if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -636,7 +640,9 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple( + o for o in static_outputs if o is not None and o.requires_grad + ), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -719,7 +725,8 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -727,7 +734,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -794,7 +801,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) - return tuple(o.detach() for o in static_outputs) + return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @torch.autograd.function.once_differentiable