Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
96db8f5
split wgrad for GroupedLinear
lhb8125 Mar 12, 2025
4d3326e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2025
94f1892
support wgrad split for linear and ln_linear
lhb8125 Mar 12, 2025
4cac7d0
add comments and fix WeightGradStore
lhb8125 Apr 3, 2025
981ed83
support bias and fix unit tests
lhb8125 Apr 3, 2025
d5f8376
minor fix
lhb8125 Apr 7, 2025
6c23454
support fuse_grad_accumulation=false
lhb8125 Apr 7, 2025
cb49c70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2025
ba5dc5d
Enable reuse of dummy wgrad tensor (#1651)
vasunvidia Apr 8, 2025
9d4e11e
[PyTorch] Debug GEMM refactor (#1652)
timmoon10 Apr 8, 2025
962d9c5
[JAX] Scaling Enum Abstracting (#1655)
phu0ngng Apr 9, 2025
20e95ba
[PyTorch] Explicitly specify quantized tensor usages needed for linea…
timmoon10 Apr 9, 2025
0da6044
[PyTorch] Debug checkpointing with te.Sequential (#1629)
timmoon10 Apr 9, 2025
76eea17
Merge branch 'main' into hongbinl/split_wgrad_new
ksivaman Apr 10, 2025
a8f0fe0
Blockwise scaling linear quantization recipe (#1559)
kwyss-nvidia Apr 10, 2025
2856c3e
Add user to TE CI (#1669)
ksivaman Apr 11, 2025
d91ed12
add wgrad split for layernorm_mlp
lhb8125 Apr 11, 2025
a8e786c
Merge branch 'main' into hongbinl/split_wgrad_new
lhb8125 Apr 11, 2025
3b38bb4
minor fix
lhb8125 Apr 11, 2025
7ec4182
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
dfb3c48
Make shape cache invalidation more conservative. (#1670)
kwyss-nvidia Apr 11, 2025
04642bf
[PyTorch] Add option in activation ops to cache input in FP8 (#1665)
timmoon10 Apr 11, 2025
c638c43
[QA] Extend error handling (#1660)
linxiddd Apr 12, 2025
d9eb058
[PyTorch] Added attention activation offloading support for TE v2.0 (…
sanandaraj5597 Apr 14, 2025
c8e7cc0
[MoE] Support new fp8 recipes for permute_fusion (#1649)
Autumn1998 Apr 14, 2025
38e18f7
fix unittest
lhb8125 Apr 14, 2025
7aefe67
Merge branch 'main' into hongbinl/split_wgrad_new
lhb8125 Apr 14, 2025
5f16c79
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2025
98b4c0d
[JAX] grouped_gemm() uses variadic arguments (#1658)
huanghua1994 Apr 14, 2025
6117b20
Add experimental Shardy support. (#1642)
jreiffers Apr 14, 2025
4c9626e
[PyTorch][MoE] Enable New Recipes for Grouped Linear (#1525)
yaox12 Apr 14, 2025
48f3ca9
[PyTorch] Avoid unnecessary tensor usages when caching for linear op …
timmoon10 Apr 14, 2025
5fdd7bb
[PyTorch] check and try to generate fp8 weight transpose cache before…
shjwudp Apr 14, 2025
313ab4f
[JAX] Improving the test_multiprocessing_encoder.py run script (#1673)
phu0ngng Apr 15, 2025
aee7883
[PyTorch] Fix for checkpointing for callables. (#1679)
pggPL Apr 15, 2025
66d6afb
[PyTorch] More precise test for the CPU offloading. (#1668)
pggPL Apr 15, 2025
86928e0
Add adam bf16 state with original fp32 kernel (#1640)
BestJuly Apr 15, 2025
0994fb4
Fix #1524 and other softmax mask functionality (#1681)
KshitijLakhani Apr 16, 2025
beaecf8
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 1 – core (#1614)
pggPL Apr 16, 2025
cd03509
add unittest for distributed interface apply Dener's suggestion
lhb8125 Apr 16, 2025
6a00d24
Merge branch 'main' into hongbinl/split_wgrad_new
lhb8125 Apr 16, 2025
92b80ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2025
8ffbbab
README.md - Installation section (#1689)
sbhavani Apr 16, 2025
541acc7
minor fix
lhb8125 Apr 17, 2025
5131080
Merge branch 'main' into hongbinl/split_wgrad_new
lhb8125 Apr 17, 2025
61312d6
[PyTorch] Deprecate the weight offloading (#1678)
pggPL Apr 17, 2025
06306ce
Merge branch 'main' into hongbinl/split_wgrad_new
lhb8125 Apr 17, 2025
a0cabb7
[QA] Add XML log generation for pytest results (#1661)
linxiddd Apr 17, 2025
61f1bf6
Support computing zero-centered gamma in compute dtype for CuDNN (#1690)
jberchtold-nvidia Apr 17, 2025
e61ce77
Allow NVTEShape to own data. (#1674)
kwyss-nvidia Apr 17, 2025
4e036c8
[PyTorch] Move swizzle scaling factor to cpp (#1683)
yaox12 Apr 17, 2025
39c0e70
Re Do symmetric memory merge request (#1682)
wdykas Apr 17, 2025
34c2f8f
Merge branch 'main' into hongbinl/split_wgrad_new
lhb8125 Apr 18, 2025
1a6a6d7
[JAX] Deprecate Praxis layers (#1694)
phu0ngng Apr 18, 2025
0fbb286
replace split_bw with delay_wgrad_compute
lhb8125 Apr 18, 2025
559e9bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2025
7b29265
Update transformer_engine/pytorch/module/layernorm_mlp.py
ksivaman Apr 18, 2025
bfb3d37
Update transformer_engine/pytorch/module/linear.py
ksivaman Apr 18, 2025
c630beb
Update transformer_engine/pytorch/module/layernorm_linear.py
ksivaman Apr 18, 2025
20fd1cd
Merge branch 'main' into hongbinl/split_wgrad_new
ksivaman Apr 18, 2025
2a7087e
remove comments
lhb8125 Apr 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
|| github.actor == 'BestJuly'
|| github.actor == 'xiaopoc'
|| github.actor == 'jreiffers'
|| github.actor == 'lhb8125'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
145 changes: 116 additions & 29 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,30 @@ Flax

Installation
============
.. installation

Pre-requisites
System Requirements
^^^^^^^^^^^^^^^^^^^^
* Linux x86_64
* CUDA 12.1+ (CUDA 12.8+ for Blackwell)
* NVIDIA Driver supporting CUDA 12.1 or later
* cuDNN 9.3 or later

Docker
^^^^^^^^^^^^^^^^^^^^
* **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere

* **OS:** Linux (official), WSL2 (limited support)

* **Software:**

* CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers
* cuDNN: 9.3+
* Compiler: GCC 9+ or Clang 10+ with C++17 support
* Python: 3.12 recommended

* **Source Build Requirements:** CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+

* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)

Installation Methods
^^^^^^^^^^^^^^^^^^^

Docker (Recommended)
^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively,
Expand All @@ -167,41 +179,116 @@ For example to use the NGC PyTorch container interactively,

Where 25.01 (corresponding to January 2025 release) is the container version.

pip
^^^^^^^^^^^^^^^^^^^^
To install the latest stable version of Transformer Engine,
**Benefits of using NGC containers:**

* All dependencies pre-installed with compatible versions and optimized configurations
* NGC PyTorch 23.08+ containers include FlashAttention-2

pip Installation
^^^^^^^^^^^^^^^^^^^

**Prerequisites for pip installation:**

* A compatible C++ compiler
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed

To install the latest stable version with pip:

.. code-block:: bash

pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
# For PyTorch integration
pip install --no-build-isolation transformer_engine[pytorch]

# For JAX integration
pip install --no-build-isolation transformer_engine[jax]

# For both frameworks
pip install --no-build-isolation transformer_engine[pytorch,jax]

Alternatively, install directly from the GitHub repository:

.. code-block:: bash

This will automatically detect if any supported deep learning frameworks are installed and build
Transformer Engine support for them. To explicitly specify frameworks, set the environment variable
NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

Alternatively, the package can be directly installed from
`Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
When installing from GitHub, you can explicitly specify frameworks using the environment variable:

.. code-block:: bash

pip3 install transformer_engine[pytorch]
NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

Source Installation
^^^^^^^^^^^^^^^^^^^

`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_

To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be
explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]).
Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX
and PyTorch extensions.
Environment Variables
^^^^^^^^^^^^^^^^^^^
These environment variables can be set before installation to customize the build process:

From source
^^^^^^^^^^^
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_.
* **CUDA_PATH**: Path to CUDA installation
* **CUDNN_PATH**: Path to cuDNN installation
* **CXX**: Path to C++ compiler
* **NVTE_FRAMEWORK**: Comma-separated list of frameworks to build for (e.g., ``pytorch,jax``)
* **MAX_JOBS**: Limit number of parallel build jobs (default varies by system)
* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job

Compiling with FlashAttention-2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance.
Compiling with FlashAttention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment.

You can verify which FlashAttention version is being used by setting these environment variables:

.. code-block:: bash

NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python your_script.py

It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.

Note that NGC PyTorch 23.08+ containers include FlashAttention-2.
.. troubleshooting-begin-marker-do-not-remove
Troubleshooting
^^^^^^^^^^^^^^^^^^^

**Common Issues and Solutions:**

1. **ABI Compatibility Issues:**

* **Symptoms:** ``ImportError`` with undefined symbols when importing transformer_engine
* **Solution:** Ensure PyTorch and Transformer Engine are built with the same C++ ABI setting. Rebuild PyTorch from source with matching ABI.
* **Context:** If you're using PyTorch built with a different C++ ABI than your system's default, you may encounter these undefined symbol errors. This is particularly common with pip-installed PyTorch outside of containers.

2. **Missing Headers or Libraries:**

* **Symptoms:** CMake errors about missing headers (``cudnn.h``, ``cublas_v2.h``, ``filesystem``, etc.)
* **Solution:** Install missing development packages or set environment variables to point to correct locations:

.. code-block:: bash

export CUDA_PATH=/path/to/cuda
export CUDNN_PATH=/path/to/cudnn

* If CMake can't find a C++ compiler, set the ``CXX`` environment variable.
* Ensure all paths are correctly set before installation.

3. **Build Resource Issues:**

* **Symptoms:** Compilation hangs, system freezes, or out-of-memory errors
* **Solution:** Limit parallel builds:

.. code-block:: bash

MAX_JOBS=1 NVTE_BUILD_THREADS_PER_JOB=1 pip install ...

4. **Verbose Build Logging:**

* For detailed build logs to help diagnose issues:

.. code-block:: bash

cd transformer_engine
pip install -v -v -v --no-build-isolation .

.. troubleshooting-end-marker-do-not-remove

Breaking Changes
================
Expand Down
20 changes: 12 additions & 8 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr

.. code-block:: bash

pip3 install transformer_engine[pytorch]
pip3 install --no-build-isolation transformer_engine[pytorch]

To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.

Expand All @@ -54,7 +54,7 @@ Execute the following command to install the latest stable version of Transforme

.. code-block:: bash

pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`).

Expand All @@ -71,15 +71,15 @@ Execute the following command to install the latest development build of Transfo

.. code-block:: bash

pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@main
pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main

This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`.

In order to install a specific PR, execute (after changing NNN to the PR number):

.. code-block:: bash

pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@refs/pull/NNN/merge
pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@refs/pull/NNN/merge


Installation (from source)
Expand All @@ -93,8 +93,8 @@ Execute the following commands to install Transformer Engine from source:
git clone --branch stable --recursive https://github.com/NVIDIA/TransformerEngine.git

cd TransformerEngine
export NVTE_FRAMEWORK=pytorch # Optionally set framework
pip3 install . # Build and install
export NVTE_FRAMEWORK=pytorch # Optionally set framework
pip3 install --no-build-isolation . # Build and install

If the Git repository has already been cloned, make sure to also clone the submodules:

Expand All @@ -106,10 +106,14 @@ Extra dependencies for testing can be installed by setting the "test" option:

.. code-block:: bash

pip3 install .[test]
pip3 install --no-build-isolation .[test]

To build the C++ extensions with debug symbols, e.g. with the `-g` flag:

.. code-block:: bash

pip3 install . --global-option=--debug
pip3 install --no-build-isolation . --global-option=--debug

.. include:: ../README.rst
:start-after: troubleshooting-begin-marker-do-not-remove
:end-before: troubleshooting-end-marker-do-not-remove
62 changes: 48 additions & 14 deletions examples/jax/encoder/run_test_multiprocessing_encoder.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,54 @@

NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i &
done
wait
# Define the test cases to run
TEST_CASES=(
"test_te_bf16"
"test_te_delayed_scaling_fp8"
"test_te_mxfp8"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
)

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
echo
echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"

HAS_FAILURE=0 # Global failure flag

# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
echo
echo "=== Starting test: $TEST_CASE ..."

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_CASE}_gpu_${i}.log"

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i &
# Run pytest and redirect stdout and stderr to the log file
pytest -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
done

# Wait for the process to finish
wait

# Check and print the log content accordingly
if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
tail -n +7 "${TEST_CASE}_gpu_0.log"
elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "Invalid ${TEST_CASE}_gpu_0.log"
fi

# Remove the log file after processing it
rm ${TEST_CASE}_gpu_*.log
done
wait

exit $HAS_FAILURE
42 changes: 37 additions & 5 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def replace_params(x):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)

train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)

num_gpu = jax.local_device_count()
Expand Down Expand Up @@ -441,20 +443,22 @@ def encoder_parser(args):
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)

return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
def setUp(self):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
self.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
Expand Down Expand Up @@ -503,6 +507,34 @@ def test_te_mxfp8_with_sp(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785

# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
Loading