Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Flax
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

For a more comprehensive tutorial, check out our `Quickstart Notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb>`_.
For a more comprehensive tutorial, check out our `Getting Started Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/getting_started.html>`_.

.. overview-end-marker-do-not-remove

Expand Down
6 changes: 3 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr
- **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency.
- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb)
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage.
- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb)
- Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist)

# JAX
Expand All @@ -34,7 +32,9 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr
- Model Parallelism: Divide a model across multiple GPUs for parallel training.
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist)

- [TE JAX Integration Tutorial](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb)
- Introduction to integrating TE into an existing JAX model framework, building a Transformer Layer, and instructions on integrating TE modules like Linear and LayerNorm.

# Third party
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine)
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3.
29 changes: 18 additions & 11 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,12 @@ def hook_fn(
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
outputs_requiring_grad = tuple(
o for o in outputs if o is not None and o.requires_grad
)
torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad),
grad_tensors=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
outputs_requiring_grad,
grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad),
)
grad_inputs = tuple(input.grad for input in inputs)

Expand Down Expand Up @@ -616,27 +617,32 @@ def hook_fn(
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad
(o.shape, o.dtype, o.layout)
for o in static_outputs
if o is not None and o.requires_grad
)
if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(
o for o in static_outputs if o is not None and o.requires_grad
),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
Expand Down Expand Up @@ -719,15 +725,16 @@ def hook_fn(
):
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(o for o in static_outputs if o is not None and o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
Expand Down Expand Up @@ -794,7 +801,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs):
# Replay forward graph
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
return tuple(o.detach() if o is not None else o for o in static_outputs)

@staticmethod
@torch.autograd.function.once_differentiable
Expand Down
Loading