diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 8c3b84b1b9..f86fdd1066 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -50,6 +50,7 @@ jobs: || github.actor == 'BestJuly' || github.actor == 'xiaopoc' || github.actor == 'jreiffers' + || github.actor == 'lhb8125' ) steps: - name: Check if comment is issued by authorized person diff --git a/README.rst b/README.rst index c4fde5bd11..3313a2625b 100644 --- a/README.rst +++ b/README.rst @@ -145,18 +145,30 @@ Flax Installation ============ -.. installation -Pre-requisites +System Requirements ^^^^^^^^^^^^^^^^^^^^ -* Linux x86_64 -* CUDA 12.1+ (CUDA 12.8+ for Blackwell) -* NVIDIA Driver supporting CUDA 12.1 or later -* cuDNN 9.3 or later -Docker -^^^^^^^^^^^^^^^^^^^^ +* **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere + +* **OS:** Linux (official), WSL2 (limited support) + +* **Software:** + + * CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers + * cuDNN: 9.3+ + * Compiler: GCC 9+ or Clang 10+ with C++17 support + * Python: 3.12 recommended + +* **Source Build Requirements:** CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+ +* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell) + +Installation Methods +^^^^^^^^^^^^^^^^^^^ + +Docker (Recommended) +^^^^^^^^^^^^^^^^^^^ The quickest way to get started with Transformer Engine is by using Docker images on `NVIDIA GPU Cloud (NGC) Catalog `_. For example to use the NGC PyTorch container interactively, @@ -167,41 +179,116 @@ For example to use the NGC PyTorch container interactively, Where 25.01 (corresponding to January 2025 release) is the container version. -pip -^^^^^^^^^^^^^^^^^^^^ -To install the latest stable version of Transformer Engine, +**Benefits of using NGC containers:** + +* All dependencies pre-installed with compatible versions and optimized configurations +* NGC PyTorch 23.08+ containers include FlashAttention-2 + +pip Installation +^^^^^^^^^^^^^^^^^^^ + +**Prerequisites for pip installation:** + +* A compatible C++ compiler +* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed + +To install the latest stable version with pip: .. code-block:: bash - pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable + # For PyTorch integration + pip install --no-build-isolation transformer_engine[pytorch] + + # For JAX integration + pip install --no-build-isolation transformer_engine[jax] + + # For both frameworks + pip install --no-build-isolation transformer_engine[pytorch,jax] + +Alternatively, install directly from the GitHub repository: + +.. code-block:: bash -This will automatically detect if any supported deep learning frameworks are installed and build -Transformer Engine support for them. To explicitly specify frameworks, set the environment variable -NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). + pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable -Alternatively, the package can be directly installed from -`Transformer Engine's PyPI `_, e.g. +When installing from GitHub, you can explicitly specify frameworks using the environment variable: .. code-block:: bash - pip3 install transformer_engine[pytorch] + NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + +Source Installation +^^^^^^^^^^^^^^^^^^^ + +`See the installation guide `_ -To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be -explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). -Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX -and PyTorch extensions. +Environment Variables +^^^^^^^^^^^^^^^^^^^ +These environment variables can be set before installation to customize the build process: -From source -^^^^^^^^^^^ -`See the installation guide `_. +* **CUDA_PATH**: Path to CUDA installation +* **CUDNN_PATH**: Path to cuDNN installation +* **CXX**: Path to C++ compiler +* **NVTE_FRAMEWORK**: Comma-separated list of frameworks to build for (e.g., ``pytorch,jax``) +* **MAX_JOBS**: Limit number of parallel build jobs (default varies by system) +* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job -Compiling with FlashAttention-2 -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance. +Compiling with FlashAttention +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment. + +You can verify which FlashAttention version is being used by setting these environment variables: + +.. code-block:: bash + + NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python your_script.py It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. -Note that NGC PyTorch 23.08+ containers include FlashAttention-2. +.. troubleshooting-begin-marker-do-not-remove +Troubleshooting +^^^^^^^^^^^^^^^^^^^ + +**Common Issues and Solutions:** + +1. **ABI Compatibility Issues:** + + * **Symptoms:** ``ImportError`` with undefined symbols when importing transformer_engine + * **Solution:** Ensure PyTorch and Transformer Engine are built with the same C++ ABI setting. Rebuild PyTorch from source with matching ABI. + * **Context:** If you're using PyTorch built with a different C++ ABI than your system's default, you may encounter these undefined symbol errors. This is particularly common with pip-installed PyTorch outside of containers. + +2. **Missing Headers or Libraries:** + + * **Symptoms:** CMake errors about missing headers (``cudnn.h``, ``cublas_v2.h``, ``filesystem``, etc.) + * **Solution:** Install missing development packages or set environment variables to point to correct locations: + + .. code-block:: bash + + export CUDA_PATH=/path/to/cuda + export CUDNN_PATH=/path/to/cudnn + + * If CMake can't find a C++ compiler, set the ``CXX`` environment variable. + * Ensure all paths are correctly set before installation. + +3. **Build Resource Issues:** + + * **Symptoms:** Compilation hangs, system freezes, or out-of-memory errors + * **Solution:** Limit parallel builds: + + .. code-block:: bash + + MAX_JOBS=1 NVTE_BUILD_THREADS_PER_JOB=1 pip install ... + +4. **Verbose Build Logging:** + + * For detailed build logs to help diagnose issues: + + .. code-block:: bash + + cd transformer_engine + pip install -v -v -v --no-build-isolation . + +.. troubleshooting-end-marker-do-not-remove Breaking Changes ================ diff --git a/docs/installation.rst b/docs/installation.rst index 10046d6306..d0d6cf96d2 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -34,7 +34,7 @@ Transformer Engine can be directly installed from `our PyPI "$LOG_FILE" 2>&1 & + done + + # Wait for the process to finish + wait + + # Check and print the log content accordingly + if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then + HAS_FAILURE=1 + echo "... $TEST_CASE FAILED" + tail -n +7 "${TEST_CASE}_gpu_0.log" + elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE SKIPPED" + elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE PASSED" + else + echo "Invalid ${TEST_CASE}_gpu_0.log" + fi + + # Remove the log file after processing it + rm ${TEST_CASE}_gpu_*.log done -wait + +exit $HAS_FAILURE diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7e6605c9fe..0577787b7c 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -258,6 +258,8 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) num_gpu = jax.local_device_count() @@ -441,6 +443,9 @@ def encoder_parser(args): parser.add_argument( "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." ) + parser.add_argument( + "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." + ) return parser.parse_args(args) @@ -448,13 +453,12 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) - @classmethod - def setUpClass(cls): + def setUp(self): """Run 3 epochs for testing""" - cls.args = encoder_parser(["--epochs", "3"]) + self.args = encoder_parser(["--epochs", "3"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -503,6 +507,34 @@ def test_te_mxfp8_with_sp(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.455 and actual[1] > 0.785 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + def test_te_bf16_shardy(self): + """Test Transformer Engine with BF16""" + self.args.enable_shardy = True + actual = train_and_evaluate(self.args) + assert actual[0] < 0.455 and actual[1] > 0.785 + + @unittest.skipIf(not is_fp8_supported, fp8_reason) + def test_te_delayed_scaling_fp8_shardy(self): + """Test Transformer Engine with DelayedScaling FP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "DelayedScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.455 and actual[1] > 0.785 + + @unittest.skipIf(not is_fp8_supported, fp8_reason) + def test_te_delayed_scaling_fp8_with_sp_shardy(self): + """Test Transformer Engine with DelayedScaling FP8 + SP""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "DelayedScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.455 and actual[1] > 0.785 + + # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index ba62d964fa..c196692757 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -238,6 +238,7 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) num_gpu = jax.local_device_count() @@ -409,6 +410,9 @@ def encoder_parser(args): default="DelayedScaling", help="Use FP8 recipe (default: DelayedScaling)", ) + parser.add_argument( + "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." + ) return parser.parse_args(args) @@ -416,13 +420,12 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) - @classmethod - def setUpClass(cls): + def setUp(self): """Run 3 epochs for testing""" - cls.args = encoder_parser(["--epochs", "3"]) + self.args = encoder_parser(["--epochs", "3"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -446,6 +449,24 @@ def test_te_mxfp8(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.535 and actual[1] > 0.73 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + def test_te_bf16_shardy(self): + """Test Transformer Engine with BF16""" + self.args.enable_shardy = True + actual = train_and_evaluate(self.args) + assert actual[0] < 0.535 and actual[1] > 0.73 + + @unittest.skipIf(not is_fp8_supported, fp8_reason) + def test_te_delayed_scaling_fp8_shardy(self): + """Test Transformer Engine with DelayedScaling FP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "DelayedScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.535 and actual[1] > 0.73 + + # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index a2b160b522..352160a8ed 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -343,6 +343,7 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) if args.process_id == 0: nltk.download("punkt_tab") @@ -565,6 +566,9 @@ def encoder_parser(args): default=0, help="the ID number of the current process (default: 0)", ) + parser.add_argument( + "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." + ) return parser.parse_args(args) @@ -573,7 +577,7 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - def exec(self, use_fp8, fp8_recipe): + def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): """Run 3 epochs for testing""" args = encoder_parser([]) @@ -589,6 +593,7 @@ def exec(self, use_fp8, fp8_recipe): args.num_process = num_gpu args.process_id = self.process_id args.fp8_recipe = fp8_recipe + args.enable_shardy = enable_shardy return train_and_evaluate(args) @@ -604,7 +609,7 @@ def test_te_bf16(self): def test_te_delayed_scaling_fp8(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling") - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.505 and result[1] > 0.754 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -614,6 +619,22 @@ def test_te_mxfp8(self): result = self.exec(True, "MXFP8BlockScaling") assert result[0] < 0.505 and result[1] > 0.754 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + def test_te_bf16_shardy(self): + """Test Transformer Engine with BF16""" + result = self.exec(False, None, enable_shardy=True) + assert result[0] < 0.505 and result[1] > 0.755 + + @unittest.skipIf( + not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" + ) + def test_te_delayed_scaling_fp8_shardy(self): + """Test Transformer Engine with DelayedScaling FP8""" + result = self.exec(True, "DelayedScaling", enable_shardy=True) + assert result[0] < 0.505 and result[1] > 0.754 + + # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 1300be01bb..1783ca8177 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -327,13 +327,12 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) - @classmethod - def setUpClass(cls): - """Run 4 epochs for testing""" - cls.args = encoder_parser(["--epochs", "3"]) + def setUp(self): + """Run 3 epochs for testing""" + self.args = encoder_parser(["--epochs", "3"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 4022cb7493..435750a1db 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -306,8 +306,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 3253861484..92434c28ea 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -17,13 +17,17 @@ RET=0 FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" +wait +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" +wait . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" if [ $RET -ne 0 ]; then diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 7989eaf528..6ffc5945a2 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -2,6 +2,8 @@ # # See LICENSE for license information. +set -x + function error_exit() { echo "Error: $1" exit 1 @@ -18,20 +20,23 @@ FAILED_CASES="" pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" + : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" # Test without custom calls -NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py without TE custom calls" +NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 1206012195..8e37a83dea 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -19,29 +19,31 @@ FAILED_CASES="" set -x : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" -NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 96c5949a99..5deb77af91 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -5,5 +5,7 @@ set -xe : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 36d491ecd3..03997489e8 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -17,16 +17,18 @@ RET=0 FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" -# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential -python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh index 1737ca9ba1..edf3f2eb84 100644 --- a/qa/L1_pytorch_thunder_integration/test.sh +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -5,9 +5,11 @@ set -x : ${THUNDER_PATH:=/opt/pytorch/lightning-thunder} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 -python3 -m pytest -v -s ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py # Check return code # Note: Return code 5 is fine. Lightning tests are skipped on systems diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index ec651a1317..07eb0fc8f1 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -2,22 +2,45 @@ # # See LICENSE for license information. -set -xe +set -x -pip install "nltk>=3.8.2" -pip install pytest==8.2.1 +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" + +pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" # Test without custom calls -NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py +NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" -pip install -r $TE_PATH/examples/jax/mnist/requirements.txt -pip install -r $TE_PATH/examples/jax/encoder/requirements.txt +pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" +pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 3e83ef7f52..a42ec035e8 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -5,6 +5,8 @@ set -e : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 @@ -37,6 +39,6 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py done diff --git a/setup.py b/setup.py index e1977601f5..97fb292c51 100644 --- a/setup.py +++ b/setup.py @@ -110,12 +110,15 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: install_reqs.extend(["torch>=2.1"]) + install_reqs.append( + "nvdlfw-inspect @" + " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" + ) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) - # test_reqs.extend(["numpy", "praxis"]) test_reqs.extend(["numpy"]) return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index cc27f72769..10b52e065f 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -19,6 +19,12 @@ using namespace test; namespace { +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + constexpr size_t kBlockLen = 128; enum ProcessingMethod { @@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector ref_output = std::make_unique(rows * cols); @@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector(&input, fill_case); fillUniform(&grad); + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); Tensor workspace; switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, Tensor input("input", shape, itype); Tensor grad("grad", shape, itype); Tensor output_c("output_c", shape, otype, rowwise, colwise, - opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); Tensor output_dbias("output_dbias", {cols}, itype); std::unique_ptr ref_output = std::make_unique(rows * cols); @@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, fillUniform(&grad); Tensor workspace; + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { - {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, - {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, + {1, 16}, {65, 96}, {256, 256}, {993, 512}, + {256, 65536}, {4096, 1632}, {1024, 1}, + {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, }; std::vector input_scenarios = { @@ -429,6 +441,8 @@ std::vector Activation_types = { std::vector amax_epsilons = { 0.0f, + 1.0f, // Make large to be observable. + }; } // namespace @@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); @@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 0004c2ce74..a0ca938fbf 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -18,159 +18,16 @@ #include #include #include "../test_common.h" +#include "test_normalization.h" using namespace transformer_engine; using namespace test; namespace { -enum NormType { - LayerNorm, - RMSNorm -}; - -std::map normToString = { - {NormType::LayerNorm, "LayerNorm"}, - {NormType::RMSNorm, "RmsNorm"} -}; - -template -void compute_ref_stats(NormType norm_type, - const InputType *data, float *mu, float *rsigma, - const size_t N, const size_t H, const double epsilon){ - using compute_t = float; - compute_t current, m; - for (size_t i = 0; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - sum += static_cast(data[i * H + j]); - } - if (norm_type == LayerNorm){ - mu[i] = sum / H; - m = mu[i]; - } else { m = 0;} - - compute_t sum_sq = 0; - for (size_t j = 0; j < H; ++j) { - current = static_cast(data[i * H + j]); - sum_sq += (current - m) * (current - m); - } - rsigma[i] = rsqrtf((sum_sq / H) + epsilon); - } -} - -// For now, cudnn does static_cast(gamma + static_cast(1.0)) -// This will be changed in the future release -template -inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){ - - using compute_t = float; - if constexpr (std::is_same_v || std::is_same_v){ - compute_t g = static_cast(gamma); - if (zero_centered_gamma) { - g += static_cast(1.f); - } - return g; - } else { - if (use_cudnn){ - compute_t g = static_cast(0.f); - InputType gi = gamma; - if (zero_centered_gamma) { - gi = gi + static_cast(1.f); - } - g = static_cast(gi); - return g; - } else { - compute_t g = static_cast(gamma); - if (zero_centered_gamma) { - g += static_cast(1.f); - } - return g; - } - } -} - -template -void compute_ref_output(NormType norm_type, - const InputType *data, const InputType *gamma, const InputType *beta, - OutputType* output, - const float *mu, const float *rsigma, - const size_t N, const size_t H, - float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); - - compute_t tmp; - if (norm_type == LayerNorm) { - tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); - } else { // RMSNorm - tmp = current * rsigma[i] * g; - } - - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - - -template -void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, - const float *mu, const float *rsigma, - const InputType *gamma, - InputType *data_grad, - InputType *gamma_grad, InputType *beta_grad, - const size_t N, const size_t H, - const bool zero_centered_gamma, const bool use_cudnn) { - using compute_t = float; - std::vector dgamma(H, 0.f); - std::vector dbeta(H, 0.f); - - for (size_t i = 0 ; i < N; ++i) { - // Reductions - auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; - compute_t mdy = 0, mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - local_mu) * rsigma[i]; - compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - if (norm_type == LayerNorm) { - dbeta[j] += dz; - mdy += dy; - } - mdyy += dy * y; - } - mdy /= H; - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - local_mu) * rsigma[i]; - compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); - if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); -} - template void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, - NormType norm_type, bool use_cudnn) { + NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) { if (sizeof(InputType) < sizeof(OutputType)) { GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; return; @@ -219,9 +76,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); + if ((!use_cudnn || !zero_centered_gamma) && zero_centered_gamma_in_weight_dtype) { + // Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation + GTEST_SKIP() << "Zero-centered gamma in weight dtype is only supported with cuDNN backend"; + } + if (use_cudnn){ nvte_enable_cudnn_norm_fwd(true); nvte_enable_cudnn_norm_bwd(true); + + + // Zero-centered gamma in weight dtype only supported by CuDNN backend currently + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(true); + } else { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } } // Forward kernel @@ -269,6 +139,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, if (use_cudnn){ nvte_enable_cudnn_norm_fwd(false); nvte_enable_cudnn_norm_bwd(false); + + // Zero-centered gamma in weight dtype only supported by CuDNN backend currently + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } } // Reference implementations @@ -289,14 +164,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, &ref_amax, ref_scale, zero_centered_gamma, - use_cudnn); + use_cudnn, + zero_centered_gamma_in_weight_dtype); compute_ref_backward(norm_type, dz.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), mu.rowwise_cpu_dptr(), rsigma.rowwise_cpu_dptr(), gamma.rowwise_cpu_dptr(), ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), N, H, zero_centered_gamma, - use_cudnn); + use_cudnn, + zero_centered_gamma_in_weight_dtype); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -341,6 +218,7 @@ NormType, transformer_engine::DType, transformer_engine::DType, std::pair, + bool, bool>> {}; TEST_P(NormTestSuite, TestNorm) { @@ -353,10 +231,11 @@ TEST_P(NormTestSuite, TestNorm) { const DType output_type = std::get<3>(GetParam()); const auto size = std::get<4>(GetParam()); const bool zero_centered_gamma = std::get<5>(GetParam()); + const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn); + performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype); ); ); } @@ -370,6 +249,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), ::testing::ValuesIn(test_cases), + ::testing::Values(false, true), ::testing::Values(false, true)), [](const testing::TestParamInfo& info) { auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; @@ -380,6 +260,7 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<3>(info.param)) + "X" + std::to_string(std::get<4>(info.param).first) + "X" + std::to_string(std::get<4>(info.param).second) + "X" + - std::to_string(std::get<5>(info.param)); + std::to_string(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)); return name; }); diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h new file mode 100644 index 0000000000..368ffa66c9 --- /dev/null +++ b/tests/cpp/operator/test_normalization.h @@ -0,0 +1,178 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + + #pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include "../test_common.h" + +namespace test { +namespace { + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RmsNorm"} +}; + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + compute_t current, m; + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +template +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { + + using compute_t = float; + + // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently + // Remove the use_cudnn check here when it is supported by both backends. + const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; + + if constexpr (std::is_same_v || std::is_same_v){ + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } else { + if (zero_centered_gamma_in_weight_dtype){ + compute_t g = static_cast(0.f); + InputType gi = gamma; + if (zero_centered_gamma) { + gi = gi + static_cast(1.f); + } + g = static_cast(gi); + return g; + } else { + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + OutputType* output, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = static_cast(tmp * scale); + current_max = fmaxf(current_max, fabsf(tmp)); + } + } + + if (amax) { + *amax = current_max; + } +} + + +template +void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, + const float *mu, const float *rsigma, + const InputType *gamma, + InputType *data_grad, + InputType *gamma_grad, InputType *beta_grad, + const size_t N, const size_t H, + const bool zero_centered_gamma, const bool use_cudnn, + const bool cudnn_zero_centered_gamma_in_weight_dtype) { + using compute_t = float; + std::vector dgamma(H, 0.f); + std::vector dbeta(H, 0.f); + + for (size_t i = 0 ; i < N; ++i) { + // Reductions + auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; + compute_t mdy = 0, mdyy = 0; + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + dgamma[j] += y * dz; + if (norm_type == LayerNorm) { + dbeta[j] += dz; + mdy += dy; + } + mdyy += dy * y; + } + mdy /= H; + mdyy /= H; + + // Input grads + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); + data_grad[i * H + j] = static_cast(dx); + } + } + + // Weight grads + for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); + if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); +} + +} // namespace +} // namespace test diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index 191c62835b..4d0cf86034 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -19,6 +19,7 @@ #include #include #include "../test_common.h" +#include "test_normalization.h" using namespace transformer_engine; using namespace test; @@ -27,16 +28,6 @@ namespace { using fp8e8m0 = byte; -enum NormType { - LayerNorm, - RMSNorm -}; - -std::map normToString = { - {NormType::LayerNorm, "LayerNorm"}, - {NormType::RMSNorm, "RMSNorm"} -}; - template void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr, size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){ @@ -110,65 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training) 32, 1); } -template -void compute_ref_stats(NormType norm_type, - const InputType *data, float *mu, float *rsigma, - const size_t N, const size_t H, const double epsilon){ - using compute_t = float; - - #pragma omp parallel for proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - sum += static_cast(data[i * H + j]); - } - compute_t m; - if (norm_type == LayerNorm){ - mu[i] = sum / H; - m = mu[i]; - } else { m = 0;} - - compute_t sum_sq = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum_sq += (current - m) * (current - m); - } - rsigma[i] = rsqrtf((sum_sq / H) + epsilon); - } -} - -template -void compute_ref_output(NormType norm_type, - const InputType *data, const InputType *gamma, const InputType *beta, - const float *mu, const float *rsigma, - const size_t N, const size_t H, - OutputType* output, - const bool zero_centered_gamma){ - using compute_t = float; - - #pragma omp parallel for proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1.0; - } - - compute_t tmp; - if (norm_type == LayerNorm) { - tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); - } else { // RMSNorm - tmp = current * rsigma[i] * g; - } - - output[i * H + j] = tmp; - } - } -} - template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) { +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training, const bool zero_centered_gamma_in_weight_dtype) { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); @@ -195,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, fillUniform(&gamma); fillUniform(&beta); + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(true); + } else { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } + // Forward kernel float epsilon = 1e-5; if (norm_type == NormType::LayerNorm){ @@ -220,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, 0); } + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } + Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); dequantize_2x(z, dequantized_output, is_training); @@ -246,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, compute_ref_output(norm_type, input.rowwise_cpu_dptr(), gamma.rowwise_cpu_dptr(), beta.rowwise_cpu_dptr(), + ref_output.get(), ref_mu_ptr, ref_rsigma_ptr, N, H, - ref_output.get(), - zero_centered_gamma); + nullptr, // amax + 1.f, // scale + zero_centered_gamma, + true, // CuDNN is the only MXFP8 backend currently + zero_centered_gamma_in_weight_dtype); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -298,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple, - bool, bool>> {}; + bool, bool, bool>> {}; TEST_P(MxNormTestSuite, TestMxNorm) { using namespace transformer_engine; @@ -310,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) { const auto size = std::get<3>(GetParam()); const bool zero_centered_gamma = std::get<4>(GetParam()); const bool is_training = std::get<5>(GetParam()); + const bool zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma, norm_type, is_training); + performTest(size.first, size.second, zero_centered_gamma, norm_type, is_training, zero_centered_gamma_in_weight_dtype); ); ); } @@ -327,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), ::testing::ValuesIn(test_cases), ::testing::Values(true, false), + ::testing::Values(true, false), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = normToString.at(std::get<0>(info.param)) + "_" + @@ -335,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<3>(info.param).first) + "X" + std::to_string(std::get<3>(info.param).second) + "X" + std::to_string(std::get<4>(info.param)) + "out" + - std::to_string(int(std::get<5>(info.param)) + 1) + "x"; + std::to_string(int(std::get<5>(info.param)) + 1) + "x" + + std::to_string(std::get<6>(info.param)); return name; }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 071c2186e0..0977c512cb 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -112,8 +112,8 @@ struct scale_inv_meta { size_t type_size; }; -NVTEShape convertShape(const std::vector& shape) { - return {shape.data(), shape.size()}; +NVTEShape convertShape(const std::vector& s) { + return nvte_make_shape(s.data(), s.size()); } std::pair get_scales(const NVTEShape& shape, @@ -216,8 +216,7 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode, - const QuantizationOptions* q_opts) { + const NVTEScalingMode &scaling_mode) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -241,7 +240,7 @@ Tensor::Tensor(const std::string& name, std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), shape.data[shape.ndim - 1]}; NVTEShape normalized_shape = convertShape(normalized_shape_v); - NVTEShape columnwise_shape{nullptr, 0}; + NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { @@ -258,8 +257,7 @@ Tensor::Tensor(const std::string& name, } if (columnwise) { - columnwise_shape.data = columnwise_shape_vec.data(); - columnwise_shape.ndim = columnwise_shape_vec.size(); + columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); } tensor_ = TensorWrapper(scaling_mode); @@ -328,10 +326,6 @@ Tensor::Tensor(const std::string& name, tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } - if (q_opts != nullptr) { - NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation."); - NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation."); - } } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 08df3cf7d1..5e01dacc0a 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,29 +95,21 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; -struct QuantizationOptions { - bool force_pow_2_scales = false; - float amax_epsilon = 0.0; - size_t block_scaling_dim = 2u; -}; - class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {} Tensor() {} diff --git a/tests/jax/pytest.ini b/tests/jax/pytest.ini index 1e835b2187..70d4188c5f 100644 --- a/tests/jax/pytest.ini +++ b/tests/jax/pytest.ini @@ -25,3 +25,5 @@ filterwarnings= ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning ignore:The host_callback APIs are deprecated .*:DeprecationWarning ignore:Scan loop is disabled for fused ring attention.*:UserWarning + ignore:jax.extend.ffi.register_ffi_target is deprecated + ignore:jax.extend.ffi.ffi_lowering is deprecated diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4dc07a2eea..8917e92465 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -48,21 +48,21 @@ LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) + supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING) + supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) def is_shape_supported_by_mxfp8(input_shape): try: if isinstance(input_shape, type(pytest.param(0))): input_shape = input_shape.values[0] - ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) + ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True except: # get_scale_shapes will raise an exception if the shape is not supported @@ -170,7 +170,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, ) quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) @@ -198,7 +198,7 @@ def test_act_forward_with_delayed_scaling_fp8( te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=q_layout, ) @@ -223,7 +223,7 @@ def test_act_forward_with_block_scaling_fp8( self.activation_type = activation_type quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) output = tex.act_lu(x, activation_type, quantizer) @@ -345,7 +345,7 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_layout=q_layout, ) @@ -420,7 +420,7 @@ def test_norm_forward_with_delayed_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_layout=q_layout, ) @@ -437,7 +437,7 @@ def test_norm_forward_with_block_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_layout=QuantizeLayout.ROWWISE_COLWISE, ) @@ -493,7 +493,7 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -533,7 +533,7 @@ class TestFusedQuantize: def test_quantize_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis ): - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( + if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") @@ -618,7 +618,7 @@ def test_quantize_dact_dbias_no_quantization( in_dtype=in_dtype, input_shape=input_shape, out_dtype=in_dtype, - scaling_mode=ScalingMode.NVTE_NO_SCALING, + scaling_mode=ScalingMode.NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=QuantizeLayout.ROWWISE, @@ -639,7 +639,7 @@ def test_quantize_dact_dbias_delayed_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -670,7 +670,7 @@ def test_quantize_dact_dbias_mxfp8_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -785,7 +785,7 @@ def ref_func(x, w, bias, data_layout): scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) @@ -830,7 +830,7 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): Test layernorm_dense VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -886,7 +886,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, @@ -916,7 +916,7 @@ def test_layernorm_mlp_grad( Test layernorm_mlp VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -993,7 +993,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index bb7f83b319..ecca5ab322 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -48,31 +48,7 @@ def generate_collectives_count_ref( # for loss and dbias return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "data_shape", - [ - pytest.param((32, 512, 12, 64), id="32-512-12-64"), - pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), - ], - ) - @pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), - ], - ) - @pytest.mark.parametrize( - "attn_mask_type", - [ - pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), - pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), - ], - ) - @pytest.mark.parametrize("dtype", DTYPES) - def test_self_attn( + def impl_test_self_attn( self, device_count, mesh_shape, @@ -83,7 +59,9 @@ def test_self_attn( bias_shape, attn_mask_type, dtype, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) dropout_prob = 0.0 is_training = True @@ -137,6 +115,80 @@ def test_self_attn( ) runner.test_backward() + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize( + "data_shape", + [ + pytest.param((32, 512, 12, 64), id="32-512-12-64"), + pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), + ], + ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + ], + ) + @pytest.mark.parametrize("dtype", DTYPES) + def test_self_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_bias_type, + bias_shape, + attn_mask_type, + dtype, + ): + self.impl_test_self_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_bias_type, + bias_shape, + attn_mask_type, + dtype, + use_shardy=False, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + ], + ) + def test_self_attn_shardy( + self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape + ): + data_shape = (32, 512, 12, 64) + self.impl_test_self_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_bias_type, + bias_shape, + AttnMaskType.PADDING_MASK, + jnp.bfloat16, + use_shardy=True, + ) + class TestDistributedCrossAttn: @@ -203,37 +255,23 @@ def test_cross_attn( runner.test_backward() -@pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() -) -@pytest.mark.parametrize( - "data_shape", - [ - # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. - pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), - pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), - ], -) -@pytest.mark.parametrize("kv_groups", [1, 8]) -@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) -@pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - [ - pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), - pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), - pytest.param( - QKVLayout.THD_THD_THD, - AttnMaskType.PADDING_CAUSAL_MASK, - id="THD_SEPARATE-PADDING_CAUSAL", - ), - ], -) -@pytest.mark.parametrize( - "load_balanced", - [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], -) +DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [ + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), + pytest.param( + QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL" + ), +] + +DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [ + # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), + pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), +] + + class TestDistributedContextParallelSelfAttn: def impl_test_context_parallel_attn( @@ -249,7 +287,23 @@ def impl_test_context_parallel_attn( qkv_layout, load_balanced, cp_strategy, + use_shardy, + use_scan_ring=False, ): + if qkv_layout.is_thd(): + if cp_strategy == CPStrategy.ALL_GATHER: + pytest.skip("THD doesn't support all gather context parallelism.") + if not load_balanced and cp_strategy == CPStrategy.RING: + pytest.skip("THD + ring doesn't support unbalanced context parallelism.") + + assert not use_scan_ring or cp_strategy == CPStrategy.RING + + if use_scan_ring: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" + else: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" + + jax.config.update("jax_use_shardy_partitioner", use_shardy) attn_bias_type = AttnBiasType.NO_BIAS bias_shape = None dropout_prob = 0.0 @@ -324,7 +378,58 @@ def check_has_backend_for_mask(mask_type): pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") runner.test_backward() + del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] + + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + def test_context_parallel_allgather_attn_shardy( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_mask_type, + dtype, + qkv_layout, + ): + kv_groups = 8 + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced=True, + cp_strategy=CPStrategy.ALL_GATHER, + use_shardy=True, + ) + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], + ) def test_context_parallel_allgather_attn( self, device_count, @@ -338,9 +443,7 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, ): - if qkv_layout.is_thd(): - pytest.skip("THD doesn't support all gather context parallelism.") - return self.impl_test_context_parallel_attn( + self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -352,8 +455,23 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, + use_shardy=False, ) + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], + ) @pytest.mark.parametrize( "use_scan", [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], @@ -372,14 +490,6 @@ def test_context_parallel_ring_attn( load_balanced, use_scan, ): - if use_scan: - os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" - else: - os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" - - if qkv_layout.is_thd() and not load_balanced: - pytest.skip("THD + ring doesn't support unbalanced context parallelism.") - self.impl_test_context_parallel_attn( device_count, mesh_shape, @@ -392,9 +502,46 @@ def test_context_parallel_ring_attn( qkv_layout, load_balanced, CPStrategy.RING, + use_shardy=False, + use_scan_ring=use_scan, + ) + + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + def test_context_parallel_ring_attn_shardy( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_mask_type, + dtype, + qkv_layout, + ): + kv_groups = 8 + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced=True, + cp_strategy=CPStrategy.RING, + use_shardy=False, + use_scan_ring=True, ) - del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] - return class TestReorderCausalLoadBalancing: diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 6d4cde364f..0358a2a2e3 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -29,7 +29,7 @@ } is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: @@ -86,6 +86,7 @@ def generate_collectives_count_ref( @pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_layernorm( self, device_count, @@ -97,7 +98,9 @@ def test_layernorm( zero_centered_gamma, shard_weights, fp8_recipe, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "layernorm" q_dtype = jnp.float8_e4m3fn @@ -168,6 +171,7 @@ def ref_func(x, gamma, beta): @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_rmsnorm( self, device_count, @@ -178,7 +182,9 @@ def test_rmsnorm( dtype, shard_weights, fp8_recipe, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "rmsnorm" q_dtype = jnp.float8_e4m3fn diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 4350d5e8f3..f97f264245 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -36,7 +36,7 @@ is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: @@ -144,16 +144,10 @@ def layernorm_fp8_mlp_prim_func( ) ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - def test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + def _test_layernorm_mlp_grad( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -257,9 +251,59 @@ def test_layernorm_mlp_grad( err_msg=f"multi_grads[{i}] is not close", ) + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + def test_layernorm_mlp_grad( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + ): + self._test_layernorm_mlp_grad( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + use_shardy=False, + ) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + def test_layernorm_mlp_grad_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype + ): + # We don't test block scaling with Shardy because at the time of writing, + # it is not supported in JAX's scaled_matmul_stablehlo. + self._test_layernorm_mlp_grad( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe=recipe.DelayedScaling(), + use_shardy=True, + ) + def _test_layernorm_mlp( - self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8, + fp8_recipe, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -322,9 +366,19 @@ def _test_layernorm_mlp( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): + @pytest_parametrize_wrapper("use_shardy", [False, True]) + def test_layernorm_mlp_layer( + self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy + ): self._test_layernorm_mlp( - mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=False, + fp8_recipe=None, + use_shardy=use_shardy, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -345,4 +399,5 @@ def test_layernorm_mlp_layer_fp8( dtype, use_fp8=True, fp8_recipe=fp8_recipe, + use_shardy=False, ) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 30a9fd53ea..cb30c34abc 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -28,14 +28,16 @@ def generate_collectives_count_ref(self): all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding): + def generate_inputs( + self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask + ): batch, _, sqelen, _ = shape x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: mask = make_causal_mask(batch, sqelen) else: - mask = make_self_mask(batch, sqelen) + mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen) if not bad_sharding: x_pspec = PartitionSpec( @@ -45,7 +47,11 @@ def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_shardin x_pspec = PartitionSpec( mesh_resource.dp_resource, None, None, mesh_resource.tp_resource ) - mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) + + if broadcast_batch_mask: + mask_pspec = PartitionSpec(None, None, None, None) + else: + mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) return (x, mask), (x_pspec, mask_pspec) @@ -67,16 +73,7 @@ def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): output = jax.nn.softmax(x * scale_factor) return jnp.mean(output) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) - @pytest.mark.parametrize( - "softmax_type", - [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], - ) - @pytest.mark.parametrize("scale_factor", [1.0, 3.0]) - @pytest.mark.parametrize("dtype", DTYPES) - @pytest.mark.parametrize("bad_sharding", [False, True]) - def test_softmax( + def impl_test_softmax( self, device_count, mesh_shape, @@ -87,15 +84,20 @@ def test_softmax( scale_factor, dtype, bad_sharding, + broadcast_batch_mask, + use_shardy, ): + if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED: + pytest.skip("Softmax type has no mask.") + jax.config.update("jax_use_shardy_partitioner", use_shardy) target_func = partial( self.target_func, scale_factor=scale_factor, softmax_type=softmax_type ) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) (x, mask), (x_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, softmax_type, dtype, bad_sharding + data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask ) collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) @@ -129,4 +131,70 @@ def test_softmax( assert "Sharding the hidden dimension is not supported" in str(w), ( "Softmax primitive did not raise the correct warning for " "unsupported sharding in the hidden dimension." + f"{str(w)}" ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) + @pytest.mark.parametrize( + "softmax_type", + [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], + ) + @pytest.mark.parametrize("scale_factor", [1.0, 3.0]) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("bad_sharding", [False, True]) + @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) + def test_softmax( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + softmax_type, + scale_factor, + dtype, + bad_sharding, + broadcast_batch_mask, + ): + self.impl_test_softmax( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + softmax_type, + scale_factor, + dtype, + bad_sharding, + broadcast_batch_mask, + use_shardy=False, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) + @pytest.mark.parametrize("bad_sharding", [False, True]) + @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) + def test_softmax_shardy( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + softmax_type, + bad_sharding, + broadcast_batch_mask, + ): + self.impl_test_softmax( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape=[32, 12, 128, 128], + softmax_type=softmax_type, + scale_factor=1.0, + dtype=DTYPES[0], + bad_sharding=bad_sharding, + broadcast_batch_mask=broadcast_batch_mask, + use_shardy=True, + ) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index b89530c19f..d59e130530 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -39,7 +39,7 @@ def enable_fused_attn(): is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) QUANTIZE_RECIPES = [] """ Find supported scaling modes""" @@ -215,12 +215,53 @@ def enable_fused_attn(): _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, }, # attrs22 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "causal", + _KEY_OF_WINDOW_SIZE: None, + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, + }, + # attrs23 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "causal", + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, + }, + # attrs24 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask", + }, + # attrs25 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask", + _KEY_OF_WINDOW_SIZE: (2, 2), + }, + # attrs26 { _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", _KEY_OF_WINDOW_SIZE: (2, 2), }, + # attrs27 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", + _KEY_OF_WINDOW_SIZE: None, + }, + # attrs28 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_WINDOW_SIZE: (2, 2), + }, ] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] @@ -313,7 +354,7 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: _, updated_quantize_meta = flax.core.pop( updated_state[0], QuantizeConfig.COLLECTION_NAME ) @@ -370,13 +411,13 @@ def generate_inputs(self, data_shape, dtype): data_rng = jax.random.PRNGKey(2024) inputs = (jax.random.normal(data_rng, data_shape, dtype),) - padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) - causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) + mask_shape = (batch, 1, seqlen, seqlen) + padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8) + causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1) if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]: mask = causal_mask else: mask = padded_mask - ref_masks = (1 - mask,) test_masks = (None, mask) # The second arg of Transformer is encoded tokens. diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 8cc8448979..09386c92ed 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -18,6 +18,7 @@ from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available from transformer_engine.jax.softmax import SoftmaxType, softmax +from transformer_engine.jax.flax.module import Softmax def catch_unsupported(method): @@ -94,7 +95,6 @@ def _setup_inputs(self): case _: raise ValueError(f"Unknown {self.softmax_type=}") - @catch_unsupported def test_forward(self): """ Test transformer_engine.jax.softmax.softmax fwd rule @@ -104,7 +104,6 @@ def test_forward(self): reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor) assert_allclose(primitive_out, reference_out, dtype=self.dtype) - @catch_unsupported def test_backward(self): """ Test transformer_engine.jax.softmax.softmax bwd rule @@ -141,6 +140,50 @@ def grad_func(func, *args, **kwargs): assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) +class SoftmaxPrimitivesRunner(SoftmaxRunner): + """ + Jax Softmax Primitives runner + """ + + @catch_unsupported + def test_forward(self): + return super().test_forward() + + @catch_unsupported + def test_backward(self): + return super().test_backward() + + +class SoftmaxModuleRunner: + """ + Jax Softmax Module runner + """ + + module_runner: SoftmaxRunner + bias: None + + def __init__(self, module_runner, bias): + self.module_runner = module_runner + self.bias = bias + + def test_forward(self): + """ + Test transformer_engine.jax.flax.module.Softmax fwd rule + """ + runner = self.module_runner + runner._setup_inputs() + rng = jax.random.PRNGKey(0) + softmax_module = Softmax( + scale_factor=runner.scale_factor, + softmax_type=runner.softmax_type, + ) + softmax_vars = softmax_module.init(rng, runner.logits, runner.mask) + module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask) + reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor) + assert_allclose(module_out, reference_out, dtype=runner.dtype) + + +# Run softmax primitives test @pytest.mark.parametrize( "b, s_q, s_kv, h", [ @@ -165,7 +208,7 @@ def grad_func(func, *args, **kwargs): pytest.param(jnp.float16, id="FP16"), ], ) -class TestSoftmax: +class TestSoftmaxPrimitives: """ Test transformer_engine.jax.softmax.softmax """ @@ -175,7 +218,7 @@ def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): """ Test forward with parameterized configs """ - runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner.test_forward() @staticmethod @@ -183,5 +226,48 @@ def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): """ Test forward with parameterized configs """ - runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner.test_backward() + + +# Run Softmax module test +@pytest.mark.parametrize( + "b, s_q, s_kv, h", + [ + pytest.param(8, 16, 16, 16, id="8-16-16-16"), + pytest.param(8, 512, 512, 16, id="8-512-512-16"), + pytest.param(2, 8, 16384, 8, id="2-8-16384-8"), + # triggers backup framework implementation due to (s_q % 4) != 0 + pytest.param(8, 511, 512, 16, id="8-511-512-16"), + ], +) +@pytest.mark.parametrize("scale_factor", [0.125]) +@pytest.mark.parametrize( + "softmax_type", + [ + pytest.param(SoftmaxType.SCALED, id="SCALED"), + pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"), + pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(jnp.bfloat16, id="BF16"), + pytest.param(jnp.float16, id="FP16"), + ], +) +class TestSoftmaxModule: + """ + Test transformer_engine.jax.flax.module.Softmax + """ + + @staticmethod + def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): + """ + Test forward with parameterized configs + """ + module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + bias = None + runner = SoftmaxModuleRunner(module_runner, bias) + runner.test_forward() diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py index ec06bb7e48..1b38f72512 100644 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -243,10 +243,10 @@ def __init__(self, weights, lr, dp_group): # Flatten the weights and pad to align with world size raw_data_list = [ - _get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1) + _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1) for w in weights ] - if isinstance(weights[0], Float8Tensor): + if isinstance(weights[0], QuantizedTensor): raw_data_list = [_get_raw_data(w).view(-1) for w in weights] else: raw_data_list = [w.view(-1) for w in weights] @@ -282,7 +282,7 @@ def __init__(self, weights, lr, dp_group): self.weight_indices.append((None, None)) self.shard_indices.append((None, None)) - if isinstance(weights[idx], Float8Tensor): + if isinstance(weights[idx], QuantizedTensor): replace_raw_data( weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) ) @@ -378,19 +378,13 @@ def step(self): master_weight -= grad * self.lr # Step 3: Cast master weights to FP8 or BF16 precision - if isinstance(self.weights[0], Float8Tensor): + if isinstance(self.weights[0], QuantizedTensor): local_weights = [] - for model_weight, local_weight in zip(self.weights, self.local_weights): + for local_weight in self.local_weights: if local_weight is None: local_weights.append(None) continue - quantizer = model_weight._get_quantizer() - if isinstance(quantizer, Float8CurrentScalingQuantizer): - local_weight = quantizer.create_tensor_from_data( - local_weight.view(-1), - model_weight.dtype, - ) local_weights.append(local_weight) cast_master_weights_to_fp8( diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ae5993eb1e..621d036212 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -19,6 +19,7 @@ MXFP8BlockScaling, DelayedScaling, Float8CurrentScaling, + Float8BlockScaling, Format, Recipe, ) @@ -49,6 +50,8 @@ def quantization_recipe() -> Recipe: return MXFP8BlockScaling() if QUANTIZATION == "fp8_cs": return Float8CurrentScaling() + if QUANTIZATION == "fp8_block_scaling": + return Float8BlockScaling() return te.fp8.get_default_fp8_recipe() @@ -85,7 +88,7 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE SEQ_LEN = 32 BATCH_SIZE = 32 @@ -297,6 +300,11 @@ def _loss_backward(output_single_node, output_distributed): LOSS_FN(output_distributed, target).backward() +def _loss_backward_dw(model_single_node, model_distributed): + model_single_node.backward_dw() + model_distributed.backward_dw() + + def _alloc_main_grad(model_single_node, model_distributed): for model in [model_single_node, model_distributed]: for param in model.parameters(): @@ -470,6 +478,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): # Compute loss and backpropagate _loss_backward(output_single_node, output_distributed) + # Compute delayed weight gradient + if "delay_wgrad_compute" in kwargs: + _loss_backward_dw(model_single_node, model_distributed) + # Validate outputs and gradients _check_outputs(output_single_node, output_distributed) @@ -491,6 +503,7 @@ def test_linear(): {"fuse_wgrad_accumulation": True}, {"return_bias": True}, {"params_dtype": torch.float16}, + {"delay_wgrad_compute": True}, ] for kwargs in kwargs_list: for parallel_mode in ["column", "row"]: @@ -642,6 +655,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs # Compute loss and backpropagate _loss_backward(output_single_node, output_distributed) + # Compute delayed weight gradient + if "delay_wgrad_compute" in kwargs: + _loss_backward_dw(model_single_node, model_distributed) + # Validate outputs and gradients _check_outputs(output_single_node, output_distributed) @@ -664,6 +681,7 @@ def test_layernorm_linear(): {"params_dtype": torch.float16}, {"zero_centered_gamma": False}, {"return_layernorm_output": True}, + {"delay_wgrad_compute": True}, ] for kwargs in kwargs_list: for parallel_mode in ["column"]: @@ -743,6 +761,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg # Compute loss and backpropagate _loss_backward(output_single_node, output_distributed) + if "delay_wgrad_compute" in kwargs: + _loss_backward_dw(model_single_node, model_distributed) + # Validate outputs and gradients _check_outputs(output_single_node, output_distributed) @@ -768,6 +789,7 @@ def test_layernorm_mlp(): {"fuse_wgrad_accumulation": True}, {"return_bias": True}, {"return_layernorm_output": True}, + {"delay_wgrad_compute": True}, ] for kwargs in kwargs_list: diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index b4e2b680b3..632f50e90a 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -28,6 +28,9 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -48,7 +51,7 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -56,4 +59,6 @@ def test_distributed(quantization): pytest.skip(fp8_available) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) _run_test(quantization) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index ed7cdda85b..ab4b7634b8 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,41 +2,84 @@ # # See LICENSE for license information. +import os +from contextlib import nullcontext import pytest import torch -from contextlib import nullcontext import transformer_engine.pytorch as te +from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -# Check if FP8 supported +# Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() + +fp8_recipes = [ + None, # non-fp8 + # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet + recipe.Float8CurrentScaling(), + recipe.DelayedScaling(), +] SIZE = 512 +NUM_HEADS = 8 +NUM_LAYERS = 5 +EPSILON = 0.1 + +# Flash attention saves some internal tensor for the backward pass +# that cannot be offloaded to CPU. +assert os.getenv("NVTE_FLASH_ATTN") == "0" -models = { - "linear": te.Linear, - "layernorm_mlp": te.LayerNormMLP, - "layernorm_linear": te.LayerNormLinear, +# Offloading is supported for attention only for fused and flash attention backends, +# so the use of bfloat16 is required. +# +# For the TransformerLayer, activation offloading with dropout is not supported, +# so we set hidden_dropout to 0.0. +model_types = { + "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), + "multihead_attention": lambda: te.MultiheadAttention( + SIZE, NUM_HEADS, params_dtype=torch.bfloat16 + ), + "transformer_layer": lambda: te.TransformerLayer( + SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 + ), } def _get_input(): - return torch.empty((128, SIZE, SIZE)).cuda() + return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() + + +def _get_fp8_weight_cache_size(models, fp8_recipe): + """ + Calculate the total FP8 weight cache size (in MB) for a list of models. + """ + if fp8_recipe is None: + return 0 + params_bytes = 0 + for model in models: + for name, param in model.named_parameters(): + if "weight" in name: + params_bytes += param.numel() -def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): + # One byte for columnwise and one byte for rowwise, + # hence multiply by 2 and convert to MB + # there is 1 byte of scale per 32 elements in mxFP8 + factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1 + return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2) - input_layer = model_cls(SIZE, SIZE) - hidden_layer = model_cls(SIZE, SIZE) - output_layer = model_cls(SIZE, SIZE) - input = _get_input() +def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): + tensor = _get_input() if cpu_offload: offload_context, sync_function = te.get_cpu_offload_context( enabled=True, - num_layers=2, - model_layers=3, + num_layers=len(models) - 1, + model_layers=len(models), offload_activations=True, offload_weights=False, ) @@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): offload_context = nullcontext() sync_function = lambda x: x - with te.fp8_autocast(enabled=fp8), offload_context: - out = input_layer(input) - out = sync_function(out) - with te.fp8_autocast(enabled=fp8), offload_context: - out = hidden_layer(out) - out = sync_function(out) - with te.fp8_autocast(enabled=fp8), offload_context: - out = output_layer(out) - out = sync_function(out) - - max_mem_used = torch.cuda.memory_allocated() / 1024**2 - - out.sum().backward() - - del input_layer - del hidden_layer - del output_layer - del input - del out + for model in models: + with te.fp8_autocast( + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + ), offload_context: + tensor = model(tensor) + tensor = sync_function(tensor) + max_mem_used = torch.cuda.memory_allocated() / (1024**2) torch.cuda.synchronize() return max_mem_used -@pytest.mark.parametrize("fp8", [True, False]) -@pytest.mark.parametrize("model_key", models.keys()) -def test_cpu_offload(fp8, model_key) -> None: +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("model_key", model_types.keys()) +def test_cpu_offload(fp8_recipe, model_key) -> None: + """ + We run three configurations: + (1) No offloading: All activations remain on the GPU between forward and backward passes. + (2) No offloading (one layer): Only the first layer's activations remain on the GPU between + forward and backward passes. + (3) With offloading (all layers): Only the last layer's activations remain on the GPU + between forward and backward passes, while all other layers are offloaded to the CPU. - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) + We expect the memory consumption of configurations (2) and (3) to be similar, with + the difference being the size of the FP8 cache that is not offloaded to the CPU. + We also expect this memory consumption to be smaller than in scenario (1). + """ - model_cls = models[model_key] + model_cls = model_types[model_key] + models_list = [model_cls() for _ in range(NUM_LAYERS)] - without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) - - with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) + if fp8_recipe and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None: + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + without_offloading = _measure_memory_between_forward_and_backward( + models_list, fp8_recipe, False + ) + without_offloading_one_layer = _measure_memory_between_forward_and_backward( + models_list[:1], fp8_recipe, False + ) + with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True) assert with_offloading < without_offloading + + # The only difference between the memory consumption of with_offloading + # and without_offloading_one_layer should be the size of the FP8 weights cache, + # which is not offloaded to the CPU. + memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) + assert ( + memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON + ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 5a1dc3f732..7bfe506f26 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -27,6 +27,9 @@ # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -55,6 +58,7 @@ class ModelConfig: recipe.DelayedScaling(), recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] # Supported data types @@ -316,9 +320,13 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling() and module == "linear_op": + pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") # Run model with different CUDA graph settings. model_config = model_configs[model_config] kwargs = dict( diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 9a1cfa2db8..ec23cfe8c5 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,21 +8,18 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from transformer_engine.pytorch.utils import get_device_compute_capability from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - return ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ) + supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() + return supported def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index e638fe8c5b..0baee4975d 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -4,11 +4,14 @@ from typing import Tuple import math +import os +import pathlib import pytest import torch import transformer_engine as te import transformer_engine_torch as tex -from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -18,10 +21,29 @@ BlockwiseQuantizerReference, QuantizeResult, ) +from test_float8_current_scaling_exact import ( + TestFP8RecipeLinearBase, + TestFP8RecipeLayerNormLinearBase, +) + +# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() + + +class GetRecipes: -# TODO replace with call to fp8.py when recipe added. -recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 -reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + @staticmethod + def none(): + return None + + @staticmethod + def fp8_blockwise(): + # return default configs + return Float8BlockScaling() def initialize_for_many_scales( @@ -66,35 +88,7 @@ def initialize_for_many_scales( return result -@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize( - "M, N", - [ - # full tile cases - (128, 128), - (256, 256), - (256, 1024), - (1024, 256), - # Padding required cases - (256, 272), - (303, 300), - (305, 256), - # Some larger tiles. - (2000, 2000), - (2048, 2000), - (2000, 1024), - (2048, 1024), - ], -) -@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] -) -@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) -@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) -def test_quantization_block_tiling_versus_reference( +def check_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, N: int, @@ -199,12 +193,90 @@ def test_quantization_block_tiling_versus_reference( [ # full tile cases (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (303, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) @pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (256, 256), + (2048, 1024), + # Padding required cases + (256, 272), + (303, 300), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference_fp32_scales( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"]) @pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) def test_quantization_block_tiling_extrema_versus_reference( @@ -292,3 +364,130 @@ def test_quantization_block_tiling_extrema_versus_reference( atol=0.0, rtol=0.0, ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1.6, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 9741b1258c..8911ecc159 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -82,7 +82,8 @@ def _get_sum_abs_error(a, b): @staticmethod def _get_mean_abs_relative_error(a, b): - return torch.mean(torch.abs((a - b) / b)) + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) @staticmethod def _load_golden_tensor_values(a, b): @@ -97,9 +98,12 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { @@ -437,9 +441,13 @@ def _check_golden_tensor_dumps( fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" + current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index d030426b74..6d3e879970 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -110,7 +110,10 @@ def _test_quantize_dequantize( dims = _to_list(dims) # Initialize random data + # Note: Make sure values are not all close to zero, or else + # test may pass trivially. x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref.view(-1)[0] = 0.75 x_ref_cuda = x_ref.to("cuda") # Cast to FP8 and back @@ -150,6 +153,24 @@ def test_quantize_dequantize_dtypes( ) self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1]) + def test_quantize_dequantize_columnwise_only( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=False, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True + ) + @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 42600e3099..d36da704b0 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import io -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import pytest import torch @@ -158,6 +158,32 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("noop", [True, False]) + def test_quantize_dequantize_noop( + self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool + ) -> None: + noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda") + if noop: + noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda") + dims = 23 + scale: float = 3.5 + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + # if noop, then when we input a different tensor, output should still be x_fp8_orig + x_ref_noop_test = 2 * x_ref.cuda() + x_fp8_orig = x_fp8.clone() + x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor) + if noop_tensor.item() == 1.0: + torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) + else: + torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) + def test_basic_ops( self, dims: DimsType = 23, diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 507fd3f350..cec25803f2 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -360,6 +360,20 @@ def test_fp16_exp_avg(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_exp_avg(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.bfloat16, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): @@ -389,6 +403,20 @@ def test_fp16_exp_avg_sq(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.bfloat16, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg_sq(self): diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9c1a842cd8..c2b32ca272 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import Iterable +import io import math from typing import Optional @@ -1393,6 +1394,7 @@ def test_make_extra_output( @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("cache_quantized_input", (False, True)) def test_activation( self, *, @@ -1401,6 +1403,7 @@ def test_activation( dtype: torch.dtype, device: torch.device = "cuda", quantization: Optional[str], + cache_quantized_input: bool, ) -> None: """Activation functions""" @@ -1412,6 +1415,8 @@ def test_activation( # Skip invalid configurations quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) + if cache_quantized_input: + maybe_skip_quantization("fp8", device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1420,15 +1425,17 @@ def test_activation( test_device=device, test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, + test_is_fp8=quantized_compute, requires_grad=False, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1459,7 +1466,8 @@ def test_activation( swiglu=te_ops.SwiGLU, )[activation] forward = te_ops.Sequential( - make_op(), + te_ops.Quantize(forward=False, backward=quantized_compute), + make_op(cache_quantized_input=cache_quantized_input), te_ops.Quantize(forward=quantized_compute, backward=False), ) with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): @@ -1468,9 +1476,9 @@ def test_activation( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute: + if quantized_compute or cache_quantized_input: tols = dtype_tols(tex.DType.kFloat8E4M3) - if activation == "relu": + if activation == "relu" and not cache_quantized_input: tols = {"atol": 0, "rtol": 0} # Check results @@ -1882,3 +1890,118 @@ def test_backward_linear_add( torch.testing.assert_close(y2_test, y2_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + +class TestCheckpointing: + """Tests for checkpointing""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear( + self, + *, + pre_checkpoint_steps: int = 2, + post_checkpoint_steps: int = 2, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Check checkpointing with linear op""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + + # Construct model + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_save = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25) + + # Warmup training steps + for _ in range(pre_checkpoint_steps): + x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + dy = torch.randn(out_shape, dtype=dtype, device=device) + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(x) + y.backward(dy) + optim_save.step() + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save( + {"model": model_save.state_dict(), "optim": optim_save.state_dict()}, + byte_stream, + ) + checkpoint_bytes = byte_stream.getvalue() + del byte_stream + + # Synthetic data for evaluation + xs_save = [ + torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + for _ in range(post_checkpoint_steps) + ] + with torch.no_grad(): + xs_load = [x.clone().requires_grad_() for x in xs_save] + dys = [ + torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps) + ] + + # Training steps with original model + ys_save = [] + for i in range(post_checkpoint_steps): + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(xs_save[i]) + y.backward(dys[i]) + optim_save.step() + ys_save.append(y) + + # Load checkpoint + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_load = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25) + state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False) + model_load.load_state_dict(state_dict["model"]) + optim_load.load_state_dict(state_dict["optim"]) + + # Training steps with loaded model + ys_load = [] + for i in range(post_checkpoint_steps): + optim_load.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_load(xs_load[i]) + y.backward(dys[i]) + optim_load.step() + ys_load.append(y) + + # Check that original and loaded model match exactly + tols = {"rtol": 0, "atol": 0} + for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): + torch.testing.assert_close(param_load, param_save, **tols) + torch.testing.assert_close(param_load.grad, param_save.grad, **tols) + for y_load, y_save in zip(ys_load, ys_save): + torch.testing.assert_close(y_load, y_save, **tols) + for x_load, x_save in zip(xs_load, xs_save): + torch.testing.assert_close(x_load.grad, x_save.grad, **tols) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 35f65a75f4..905339f4d3 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -50,6 +50,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) sm_80plus = get_device_compute_capability() >= (8, 0) @@ -104,6 +107,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] @@ -563,6 +567,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -675,6 +681,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -1028,7 +1036,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_granular_accuracy(block, bs, dtype, config): +def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False): reset_rng_states() inp_hidden_states = torch.randn( @@ -1044,12 +1052,18 @@ def _test_granular_accuracy(block, bs, dtype, config): out = out[0] loss = out.sum() loss.backward() + if delay_wgrad_compute: + block.backward_dw() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] for p in block.parameters(): if p.requires_grad: - outputs.append(p.grad) + if getattr(p, "main_grad", None) is not None: + outputs.append(p.main_grad) + assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True + else: + outputs.append(p.grad) return outputs @@ -1183,6 +1197,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation): + config = model_configs[model] + + te_linear_ref = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=False, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + te_linear = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=True, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + # Share params + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if bias: + te_linear_ref.bias = Parameter(te_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() + + te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True) + te_outputs_ref = _test_granular_accuracy( + te_linear_ref, bs, dtype, config, delay_wgrad_compute=False + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @@ -1368,6 +1430,67 @@ def test_layernorm_linear_accuracy( assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("normalization", all_normalizations) +@pytest.mark.parametrize("zero_centered_gamma", all_boolean) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +def test_layernorm_linear_accuracy_delay_wgrad_compute( + dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation +): + config = model_configs[model] + + ln_linear_ref = LayerNormLinear( + config.hidden_size, + 4 * config.hidden_size, + config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", + delay_wgrad_compute=False, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + ln_linear = LayerNormLinear( + config.hidden_size, + 4 * config.hidden_size, + config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", + delay_wgrad_compute=True, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + # Share params + with torch.no_grad(): + ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) + if normalization != "RMSNorm": + ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) + ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) + if bias: + ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(ln_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + ln_linear_ref.weight.main_grad = weight.main_grad.clone() + + te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True) + te_outputs_ref = _test_granular_accuracy( + ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) @@ -1444,8 +1567,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("activation", all_activations) +@pytest.mark.parametrize("normalization", all_normalizations) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +def test_layernorm_mlp_accuracy_delay_wgrad_compute( + dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation +): + config = model_configs[model] + + ln_mlp = LayerNormMLP( + hidden_size=config.hidden_size, + ffn_hidden_size=4 * config.hidden_size, + eps=config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=True, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + ln_mlp_ref = LayerNormMLP( + hidden_size=config.hidden_size, + ffn_hidden_size=4 * config.hidden_size, + eps=config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=False, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + # Share params + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) + if normalization != "RMSNorm": + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) + if fuse_wgrad_accumulation: + ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32) + ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() + ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32) + ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() + + te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True) + te_outputs_ref = _test_granular_accuracy( + ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + def _test_grouped_linear_accuracy( - block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation + block, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute=False, ): reset_rng_states() if fp8: @@ -1462,8 +1655,7 @@ def _test_grouped_linear_accuracy( if num_gemms > 1: split_size = 1 if fp8: - if recipe.delayed(): - split_size = 16 + split_size = 16 if recipe.mxfp8(): split_size = 128 m = config.seq_len // split_size @@ -1488,6 +1680,12 @@ def _test_grouped_linear_accuracy( ) loss = out.sum() loss.backward() + if delay_wgrad_compute: + if isinstance(block, GroupedLinear): + block.backward_dw() + else: + for i in range(num_gemms): + block[i].backward_dw() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] @@ -1501,33 +1699,34 @@ def _test_grouped_linear_accuracy( return outputs -@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_accuracy( dtype, num_gemms, bs, model, - fp8, recipe, fp8_model_params, fuse_wgrad_accumulation, + bias, + delay_wgrad_compute, parallel_mode=None, ): + fp8 = recipe is not None if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: + if fp8 and recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches - pytest.skip("MXFP8 unsupported for grouped linear.") - if fp8 and recipe.float8_current_scaling(): - pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1538,18 +1737,19 @@ def test_grouped_linear_accuracy( num_gemms, config.hidden_size, 4 * config.hidden_size, - bias=True, + bias=bias, params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", fuse_wgrad_accumulation=fuse_wgrad_accumulation, + delay_wgrad_compute=delay_wgrad_compute, ).eval() sequential_linear = torch.nn.ModuleList( [ Linear( config.hidden_size, 4 * config.hidden_size, - bias=True, + bias=bias, params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", @@ -1563,17 +1763,34 @@ def test_grouped_linear_accuracy( with torch.no_grad(): for i in range(num_gemms): sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation + sequential_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, ) outputs = _test_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, ) # Shoule be bit-wise match @@ -1581,24 +1798,7 @@ def test_grouped_linear_accuracy( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("parallel_mode", ["column", "row"]) -@pytest.mark.parametrize("recipe", fp8_recipes) -def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): - """Split the tests to save CI time""" - test_grouped_linear_accuracy( - dtype=torch.float32, - num_gemms=6, - bs=2, - model="126m", - fp8=True, - recipe=recipe, - fp8_model_params=True, - parallel_mode=parallel_mode, - fuse_wgrad_accumulation=True, - ) - - -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( @@ -1606,19 +1806,23 @@ def test_grouped_linear_accuracy_single_gemm(recipe): num_gemms=1, bs=2, model="126m", - fp8=True, recipe=recipe, fp8_model_params=True, fuse_wgrad_accumulation=True, + bias=True, + delay_wgrad_compute=False, ) def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): - """Padding tensor shapes to multiples of 16.""" + align_size = 16 + if recipe.mxfp8(): + align_size = 32 padded_tokens_per_expert = [ - (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert + (num_tokens + align_size - 1) // align_size * align_size + for num_tokens in tokens_per_expert ] hidden_states = torch.split(hidden_states, tokens_per_expert) padded_hidden_states = [] @@ -1719,10 +1923,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches - pytest.skip("MXFP8 unsupported for grouped linear.") - if fp8 and recipe.float8_current_scaling(): - pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1933,6 +2135,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 0dc183e298..07b5f7c529 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -8,6 +8,7 @@ import pytest from typing import Dict, List +from transformer_engine.common import recipe from transformer_engine.pytorch import ( moe_permute as te_permute, moe_permute_with_probs as te_permute_with_probs, @@ -17,9 +18,14 @@ ) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex - +import copy seed = 1234 torch.manual_seed(seed) @@ -234,7 +240,6 @@ def _test_permutation_index_map( f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -242,48 +247,12 @@ def _test_permutation_index_map( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - permute_fwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - permute_bwd_input = torch.rand( - size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - unpermute_bwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - _permute_fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _permute_bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _unpermute_bwd_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) - permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) - unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input) - - pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() - pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input.requires_grad_(True) @@ -323,9 +292,9 @@ def _test_permutation_index_map( # TE Permutation # ################################################################################################################################### - te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input = pytorch_permute_fwd_input.detach() te_permute_fwd_input.requires_grad_(True) - te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + te_permute_bwd_input = pytorch_permute_bwd_input.detach() te_permute_output, row_id_map = te_permute( te_permute_fwd_input, indices, num_out_tokens, map_type="index" @@ -338,7 +307,7 @@ def _test_permutation_index_map( te_probs.requires_grad_(True) te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input.requires_grad_(True) - te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() te_unpermute_output = te_unpermute( te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" @@ -352,16 +321,10 @@ def _test_permutation_index_map( ################################################################################################################################### tols = dtype_tols(te_dtype) - if fp8: - te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) - te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) - else: - te_permute_output_ = te_permute_output.float() - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.float() - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() torch.testing.assert_close( pytorch_permute_output.float(), @@ -487,7 +450,6 @@ def _test_permutation_mask_map( f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -495,49 +457,12 @@ def _test_permutation_mask_map( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - permute_fwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - permute_bwd_input = torch.rand( - size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - unpermute_bwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - - _permute_fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _permute_bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _unpermute_bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) - permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) - unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input) - - pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() - pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input.requires_grad_(True) @@ -553,10 +478,7 @@ def _test_permutation_mask_map( probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) probs = probs / row_sums - if fp8: - probs = probs.to(torch.float16) - else: - probs = probs.to(dtype) + probs = probs.to(dtype) probs.requires_grad_(True) ################################################################################################################################### @@ -582,9 +504,9 @@ def _test_permutation_mask_map( # TE Permutation # ################################################################################################################################### - te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input = pytorch_permute_fwd_input.detach() te_permute_fwd_input.requires_grad_(True) - te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + te_permute_bwd_input = pytorch_permute_bwd_input.detach() te_permute_output, row_id_map = te_permute( te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" @@ -597,7 +519,7 @@ def _test_permutation_mask_map( te_probs.requires_grad_(True) te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input.requires_grad_(True) - te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() te_unpermute_output = te_unpermute( te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" @@ -611,16 +533,10 @@ def _test_permutation_mask_map( ################################################################################################################################### tols = dtype_tols(te_dtype) - if fp8: - te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) - te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) - else: - te_permute_output_ = te_permute_output.float() - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.float() - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() torch.testing.assert_close( pytorch_permute_output.float(), @@ -730,6 +646,118 @@ def _test_permutation_mask_map( print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_mask_map_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + recipe, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + if recipe.delayed(): + quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + elif recipe.float8_current_scaling(): + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=te_dtype, + device=torch.device("cuda"), + columnwise=False, + ) + elif recipe.float8_block_scaling(): + quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + amax_epsilon=0.0, + force_pow_2_scales=True, # Fp8 sub-channel a2a requires e8 scales + block_scaling_dim=1, # 1x128 scaling + ) + elif recipe.mxfp8(): + quantizer = MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + ) + else: + raise ValueError("Unsupported FP8 recipe") + + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + # Make an empty fp8 tensor + permute_fwd_input_fp8 = quantizer.make_empty( + permute_fwd_input.shape, + dtype=permute_fwd_input.dtype, + device=permute_fwd_input.device, + ) + # quantize the tensor + quantizer.update_quantized(permute_fwd_input, permute_fwd_input_fp8) + if recipe.float8_block_scaling(): + pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data) + pytorch_permute_fwd_scale_input = copy.deepcopy( + permute_fwd_input_fp8._rowwise_scale_inv.T.contiguous() + ) + elif recipe.mxfp8(): + pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data) + pytorch_permute_fwd_scale_input = copy.deepcopy( + permute_fwd_input_fp8._rowwise_scale_inv.contiguous() + ) + else: + pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._data) + pytorch_permute_fwd_scale_input = None + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + # PyTorch Permutaion + pytorch_permute_output, _ = pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map) + if pytorch_permute_fwd_scale_input is not None: + pytorch_permute_scale_output, _ = pytorch_permute_mask_map( + pytorch_permute_fwd_scale_input, routing_map + ) + + # TE Permutation + permute_output, _ = te_permute( + permute_fwd_input_fp8, routing_map, num_out_tokens=num_out_tokens, map_type="mask" + ) + if recipe.float8_block_scaling(): + te_permute_output = permute_output._rowwise_data + te_permute_scale_output = permute_output._rowwise_scale_inv.T.contiguous() + elif recipe.mxfp8(): + te_permute_output = permute_output._rowwise_data + te_permute_scale_output = permute_output._rowwise_scale_inv.contiguous() + else: + te_permute_output = permute_output._data + te_permute_scale_output = None + + # check the permute output + torch.testing.assert_close( + pytorch_permute_output, + te_permute_output, + atol=0, + rtol=0, + ) + if recipe.float8_block_scaling() or recipe.mxfp8(): + torch.testing.assert_close( + pytorch_permute_scale_output, + te_permute_scale_output, + atol=0, + rtol=0, + ) + + def _test_moe_chunk_sort( te_dtype, num_tokens, @@ -743,7 +771,6 @@ def _test_moe_chunk_sort( f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -751,34 +778,11 @@ def _test_moe_chunk_sort( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") - bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") - - _fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - fwd_input = _fwd_input_quantizer.quantize(fwd_input) - bwd_input = _bwd_input_quantizer.quantize(bwd_input) - - pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16) - pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_fwd_input.requires_grad_(True) @@ -806,9 +810,9 @@ def _test_moe_chunk_sort( # TE Permutation # ################################################################################################################################### - te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach() + te_fwd_input = pytorch_fwd_input.detach() te_fwd_input.requires_grad_(True) - te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach() + te_bwd_input = pytorch_bwd_input.detach() te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) te_output.backward(te_bwd_input, retain_graph=True) @@ -820,12 +824,8 @@ def _test_moe_chunk_sort( ################################################################################################################################### tols = dtype_tols(te_dtype) - if fp8: - te_output_ = te_output.dequantize(dtype=torch.float32) - te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32) - else: - te_output_ = te_output.float() - te_fwd_input_grad = te_fwd_input.grad.float() + te_output_ = te_output.float() + te_fwd_input_grad = te_fwd_input.grad.float() torch.testing.assert_close( pytorch_output.float(), @@ -899,7 +899,6 @@ def _test_permutation_mask_map_alongside_probs( f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -907,38 +906,11 @@ def _test_permutation_mask_map_alongside_probs( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - permute_fwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - unpermute_bwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - - _permute_fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _unpermute_bwd_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input) - unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input) - - pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input.requires_grad_(True) @@ -952,10 +924,7 @@ def _test_permutation_mask_map_alongside_probs( probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) probs = probs / row_sums - if fp8: - probs = probs.to(torch.float16) - else: - probs = probs.to(dtype) + probs = probs.to(dtype) probs.requires_grad_(True) split_sizes = [0] * (num_expert * tp_size) @@ -1006,13 +975,12 @@ def _test_permutation_mask_map_alongside_probs( # TE Permutation # ################################################################################################################################### - te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input = pytorch_permute_fwd_input.detach() te_permute_fwd_input.requires_grad_(True) - te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() te_probs = probs.detach() te_probs.requires_grad_(True) - print(te_probs.shape) te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( te_permute_fwd_input, @@ -1020,27 +988,14 @@ def _test_permutation_mask_map_alongside_probs( routing_map, num_out_tokens=num_out_tokens, ) - print(te_permuted_probs.shape) te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda ) - if fp8: - _permute_output_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - te_permute_output = te_permute_output.dequantize(dtype=torch.float32) - te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) - te_permute_output = _permute_output_quantizer.quantize(te_permute_output) - else: - te_permute_output_dtype = te_permute_output.dtype - print(te_permute_output.shape) - print(te_permuted_probs.shape) - te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) - te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) + te_permute_output_dtype = te_permute_output.dtype + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) te_permute_output = te_sort_chunks_by_index( te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda @@ -1058,13 +1013,8 @@ def _test_permutation_mask_map_alongside_probs( tols = dtype_tols(te_dtype) - if fp8: - # backward of dequantize is in high precision - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) - else: - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() torch.testing.assert_close( pytorch_unpermute_output.float(), @@ -1228,6 +1178,16 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) +fp8_recipes = [ + recipe.MXFP8BlockScaling(), + recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), +] @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @@ -1237,36 +1197,7 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) -def test_permutation_index_map_fp8( - te_dtype, - num_tokens, - num_expert, - hidden_size, - topK, - num_out_tokens, -): - with_probs = True - BENCHMARK = False - - _test_permutation_index_map( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=BENCHMARK, - ) - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) -@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("recipe", fp8_recipes) def test_permutation_mask_map_fp8( te_dtype, num_tokens, @@ -1274,47 +1205,21 @@ def test_permutation_mask_map_fp8( hidden_size, topK, num_out_tokens, + recipe, ): - with_probs = True - BENCHMARK = False - - _test_permutation_mask_map( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=BENCHMARK, - ) - + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) -@pytest.mark.parametrize("num_out_tokens", [None, 2039]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) -def test_permutation_mask_map_alongside_probs_fp8( - te_dtype, - num_tokens, - num_expert, - hidden_size, - topK, - num_out_tokens, - tp_size, -): - _test_permutation_mask_map_alongside_probs( + _test_permutation_mask_map_fp8( te_dtype=te_dtype, num_tokens=num_tokens, num_expert=num_expert, hidden_size=hidden_size, topK=topK, num_out_tokens=num_out_tokens, - tp_size=tp_size, + recipe=recipe, ) @@ -1415,11 +1320,9 @@ def test_permutation_single_case(): # te_dtype = tex.DType.kFloat32 # te_dtype = tex.DType.kFloat16 - # te_dtype = tex.DType.kBFloat16 - te_dtype = tex.DType.kFloat8E5M2 - # te_dtype = tex.DType.kFloat8E4M3 + te_dtype = tex.DType.kBFloat16 - num_tokens = 10 + num_tokens = 12 num_expert = 4 hidden_size = 16 topK = 2 diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 69ac8f7996..afb17d388a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -42,10 +42,14 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.utils import replace_raw_data +from transformer_engine.pytorch.distributed import checkpoint from test_numerics import reset_rng_states, dtype_tols # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -106,6 +110,7 @@ def is_fp8_supported(self): None, # Test non-FP8 recipe.MXFP8BlockScaling(), # Test default recipe.Float8CurrentScaling(), # Test default + recipe.Float8BlockScaling(), # Test default recipe.DelayedScaling(), # Test default recipe.DelayedScaling( # Test most_recent algo amax_history_len=16, @@ -439,6 +444,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -470,6 +477,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -502,6 +511,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -543,10 +554,10 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) - if fp8_recipe.mxfp8(): - pytest.skip("Grouped linear does not support MXFP8") - if fp8_recipe.float8_current_scaling(): - pytest.skip("Grouped linear does not support FP8 current scaling") + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -590,6 +601,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -640,6 +653,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -707,6 +722,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -766,6 +783,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -823,6 +842,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -858,6 +879,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -896,6 +919,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -937,6 +962,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -979,8 +1006,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling(): + pytest.skip("cuda graph not supported for float8_block_scaling recipe") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -1255,3 +1286,31 @@ def test_fp8_model_init_high_precision_init_val(): assert not hasattr( weight, "._high_precision_init_val" ), "clear_high_precision_init_val() not work" + + +def test_sanity_checkpointing_on_callables(): + """Test that TE checkpointing works correctly on callable modules.""" + + # torch.autograf.function + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + return inp + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + module = MyFunction.apply + inp = torch.randn(10, 10, device="cuda", requires_grad=True) + + out_checkpoint = checkpoint(module, inp) + out_checkpoint.sum().backward() + grad_checkpoint = inp.grad + + out_standard = module(inp) + out_standard.sum().backward() + grad_standard = inp.grad + + # Assert that gradients are the same + torch.testing.assert_close(grad_checkpoint, grad_standard) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 708403f911..67f173a4ab 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template @@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b1fe436379..daed7718ff 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -78,8 +78,8 @@ struct SimpleTensor { SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} operator NVTEBasicTensor() const { - const NVTEShape shape = {this->shape.data(), this->shape.size()}; - return {dptr, static_cast(dtype), shape}; + return {dptr, static_cast(dtype), + nvte_make_shape(this->shape.data(), this->shape.size())}; } int numel() const { @@ -99,11 +99,6 @@ struct Tensor { SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; - private: - // Used as an allocation for nvte_tensor_shape - // if the shape has to be inferred from columnwise data. - mutable std::vector rowwise_shape_cache; - public: NVTEScalingMode scaling_mode; @@ -194,11 +189,6 @@ struct Tensor { } } - const std::vector &rowwise_shape_ref() const { - rowwise_shape_cache = shape(); - return rowwise_shape_cache; - } - /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted @@ -233,10 +223,12 @@ struct Tensor { struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; + NVTETensor noop_tensor = nullptr; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float) // amax_epsilon + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor) // noop_tensor }; }; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 6fe3539257..0cd0762ee5 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -96,8 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const int arch = cuda::sm_arch(); // Transpose mode with column-major ordering - bool transa_bool = transA == CUBLAS_OP_T; - bool transb_bool = transB == CUBLAS_OP_T; + bool is_A_transposed = transA == CUBLAS_OP_T; + bool is_B_transposed = transB == CUBLAS_OP_T; // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -106,8 +106,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transA = transA; ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - ret.lda = transa_bool ? k : m; - if (arch < 100 && !transa_bool) { + ret.lda = is_A_transposed ? k : m; + if (arch < 100 && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -123,28 +123,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = transA; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.lda = m; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? k : m; } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage @@ -165,8 +165,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transB = transB; ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; - ret.ldb = transb_bool ? n : k; - if (arch < 100 && transb_bool) { + ret.ldb = is_B_transposed ? n : k; + if (arch < 100 && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -182,28 +182,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = transB; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; - ret.ldb = k; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = is_B_transposed ? n : k; } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = CUBLAS_OP_N; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; // Requirements from @@ -392,7 +392,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &B_scale_inverse, sizeof(B_scale_inverse))); NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), - "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"); scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 7fa7957fa4..64136b2c43 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,7 +89,7 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. @@ -102,6 +102,16 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream); +/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output quantized tensor. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream); + /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/normalization.h b/transformer_engine/common/include/transformer_engine/normalization.h index 9b0b80acc2..9c194e9da2 100644 --- a/transformer_engine/common/include/transformer_engine/normalization.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable); +/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma + * in weight dtype. If set to false, it will compute in compute dtype. + * + * Currently this only applies to the CuDNN backend. If CuDNN is not used, + * this setting has no effect. + * + * \param[in] bool Enable if True + */ +void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable); + enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; #ifdef __cplusplus diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ba47b9d38c..2c3192f773 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -42,6 +42,8 @@ struct NVTEShape { const size_t *data; /*! \brief Number of dimensions. */ size_t ndim; + /*! \brief Copy of data. Num dims limited to permit fixed struct size.*/ + size_t owned_data[14]; }; /*! \struct NVTEBasicTensor @@ -134,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor); */ void *nvte_tensor_columnwise_data(const NVTETensor tensor); +/*! \brief Construct a shape from an array of dimension sizes. + * + * \param[data] Pointer to start of shape array. + * \param[data] Number of dimensions (must be <= 14) + * + * \return A shape. The shape will own its own copy of the data. + */ +NVTEShape nvte_make_shape(const size_t *data, size_t ndim); + /*! \brief Get a tensor's data shape. * * \param[in] tensor Tensor. @@ -286,6 +297,12 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigForcePow2Scales = 0, /*! Small value to add to amax for numerical stability */ kNVTEQuantizationConfigAmaxEpsilon = 1, + /*! Noop tensor (containing a scalar). + If the scalar element value = 1, quantization kernel will early exit. + This is a tensor because the flag must be on GPU in order to enable + conditional early even when captured in a static CUDA graph. + */ + kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNumAttributes }; @@ -411,8 +428,9 @@ class TensorWrapper { float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, const std::vector &scale_inv_shape = {1}, const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, - scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, + : TensorWrapper(dptr, nvte_make_shape(shape.data(), shape.size()), dtype, amax_dptr, + scale_dptr, scale_inv_dptr, + nvte_make_shape(scale_inv_shape.data(), scale_inv_shape.size()), scaling_mode) {} /*! \brief Constructs new empty TensorWrapper. @@ -528,7 +546,9 @@ class TensorWrapper { * \return Shape of this TensorWrapper. */ const NVTEShape shape() const noexcept { - if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + if (tensor_ == nullptr) { + return nvte_make_shape(nullptr, 0); + } return nvte_tensor_shape(tensor_); } @@ -537,7 +557,9 @@ class TensorWrapper { * \return Shape of this TensorWrapper. */ const NVTEShape columnwise_shape() const noexcept { - if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + if (tensor_ == nullptr) { + return nvte_make_shape(nullptr, 0); + } return nvte_tensor_columnwise_shape(tensor_); } @@ -650,7 +672,9 @@ class TensorWrapper { * \return scale_inv_shape of this TensorWrapper. */ const NVTEShape scale_inv_shape() const noexcept { - if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + if (tensor_ == nullptr) { + return nvte_make_shape(nullptr, 0); + } return nvte_tensor_scale_inv_shape(tensor_); } @@ -666,12 +690,20 @@ class TensorWrapper { void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } static constexpr size_t defaultData = 1; - static constexpr NVTEShape defaultShape = {&defaultData, 1}; + static constexpr NVTEShape defaultShape = { + &defaultData, 1, {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; private: - NVTEShape convertShape(const NVTEShape &s) { return s; } + NVTEShape convertShape(const NVTEShape &s) { + NVTEShape ret = s; + // Move the ownership rather than pointing to the parent shape. + ret.data = ret.owned_data; + return ret; + } - NVTEShape convertShape(const std::vector &s) { return {s.data(), s.size()}; } + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } /*! \brief Wrapped NVTETensor. */ NVTETensor tensor_ = nullptr; @@ -724,6 +756,12 @@ class QuantizationConfigWrapper { &amax_epsilon, sizeof(float)); } + /*! \brief Set noop tensor pointer */ + void set_noop_tensor(NVTETensor noop_tensor) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, &noop_tensor, + sizeof(NVTETensor)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index ddda78d951..89affc081c 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -39,6 +39,8 @@ Compute always in FP32 namespace transformer_engine { namespace normalization { +bool& use_zero_centered_gamma_in_weight_dtype(); + cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { return training ? cudnn_frontend::NormFwdPhase_t::TRAINING : cudnn_frontend::NormFwdPhase_t::INFERENCE; @@ -207,9 +209,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor _ndim_scale_block = 1; } - _scalar_dptr = std::make_unique(typeToSize(wtype)); + const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype; + + _scalar_dptr = std::make_unique(typeToSize(gamma_dtype)); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); + gamma_dtype, cpp_dtype, + *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); _handle = cudnnExecutionPlanManager::Instance().GetHandle(); @@ -239,13 +244,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_name("one") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(wtype)) + .set_data_type(get_cudnn_fe_dtype(gamma_dtype)) .set_is_pass_by_value(true)); auto centered_options = fe::graph::Pointwise_attributes() .set_mode(fe::PointwiseMode_t::ADD) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options); - _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype)); + _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(gamma_dtype)); } else { _gamma = _gamma_zero; } @@ -503,6 +508,13 @@ bool& _cudnn_norm_bwd_flag() { bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); } bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } +bool& _zero_centered_gamma_in_weight_dtype() { + static bool flag = transformer_engine::getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE"); + return flag; +} + +bool& use_zero_centered_gamma_in_weight_dtype() { return _zero_centered_gamma_in_weight_dtype(); } + } // namespace normalization } // namespace transformer_engine @@ -515,3 +527,8 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { NVTE_API_CALL(nvte_enable_cudnn_norm_bwd); transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; } + +void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { + NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype); + transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; +} diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index f6b6ae22c2..47b37b3482 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -31,19 +31,24 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); + NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); + NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape."); + NVTE_CHECK(gamma.data.dtype == beta.data.dtype, + "Gamma and Beta must have the same dtype. Gamma dtype: " + + to_string(gamma.data.dtype) + ", Beta dtype: " + to_string(beta.data.dtype)); + NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0], "Gamma must have the same hidden size."); - NVTE_CHECK(epsilon >= 0.f); + NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative."); - NVTE_CHECK(z->data.shape == x.data.shape); + NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x."); - NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}); - NVTE_CHECK(mu->data.dtype == DType::kFloat32); + NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}, + "Mu must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(mu->data.dtype == DType::kFloat32, "Mu must be a float32 tensor."); - NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); - NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}, + "RSigma must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); if (!workspace->data.shape.empty()) { CheckInputTensor(x, "x"); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index c56f9ef407..48cf1d819b 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -27,15 +27,16 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } - NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); - NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); - NVTE_CHECK(epsilon >= 0.f); + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1], "Gamma must have the same hidden size."); + NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative."); - NVTE_CHECK(z->data.shape == x.data.shape); + NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x."); - NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); - NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}, + "RSigma must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); if (!workspace->data.shape.empty()) { CheckInputTensor(x, "x"); diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 7e9e2a97f7..2ac38c93cf 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -333,7 +333,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so const transformer_engine::Tensor *input_fwd_cu = reinterpret_cast(input_fwd); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), reinterpret_cast(output_cu->data.dptr), @@ -359,7 +359,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id const transformer_engine::Tensor *prob_cu = reinterpret_cast(prob); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), reinterpret_cast(output_cu->data.dptr), diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b676bf6ab0..80857e565c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -5,6 +5,7 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations import warnings +import os from enum import Enum from typing import Literal, Optional, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass @@ -81,6 +82,10 @@ def float8_per_tensor_scaling(self): """Whether the given recipe is per-tensor scaling.""" return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + def float8_block_scaling(self): + """Whether the given recipe is float8 blockwise scaling.""" + return isinstance(self, Float8BlockScaling) + @dataclass() class DelayedScaling(Recipe): @@ -287,3 +292,99 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + + +@dataclass() +class Float8BlockScaling(Recipe): + """ + Use block-wise scaling for FP8 tensors. + + In this strategy, tensors are scaled in blockwise fashion. Values within + each block share a common scaling factor. The block dimensionality + can be configured. The scaling factors are float32 containers. They + will by default be constrained to powers of 2. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), the quantized tensor and its transpose are not numerically + equivalent. Due to this, when Transformer Engine needs both the FP8 tensor + and its transpose (e.g. to calculate both forward and backward pass), + during the quantization both versions are computed from the high precision + input to avoid double quantization errors. + + NOTE: To relax the default constraint that scales be powers of 2, set env variable + NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. + export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 + Or initialize the Recipe with non-default QParams in code for increased control. + + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of gradient tensor dY + x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for x. + w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for w. + grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for grad. + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + """ + + use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" + + fp8_format: Format = Format.E4M3 + fp8_quant_fwd_inp = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + x_block_scaling_dim: int = 1 + w_block_scaling_dim: int = 2 + grad_block_scaling_dim: int = 1 + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" + assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" + assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" + assert not ( + self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." + assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." + assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"x_block_scaling_dim={self.x_block_scaling_dim}, " + f"w_block_scaling_dim={self.w_block_scaling_dim}, " + f"grad_block_scaling_dim={self.grad_block_scaling_dim}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 97df5892b6..9072e1d060 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -211,6 +211,22 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { reinterpret_cast(tensor)->dtype()); } +NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { + NVTEShape ret; + if (ndim == 0) { + ret.data = nullptr; + ret.ndim = 0; + return ret; + } + NVTE_CHECK(ndim <= sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), + "Too many dims for NVTEShape (requested: ", ndim, + ", max: ", sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), ")"); + std::copy(data, data + ndim, ret.owned_data); + ret.data = ret.owned_data; + ret.ndim = ndim; + return ret; +} + NVTEShape nvte_tensor_shape(const NVTETensor tensor) { if (tensor == nullptr) { NVTE_ERROR("Invalid tensor"); @@ -218,12 +234,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { // Determine tensor shape depending on tensor format const auto &t = *reinterpret_cast(tensor); - const std::vector &rowwise_shape = t.rowwise_shape_ref(); + std::vector shape = t.shape(); - NVTEShape ret; - ret.data = rowwise_shape.data(); - ret.ndim = rowwise_shape.size(); - return ret; + return nvte_make_shape(shape.data(), shape.size()); } NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { @@ -231,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTE_ERROR("Invalid tensor"); } const auto &t = *reinterpret_cast(tensor); - NVTEShape ret; - ret.data = t.columnwise_data.shape.data(); - ret.ndim = t.columnwise_data.shape.size(); - return ret; + return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size()); } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } @@ -258,7 +268,7 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { size_t nvte_tensor_element_size(const NVTETensor tensor) { if (tensor == nullptr) return sizeof(float); const auto &t = *reinterpret_cast(tensor); - return transformer_engine::typeToSize(t.data.dtype); + return transformer_engine::typeToSize(t.dtype()); } void *nvte_tensor_data(const NVTETensor tensor) { @@ -302,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { } NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; + if (tensor == nullptr) { + return nvte_make_shape(nullptr, 0); + } const auto &t = *reinterpret_cast(tensor); - NVTEShape ret; - ret.data = t.scale_inv.shape.data(); - ret.ndim = t.scale_inv.shape.size(); - return ret; + return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size()); } void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, @@ -429,6 +438,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(buf, &config_.amax_epsilon, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(buf, &config_.noop_tensor, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -458,6 +470,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(&config_.amax_epsilon, buf, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(&config_.noop_tensor, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 298d087337..3148b4f720 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -29,11 +29,35 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor const bool return_transpose, const bool pow_2_scale, cudaStream_t stream); +// enum class for rowwise usage +enum class FP8BlockwiseRowwiseOption { + // No rowwise data + NONE, + // Rowwise data, scales in GEMM format + ROWWISE + // TODO: FP8 all gather requires some changes. + // 1. Compact scales are better for gathering than the GEMM format. +}; + +// enum class for columnwise usage +// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling +enum class FP8BlockwiseColumnwiseOption { + // No columnwise data + NONE, + // Columnwise data transposed from original shape. + // Scales in GEMM format corresponding to GEMM ingesting transposed column data. + COLUMNWISE_TRANSPOSE + // TODO: FP8 all gather requires some changes. + // 1. The transpose gets in the way of the all gather. + // 2. Compact scales are better for gathering than the GEMM format. +}; + void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, - const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scale, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 732d97999c..91f73dea1e 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -16,11 +16,15 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" #include "common/utils.cuh" namespace transformer_engine { namespace { +using transformer_engine::detail::FP8BlockwiseColumnwiseOption; +using transformer_engine::detail::FP8BlockwiseRowwiseOption; + // clang-format off /* @@ -138,15 +142,17 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); template -__global__ void __launch_bounds__(kThreadsPerBlock) - block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, - OType* const output_t, CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, const size_t row_length, - const size_t num_rows, const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, const float epsilon, - bool return_transpose, bool pow_2_scaling) { +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, + FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scaling) { + bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; + bool return_columnwise_transpose = + columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; + using SMemVec = Vec; using OVec = Vec; union IVec { @@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) __syncthreads(); // Step 2: Cast and store to output_c - { + if (return_rowwise) { constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; @@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Step 3: Transpose, cast and store to output_t - if (return_transpose) { + if (return_columnwise_transpose) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); @@ -389,10 +395,15 @@ namespace transformer_engine::detail { void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, - const bool return_transpose, const bool pow2_scale, - cudaStream_t stream) { + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); - NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + // assert that rowwise_option and columnwise_option are not both NONE + NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE || + columnwise_option != FP8BlockwiseColumnwiseOption::NONE, + "rowwise_option and columnwise_option cannot both be NONE"); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -408,21 +419,24 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor } // Options for scale layout of cuBLAS GEMM kernel. - - NVTE_CHECK(input.shape.size() == output.shape.size(), - "Input and output must have the same shape."); - size_t scale_stride_x = 0; size_t scale_stride_y = 0; - NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); - size_t scale_k = scale_inv.shape[1]; - scale_stride_x = scale_k; - scale_stride_y = 1; - size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; - if (return_transpose) { + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, + "Unexpected rowwise enum value"); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = scale_k; + scale_stride_y = 1; + } + + if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { + NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE, + "Unexpected columnwise enum value"); NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); if (output_t.shape.size() > 0) { @@ -469,10 +483,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(output_t.dptr), reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, - pow2_scale);) // kAligned - ) // OutputType - ) // InputType + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, + columnwise_option, pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 22a50025df..1f146c7a33 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, nullptr, output, - dbias, workspace, stream); + detail::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no NVTE_API_CALL(nvte_quantize_noop); using namespace transformer_engine; + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; @@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, noop, output, - dbias, workspace, stream); + detail::quantize_helper( + input, grad, output, dbias, workspace, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, @@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr const NVTETensor activation_input = nullptr; detail::quantize_helper( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, @@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, @@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, @@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, @@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, @@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index c6a8b0f23c..a599d88530 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1215,9 +1215,9 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) { @@ -1232,6 +1232,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe auto output_tensor = reinterpret_cast(output); auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); + + const QuantizationConfig *quant_config_cpp = + reinterpret_cast(quant_config); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); switch (output_tensor->scaling_mode) { @@ -1263,11 +1269,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - constexpr bool force_pow_2_scales = true; + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, + output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } @@ -1275,12 +1281,18 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - constexpr bool force_pow_2_scales = true; - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() + ? FP8BlockwiseRowwiseOption::ROWWISE + : FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = + output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE + : FP8BlockwiseColumnwiseOption::NONE; + quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/debug/__init__.py b/transformer_engine/debug/__init__.py new file mode 100644 index 0000000000..62f7f41728 --- /dev/null +++ b/transformer_engine/debug/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Top level package for numerical debugging.""" + +try: + from . import pytorch + from .pytorch.debug_state import set_weight_tensor_tp_group_reduce +except ImportError as e: + pass diff --git a/transformer_engine/debug/pytorch/__init__.py b/transformer_engine/debug/pytorch/__init__.py new file mode 100644 index 0000000000..8bdbe287de --- /dev/null +++ b/transformer_engine/debug/pytorch/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py new file mode 100644 index 0000000000..4a7a156a0a --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -0,0 +1,528 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +This file contains DebugQuantizer and DebugQuantizedTensor objects, +which are wrappers over Quantizer and QuantizedTensor. +These wrappers add logic related to debugging, using the nvdlfw_inspect package. +""" + +from __future__ import annotations +from typing import Optional, Tuple, Iterable, Union +import torch + +import transformer_engine_torch as tex + + +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + +aten = torch.ops.aten + +_tensor_to_gemm_names_map = { + "weight": ["fprop", "dgrad"], + "activation": ["fprop", "wgrad"], + "output": ["fprop", None], + "gradient": ["dgrad", "wgrad"], + "wgrad": ["wgrad", None], + "dgrad": ["dgrad", None], +} + +API_CALL_MODIFY = "modify_tensor()" +STANDARD_FP8_QUANTIZE = "FP8 Quantize" +HIGH_PRECISION = "High Precision" + + +class DebugQuantizer(Quantizer): + """ + DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect. + It allows adding custom calls inside the quantization process - which enables modifying tensors + or gathering tensor stats. + """ + + def __init__( + self, + layer_name: str, + tensor_name: str, + parent_quantizer: Optional[Quantizer], + tp_group: torch.distributed.ProcessGroup, + ): + import nvdlfw_inspect.api as debug_api + + super().__init__(rowwise=True, columnwise=True) + self.layer_name = layer_name + self.tensor_name = tensor_name + self.parent_quantizer = parent_quantizer + self.tp_group = tp_group # used in inspect_tensor calls + self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + + self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] + + # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, + # rowwise_tensor_plan, and columnwise_tensor_plan are computed. + # These fields indicate the path where API calls will be inserted. + # + # inspect_tensor*_enabled are bool fields, + # indicating whether some feature will need to run inspect_tensor_* calls. + # + # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] + # determining what will happen when the quantizer is used for that tensor. + self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] + if self.output_tensor: + self.inspect_tensor_enabled, self.rowwise_tensor_plan = ( + self.get_plans_for_output_tensors() + ) + else: + ( + self.inspect_tensor_enabled, + self.inspect_tensor_postquantize_enabled_rowwise, + self.inspect_tensor_postquantize_enabled_columnwise, + ) = self.get_enabled_look_at_tensors() + self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan() + + self.log_messages_about_plans() + + def get_plans_for_output_tensors(self) -> Tuple[bool, str]: + """ + Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the + API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support + gemm output in FP8. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + modify_enabled = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION + + return inspect_tensor_enabled, plan + + def get_enabled_look_at_tensors(self): + """ + Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + inspect_tensor_postquantize_enabled_rowwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.rowwise_gemm_name, + ) + ) + inspect_tensor_postquantize_enabled_columnwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.columnwise_gemm_name, + ) + ) + + return ( + inspect_tensor_enabled, + inspect_tensor_postquantize_enabled_rowwise, + inspect_tensor_postquantize_enabled_columnwise, + ) + + def get_tensors_plan(self): + """ + Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of + API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior + of this quantizer with respect to these tensors. + """ + import nvdlfw_inspect.api as debug_api + + rowwise_plan = None + columnwise_plan = None + + modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_rowwise: + rowwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + rowwise_plan = STANDARD_FP8_QUANTIZE + if rowwise_plan is None: + rowwise_plan = HIGH_PRECISION + + if self.columnwise_gemm_name is not None: + modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_columnwise: + columnwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + columnwise_plan = STANDARD_FP8_QUANTIZE + if columnwise_plan is None: + columnwise_plan = HIGH_PRECISION + + return rowwise_plan, columnwise_plan + + def log_messages_about_plans(self): + """ + Logs the messages about the plans for each of the tensors. + """ + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -" + f" {self.rowwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name), + ) + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -" + f" {self.columnwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name), + ) + + def _call_inspect_tensor_api( + self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None + ): + import nvdlfw_inspect.api as debug_api + + args = { + "layer_name": self.layer_name, + "tensor": tensor, + "tensor_name": self.tensor_name, + "iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count, + "tp_group": self.tp_group, + } + if tensor is not None and self.inspect_tensor_enabled: + debug_api.transformer_engine.inspect_tensor(**args) + + if self.output_tensor: + return + + if ( + self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_rowwise + ): + args["tensor"] = rowwise_gemm_tensor + args["rowwise"] = True + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + if ( + self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_columnwise + ): + args["tensor"] = columnwise_gemm_tensor + args["rowwise"] = False + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + + def quantize( + self, + tensor: torch.Tensor, + *, + out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None, + dtype: torch.dtype = None, + ): + """Returns DebugQuantizedTensor object.""" + import nvdlfw_inspect.api as debug_api + + assert not self.output_tensor + if out is not None: + return self.update_quantized(tensor, self) + + # 1. If there is fp8 quantization in at least one of the gemms, + # the quantization using the self.parent_quantizer is performed. + + # rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise + rowwise_gemm_quantize = ( + self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + columnwise_gemm_quantize = ( + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + if columnwise_gemm_quantize and not rowwise_gemm_quantize: + rowwise_gemm_quantize = True # only columnwise quantization not implemented + + rowwise_gemm_tensor, columnwise_gemm_tensor = None, None + if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + self.parent_quantizer.set_usage( + rowwise=True, + columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported + ) + quantized_tensor = self.parent_quantizer(tensor) + # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, + # one tensor with columnwise=True and rowwise=True is computed + # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. + if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: + rowwise_gemm_tensor = quantized_tensor + if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: + columnwise_gemm_tensor = quantized_tensor + + # 2. modify_tensor() is called, if it is used. + if self.columnwise_tensor_plan == API_CALL_MODIFY: + columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + if columnwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") + if self.rowwise_tensor_plan == API_CALL_MODIFY: + rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + if rowwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") + + # 3. If some tensors still are not defined we use high precision tensor. + if self.rowwise_tensor_plan == HIGH_PRECISION: + rowwise_gemm_tensor = tensor.to(dtype) + if self.columnwise_tensor_plan == HIGH_PRECISION: + columnwise_gemm_tensor = tensor.to(dtype) + + self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor) + + # sometimes we may want to return simple tensor with only rowwise_gemm + if self.tensor_name in ["wgrad", "dgrad", "output"]: + return rowwise_gemm_tensor + + return DebugQuantizedTensor( + rowwise_gemm_tensor=rowwise_gemm_tensor, + columnwise_gemm_tensor=columnwise_gemm_tensor, + quantizer=self, + layer_name=self.layer_name, + tensor_name=self.tensor_name, + ) + + def process_gemm_output(self, tensor: torch.Tensor): + """This call is invoked after the gemm to inspect and modify the output tensor.""" + import nvdlfw_inspect.api as debug_api + + assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." + assert self.output_tensor + tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} + if self.rowwise_tensor_plan == API_CALL_MODIFY: + tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + gemm=tensor_to_gemm[self.tensor_name], + tensor_name=self.tensor_name, + tensor=tensor, + iteration=self.iteration, + default_quantizer=self.parent_quantizer, + ) + self._call_inspect_tensor_api(tensor) + return tensor + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensor: + """Override make_empty() from Quantizer class.""" + if self.parent_quantizer is not None: + return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device) + return torch.empty(shape, dtype=dtype, device=device) + + def calibrate(self, tensor: torch.Tensor): + """Calibration override, should not be invoked.""" + raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported") + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Update quantized tensor - used in weight caching.""" + import nvdlfw_inspect.api as debug_api + + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + + updated_rowwise_gemm = False + if self.parent_quantizer is not None: + if ( + dst.rowwise_gemm_tensor is not None + and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ): + if hasattr(dst.rowwise_gemm_tensor, "quantize_"): + dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None) + updated_rowwise_gemm = True + if ( + dst.columnwise_gemm_tensor is not None + and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + and not updated_rowwise_gemm + ): + if hasattr(dst.columnwise_gemm_tensor, "quantize_"): + dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None) + + if self.columnwise_tensor_plan == API_CALL_MODIFY: + out = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.columnwise_gemm_tensor, + iteration=self.iteration, + ) + assert out is None, ( + "API call debug_api.transformer_engine.modify_tensor with out != None should" + " return None" + ) + if self.rowwise_tensor_plan == API_CALL_MODIFY: + debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.rowwise_gemm_tensor, + iteration=self.iteration, + ) + + if self.rowwise_tensor_plan == HIGH_PRECISION: + dst.rowwise_gemm_tensor.copy_(src) + if self.columnwise_tensor_plan == HIGH_PRECISION: + # if they are the same tensor object, it is sufficient to update one + if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor: + dst.columnwise_gemm_tensor.copy_(src) + + self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor) + + def any_feature_enabled(self) -> bool: + """Returns bool if there is at least one API call enabled.""" + if self.output_tensor: + return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + if ( + self.inspect_tensor_enabled + or self.inspect_tensor_postquantize_enabled_rowwise + or self.inspect_tensor_postquantize_enabled_columnwise + or self.rowwise_tensor_plan == API_CALL_MODIFY + or self.columnwise_tensor_plan == API_CALL_MODIFY + ): + return True + if self.parent_quantizer is not None: + if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + return False + + +class DebugQuantizedTensor: + """ + Class containing quantized tensors after debug. Depending on configuration + it can contain one or two different objects. These objects can be accessed by the method + get_tensor(). + """ + + def __init__( + self, + rowwise_gemm_tensor, + columnwise_gemm_tensor, + quantizer, + layer_name=None, + tensor_name=None, + ): + + self.rowwise_gemm_tensor = rowwise_gemm_tensor + self.columnwise_gemm_tensor = columnwise_gemm_tensor + self.quantizer = quantizer + self._layer_name = layer_name + self._tensor_name = tensor_name + + def prepare_for_saving(self): + """ " Prepare for saving method override""" + self.tensors_to_save = ( + [self.rowwise_gemm_tensor, self.columnwise_gemm_tensor] + if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor + else [self.rowwise_gemm_tensor] + ) + tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save) + self.tensors_to_save = tensor_objects_list + # pylint: disable=unbalanced-tuple-unpacking + return tensor_list, self + + def restore_from_saved(self, tensors): + """Restore from saved method override""" + tensor_objects_list, saved_tensors = restore_from_saved( + self.tensors_to_save, + tensors, + return_saved_tensors=True, + ) + if len(tensor_objects_list) == 2: + # pylint: disable=unbalanced-tuple-unpacking + self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list + else: + self.rowwise_gemm_tensor = tensor_objects_list[0] + self.columnwise_gemm_tensor = self.rowwise_gemm_tensor + return saved_tensors + + def quantize_(self, tensor, *, noop_flag=None): + """ " quantize_ method override""" + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + self.quantizer.update_quantized(tensor, self) + + def dequantize(self, *, dtype=None): + """ " dequantize method override""" + if dtype is None: + dtype = self.rowwise_gemm_tensor.dtype + return self.rowwise_gemm_tensor.dequantize().to(dtype) + + def get_tensor(self, transpose: bool): + """Is used in the python gemm() to get tensor or transpose of the tensor.""" + return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor + + def size(self): + """Size of the tensor.""" + return self.rowwise_gemm_tensor.size() + + def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): + """Update usage of the tensor.""" diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py new file mode 100644 index 0000000000..11edb3641f --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Managing the state of all the debugged layers. +""" + +import sys + + +class TEDebugState: + """ + A class to manage the state of debug layers. + """ + + layer_count = 1 + layers_initialized = {} + weight_tensor_tp_group_reduce = True + debug_enabled = None + + @classmethod + def initialize(cls): + """ + If debug_api module is initialized, then sets cls.debug_enabled to True. + """ + + if "nvdlfw_inspect" in sys.modules: + import nvdlfw_inspect.api as debug_api + + if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None: + # This method is invoked when initializing TE modules. + # If this error is thrown, it means that some TE module had been initialized before + # debug_api was initialized, and now a new TE module is being initialized. + # This is likely to be a bug. + raise RuntimeError( + "[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before" + " initialization of the first TE module" + ) + cls.debug_enabled = debug_api.DEBUG_MANAGER is not None + + @classmethod + def _reset(cls): + """Resets layer count and stats buffers.""" + from ..features.utils.stats_buffer import STATS_BUFFERS + + STATS_BUFFERS.reset() + cls.debug_enabled = None + cls.layers_initialized.clear() + + @classmethod + def get_layer_count(cls): + """ + Layer counter is used when layer names are not provided to modules by the user. + """ + lc = cls.layer_count + cls.layer_count += 1 + return lc + + @classmethod + def set_weight_tensor_tp_group_reduce(cls, enabled): + """Sets weight tensor reduction mode.""" + cls.weight_tensor_tp_group_reduce = enabled + + +def set_weight_tensor_tp_group_reduce(enabled): + """Sets weight tensor reduction mode.""" + TEDebugState.set_weight_tensor_tp_group_reduce(enabled) diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py new file mode 100644 index 0000000000..4aea05333c --- /dev/null +++ b/transformer_engine/debug/pytorch/utils.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utils functions for the debug module.""" + + +def any_feature_enabled(quantizers): + """Returns True if at least one API call is made from DebugQuantizer.""" + return any(q.any_feature_enabled() for q in quantizers) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d7676781c3..21d5503e3e 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -10,6 +10,7 @@ import jax import jax.numpy as jnp from jax import dtypes +from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec import transformer_engine_jax @@ -162,7 +163,7 @@ def lowering( assert scale_aval is None or scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x + ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x ) return out @@ -282,7 +283,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -293,9 +294,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -339,7 +340,7 @@ def partition( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -350,9 +351,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -391,7 +392,7 @@ def sharded_impl(x, scale): ) ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -406,6 +407,54 @@ def sharded_impl(x, scale): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + out_dtype, + act_enum, + act_len, + scaling_mode, + is_2x, + scale_dtype, + scale_shapes, + is_outer, + mesh, + value_types, + result_types, + ): + del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types + + x_rank = len(value_types[0].shape) + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + x_rank - 1, unique_var="i", flatten_axis=-2 + ) + x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) + out = (*x_axes[:-2], x_axes[-1]) + scale_inv = scale_rules.rowwise_rule + colwise_scale_inv = scale_rules.colwise_rule + + if is_2x: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + colwise_out = tuple( + multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) + ) + else: + colwise_out = out + else: + colwise_out = ("j",) + colwise_scale_inv = ("k",) + + # amax is always a unit tensor. + amax = ("l",) + + return SdyShardingRule( + ( + x_axes, + "…1", + ), + (out, colwise_out, scale_inv, colwise_scale_inv, amax), + **scale_rules.factor_sizes, + ) + register_primitive(ActLuPrimitive) @@ -463,7 +512,7 @@ def abstract( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: colwise_out_shape = out_shape @@ -545,7 +594,7 @@ def lowering( dz, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), @@ -673,7 +722,7 @@ def infer_sharding_from_operands( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -691,9 +740,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -743,7 +792,7 @@ def partition( ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -761,9 +810,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -810,7 +859,7 @@ def sharded_impl(dz, x, scale): else: global_dbias = local_dbias - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -819,6 +868,46 @@ def sharded_impl(dz, x, scale): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + out_dtype, + scaling_mode, + is_2x, + scale_dtype, + scale_shapes, + is_dbias, + act_enum, + act_len, + is_outer, + mesh, + value_types, + result_types, + ): + del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types + + x_rank = len(value_types[1].shape) + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + x_rank, unique_var="i", flatten_axis=-2 + ) + x_axes = scale_rules.input_spec + out = x_axes + if is_2x: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) + else: + colwise_out = tuple(x_axes) + else: + colwise_out = ("j",) + + dbias = x_axes[-2:] if is_dbias else ("k",) + amax = ("…4",) + + return SdyShardingRule( + (("…0",), tuple(x_axes), ("…2",)), + (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + **scale_rules.factor_sizes, + ) + register_primitive(DActLuDBiasQuantizePrimitive) @@ -928,7 +1017,7 @@ def act_lu( out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1042,7 +1131,7 @@ def quantize_dact_dbias( # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, # unused scale_dtype=jnp.float32, # unused scale_shapes=((), ()), # unused @@ -1095,7 +1184,7 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 7a31fa729d..ea682d4c47 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -14,6 +14,7 @@ import jax.numpy as jnp from jax import dtypes, lax from jax.sharding import PartitionSpec, NamedSharding +from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax from transformer_engine_jax import NVTE_Fused_Attn_Backend @@ -42,6 +43,7 @@ get_mesh_axis_rank, get_all_mesh_axes, num_of_devices, + with_sharding_constraint, ) @@ -618,6 +620,35 @@ def partition(config, mesh, arg_infos, result_infos): impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + del mesh, result_types + + # Keep in sync with `infer_sharding_from_operands`. + # We only need the first input. Fill up the rest with placeholders. + input_spec = [(f"…{x}",) for x in range(len(value_types))] + # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint + # instead. This has to happen outside of the primitive, see `fused_attn_fwd`. + rng_sharding = (f"…{len(value_types)}",) + + if config.qkv_layout.is_qkvpacked(): + input_spec[0] = ("…0", "seqlen", "three", "head", "hidden") + elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate(): + input_spec[0] = ("…0", "seqlen", "head", "hidden") + else: + raise ValueError(f"Unsupported {config.qkv_layout=}") + + is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd() + out_sharding = ("…0", "seqlen", "head", "hidden") + if is_packed_softmax: + softmax_aux_sharding = ("…0", "seqlen", "head", "i") + else: + softmax_aux_sharding = ("…0", "head", "seqlen", "i") + + return SdyShardingRule( + tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding) + ) + register_primitive(FusedAttnFwdPrimitive) @@ -998,6 +1029,15 @@ def sharded_impl( return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + del config, mesh + # We only care about the four first arguments. + # Keep in sync with `infer_sharding_from_operands`. + input_spec = tuple((f"…{x}",) for x in range(len(value_types))) + output_spec = tuple((f"…{x}",) for x in range(len(result_types))) + return SdyShardingRule(input_spec, output_spec) + register_primitive(FusedAttnBwdPrimitive) @@ -2436,13 +2476,15 @@ def fused_attn_fwd( primitive = FusedRingAttnFwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) - return primitive.bind( + output, softmax_aux, rng_state = primitive.bind( *qkv_for_primitive, bias, seed, *seq_desc_flatten, config=fused_config, ) + rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None)) + return (output, softmax_aux, rng_state) def fused_attn_bwd( diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 5d64fa9bb6..1c9bade0e7 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -98,6 +98,15 @@ def partition(): """ return NotImplemented + @staticmethod + @abstractmethod + def shardy_sharding_rule(*args): + """ + Returns the sharding rule for this primitive. + """ + del args + return "... -> ..." + def register_primitive(cls): """ @@ -123,7 +132,9 @@ def name_of_wrapper_p(): batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition + infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition, + sharding_rule=cls.shardy_sharding_rule, ) mlir.register_lowering( outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1df2bcc97f..588e7a469d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9) + impl_static_args = () inner_primitive = None outer_primitive = None @staticmethod - def abstract( - lhs_contig_aval, - lhs_scale_contig_aval, - rhs_contig_aval, - rhs_scale_contig_aval, - bias_contig_aval, - dim_list_aval, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ): - del lhs_contig_aval, lhs_scale_contig_aval - del rhs_contig_aval, rhs_scale_contig_aval - del bias_contig_aval, dim_list_aval - del num_gemms, scaling_mode - out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype) - wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams - wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8) - return (out_flat_aval, wkspace_aval) + def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + """ + Args: + *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: + args[ 0 : num_gemms] are the lhs tensors, + args[ num_gemms : 2*num_gemms] are the rhs tensors, + args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, + args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, + args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. + num_gemms: Number of GEMM operations to perform. + scaling_mode: Scaling mode for the GEMM operations. + out_dtype: Data type of the output tensors. + has_bias: Boolean indicating if bias tensors are provided. + + Returns: + A tuple of ShapedArray objects of size num_gemms+1: + ret[0 : num_gemms]: GEMM output tensors, + ret[num_gemms]:workspace tensor. + """ + del scaling_mode + expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms + assert ( + len(args) == expected_num_args + ), f"Expected {expected_num_args} input arguments, but got {len(args)}" + A_list = args[0:num_gemms] + B_list = args[num_gemms : 2 * num_gemms] + # A and B have shapes [1, m, k] and [1, n, k] + out_list_aval = tuple( + jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) + for A, B in zip(A_list, B_list) + ) + workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + return (*out_list_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): @@ -74,60 +87,27 @@ def outer_abstract(*args, **kwargs): return out_aval @staticmethod - def lowering( - ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: - del out_dtype, out_flat_size + def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + *args, num_gemms=num_gemms, scaling_mode=int(scaling_mode), + has_bias=has_bias, ) @staticmethod - def impl( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: + def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): assert GroupedGemmPrimitive.inner_primitive is not None out = GroupedGemmPrimitive.inner_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + *args, num_gemms=num_gemms, scaling_mode=scaling_mode.value, out_dtype=out_dtype, - out_flat_size=out_flat_size, + has_bias=has_bias, ) - return out[0] # out is [out_flat, wkspace], only return out_flat + return out[:-1] # out is [out_list, wkspace], only return out_list register_primitive(GroupedGemmPrimitive) @@ -198,7 +178,7 @@ def _jax_gemm_delayed_scaling_fp8( ): """FP8 GEMM for XLA pattern match""" assert ( - rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -230,7 +210,7 @@ def _jax_gemm_mxfp8_1d( JAX GEMM for MXFP8 via scaled_matmul """ assert ( - rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING + rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING ), "rhs does not have MXFP8 1D scaling mode" from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper @@ -291,10 +271,10 @@ def _jax_gemm( def _jax_gemm_fp8_impl(lhs, rhs): - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) - if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") @@ -366,6 +346,7 @@ def swizzled_scale(scales): rows, cols = scales.shape scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) + scales = scales.reshape(rows, cols) return scales @@ -380,18 +361,12 @@ def grouped_gemm( len(lhs_list) == len(rhs_list) == len(contracting_dims_list) ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - # Flatten inputs and save their shapes - num_gemms = len(lhs_list) - out_flat_size = 0 - dims = [] - lhs_contig_ = [] - rhs_contig_ = [] - lhs_scale_inv_contig_ = [] - rhs_scale_inv_contig_ = [] - bias_contig_ = [] - out_offsets = [] - remain_shape_list = [] num_gemms = len(lhs_list) + lhs_list_ = [] + rhs_list_ = [] + lhs_sinv_list_ = [] + rhs_sinv_list_ = [] + bias_list_ = [] for i in range(num_gemms): lhs = lhs_list[i] rhs = rhs_list[i] @@ -402,8 +377,8 @@ def grouped_gemm( lhs_shape = lhs.data.shape rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype - # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" @@ -415,7 +390,7 @@ def grouped_gemm( dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NVTE_NO_SCALING + scaling_mode = ScalingMode.NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape out_dtype = lhs.dtype @@ -427,24 +402,25 @@ def grouped_gemm( lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy + if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) + # swizzled_scale requires a matrix lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) else: raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - # Note: if _shape_normalization() is updated to support non-TN, need to update here - # already_transposed doesn't matter for the output shape + # Note: already_transposed doesn't matter for the output shape # x.shape = [B, D1, D2] # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] @@ -455,66 +431,37 @@ def grouped_gemm( bn = rhs_remain_shape[0] kl = lhs_3d.shape[-1] kr = rhs_3d.shape[-1] - remain_shape_list.append(((bm,), (bn,))) - assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" - k = kl - - if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): - print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print( - f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" - " of 16" - ) - assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0 - - dims.append((bm, bn, k)) - lhs_contig_.append(lhs_3d.reshape(-1)) - rhs_contig_.append(rhs_3d.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) + assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" + if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): + print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") + print(f"m = {bm}, n = {bn}, k = {kl}; ") + print("cuBLAS requires the problem shapes being multiples of 16") + assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) + + lhs_list_.append(lhs_3d) + rhs_list_.append(rhs_3d) + if scaling_mode == ScalingMode.NO_SCALING: + lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + lhs_sinv_list_.append(lhs.scale_inv) + rhs_sinv_list_.append(rhs.scale_inv) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_sinv_list_.append(lhs_scale_inv) + rhs_sinv_list_.append(rhs_scale_inv) if bias_list is not None: - bias_contig_.append(bias_list[i].reshape(-1)) - out_flat_size += bm * bn - out_offsets.append(out_flat_size) - - lhs_contig = jnp.concatenate(lhs_contig_) - rhs_contig = jnp.concatenate(rhs_contig_) - lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_) - rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_) - bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_) - dim_list = jnp.array(dims, dtype=jnp.int32) - - # TE/common does not support NVTE_NO_SCALING yet - # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING - - # Perform batched GEMM on flattened inputs - out_contig = GroupedGemmPrimitive.outer_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + bias_list_.append(bias_list[i]) + + out_list = GroupedGemmPrimitive.outer_primitive.bind( + *lhs_list_, + *rhs_list_, + *lhs_sinv_list_, + *rhs_sinv_list_, + *bias_list_, num_gemms=num_gemms, scaling_mode=scaling_mode, out_dtype=out_dtype, - out_flat_size=out_flat_size, + has_bias=1 if bias_list is not None else 0, ) - # Split the output back into tensors - out_offsets = jnp.array(out_offsets) - out_flat_list = jnp.split(out_contig, out_offsets[:-1]) - out_tensors = [] - for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list): - out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) - - return out_tensors + return out_list diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index c79eda5568..d64104ac27 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, """ should_apply_war = ( quantizer is not None - and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x() ) if not should_apply_war: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 74882c92db..54360c2dcc 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -12,6 +12,7 @@ import jax import jax.numpy as jnp from jax import dtypes +from jax.experimental.custom_partitioning import SdyShardingRule from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec @@ -63,6 +64,27 @@ def get_backward_sm_margin(): return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) +@cache +def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool: + """Retrieves whether CuDNN norm fwd is enabled.""" + # MXFP8_1D_SCALING always uses CuDNN currently + return ( + int(os.getenv("NVTE_NORM_FWD_USE_CUDNN", "0")) == 1 + or scaling_mode == ScalingMode.MXFP8_1D_SCALING + ) + + +@cache +def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bool: + """Retrieves whether norm should compute `gamma += 1.0` for zero-centered gamma + in weight dtype as opposed to compute dtype.""" + if not is_norm_fwd_cudnn_enabled(scaling_mode): + # If CuDNN is not enabled, we use the TE backend which uses the compute dtype not weight dtype + # Remove this when TE supports gamma += 1.0 in weight dtype + return False + return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1 + + class NormFwdPrimitive(BasePrimitive): """ Layer Normalization Forward FP8 Primitive @@ -105,6 +127,26 @@ def abstract( if norm_type == NVTE_Norm_Type.LayerNorm: assert gamma_aval.size == beta_aval.size + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) + if norm_type == NVTE_Norm_Type.RMSNorm: + mu_aval = mu_aval.update(shape=(1,)) + + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) + + scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size @@ -112,33 +154,13 @@ def abstract( jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype jax_dtype_to_te_dtype(out_dtype), norm_type, - scaling_mode.value, + scaling_mode, zero_centered_gamma, epsilon, get_forward_sm_margin(), is_2x, ) - - out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) - mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_aval = mu_aval.update(shape=(1,)) - - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x_aval.shape, is_padded=not is_outer - ) - - scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) - colwise_out_aval = jax.core.ShapedArray( - shape=x_aval.shape if is_2x else (1,), dtype=out_dtype - ) - - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - - wkspace_aval = x_aval.update( + wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) @@ -274,9 +296,9 @@ def impl( scale_shapes=scale_shapes, is_outer=False, ) - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x.shape, is_padded=False - ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x.shape, is_padded=False) # slice out padding for mxfp8, noop for DelayedScaling scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape @@ -364,6 +386,8 @@ def infer_sharding_from_operands( del zero_centered_gamma, epsilon, out_dtype, result_infos del scale_dtype, scale_shapes, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) + out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -371,34 +395,27 @@ def infer_sharding_from_operands( "and hurt performance." ) - out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out" + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) - rsigma_sharding = NamedSharding( mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") + + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) - - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") output = ( out_sharding, colwise_out_sharding, @@ -427,8 +444,11 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) g_spec = get_padded_spec(arg_infos[2]) b_spec = get_padded_spec(arg_infos[3]) + out_spec = (*x_spec[:-1], None) + if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -445,43 +465,30 @@ def partition( f"{NormFwdPrimitive.name} does not support sharding of parameter beta " "Enforcing no sharding of parameters hidden dim! " ) - x_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x" - ) - g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma") - b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta") - out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out") - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" + ) rsigma_sharding = NamedSharding( - mesh, - PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]), - desc="NormFwdPrimitive.rsigma", + mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale" - ) - scale_inv_sharding = scale_sharding.duplicate_with_new_description( - "NormFwdPrimitive.scale_inv" + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") - arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, colwise_out_sharding, @@ -517,7 +524,7 @@ def sharded_impl(x, scale, gamma, beta): scale_shapes=scale_shapes, is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -534,6 +541,57 @@ def sharded_impl(x, scale, gamma, beta): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + norm_type, + zero_centered_gamma, + epsilon, + out_dtype, + scaling_mode, + is_2x, + scale_dtype, + scale_shapes, + is_outer, + mesh, + value_types, + result_types, + ): + del ( + zero_centered_gamma, + epsilon, + out_dtype, + scale_dtype, + scale_shapes, + is_outer, + mesh, + result_types, + ) + + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + len(value_types[0].shape), unique_var="i", flatten_axis=-1 + ) + x_axes = scale_rules.input_spec + + out = x_axes[:-1] + ("k",) + colwise_out = out if is_2x else ("…4",) + rsigma = x_axes[:-1] + mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = ("…6",) + + return SdyShardingRule( + (x_axes, ("…1",), ("…2",), ("…3",)), + ( + out, + colwise_out, + scale_rules.rowwise_rule, + scale_rules.colwise_rule, + amax, + mu, + rsigma, + ), + **scale_rules.factor_sizes, + ) + register_primitive(NormFwdPrimitive) @@ -737,6 +795,11 @@ def sharded_impl(dz, x, mu, rsigma, gamma): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule(*args): + del args + return "...0, ...1 i, ...2, ...3, ...4 -> ...1 j, k, l" + register_primitive(NormBwdPrimitive) @@ -746,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) JAX native layernorm implementation """ x_ = jnp.asarray(x, jnp.float32) + if not is_norm_zero_centered_gamma_in_weight_dtype( + quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING + ): + gamma = gamma.astype(jnp.float32) mean = jnp.mean(x_, axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) @@ -767,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): JAX native rmsnorm implementation """ x_ = jnp.asarray(x, jnp.float32) + if not is_norm_zero_centered_gamma_in_weight_dtype( + quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING + ): + gamma = gamma.astype(jnp.float32) var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) normed_input = x_ * rsigma @@ -824,7 +895,6 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) - if quantizer is None: output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, @@ -835,7 +905,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((1,), (1,)), @@ -845,7 +915,7 @@ def layernorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -864,7 +934,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -873,7 +943,7 @@ def layernorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -882,7 +952,7 @@ def layernorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) @@ -1017,7 +1087,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1027,7 +1097,7 @@ def rmsnorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -1046,7 +1116,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -1055,7 +1125,7 @@ def rmsnorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1064,7 +1134,7 @@ def rmsnorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 034e149c50..23d8572994 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -10,6 +10,7 @@ import jax import jax.numpy as jnp from jax import dtypes +from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec import transformer_engine_jax @@ -93,7 +94,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -114,6 +115,10 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + scaling_mode, + QuantizeLayout( + q_layout + ), # For now until we have auto-decoding for QuantizeLayout enum ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -176,7 +181,7 @@ def lowering( ctx, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, is_dbias=is_dbias, @@ -302,7 +307,7 @@ def infer_sharding_from_operands( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -322,9 +327,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -374,7 +379,7 @@ def partition( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -394,9 +399,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -445,7 +450,7 @@ def sharded_impl(x, scale): is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -466,6 +471,48 @@ def sharded_impl(x, scale): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + scale_shapes, + is_dbias, + is_outer, + mesh, + value_types, + result_types, + ): + del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types + + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis + ) + + x_axes = scale_rules.input_spec + colwise_scale_inv = scale_rules.colwise_rule + + out = x_axes + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) + else: + colwise_out = x_axes + else: + colwise_out = ("j",) + colwise_scale_inv = ("k",) + + dbias = x_axes[flatten_axis:] if is_dbias else ("l",) + amax = ("m",) + + return SdyShardingRule( + (x_axes, ("…1",)), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, + ) + register_primitive(DBiasQuantizePrimitive) @@ -588,7 +635,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index b50e98081d..1556fa3344 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -31,6 +31,9 @@ "scaled_upper_triang_masked_softmax_fwd", "scaled_upper_triang_masked_softmax_bwd", "is_softmax_kernel_available", + "jax_scaled_softmax", + "jax_scaled_masked_softmax", + "jax_scaled_upper_triang_masked_softmax", ] @@ -330,6 +333,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "... -> ..." + register_primitive(ScaledSoftmaxFwdPrimitive) @@ -400,6 +408,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "..., ... -> ..." + register_primitive(ScaledSoftmaxBwdPrimitive) @@ -412,7 +425,7 @@ def scaled_softmax_bwd( Return FP16/BF16 tensor """ if not ScaledSoftmaxBwdPrimitive.enabled(): - _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits) + _, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits) return vjp_func(dz)[0] return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( @@ -525,6 +538,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "...1, ...2 -> ...1" + register_primitive(ScaledMaskedSoftmaxFwdPrimitive) @@ -596,6 +614,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "..., ... -> ..." + register_primitive(ScaledMaskedSoftmaxBwdPrimitive) @@ -682,6 +705,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): result_infos, ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "... -> ..." + register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) @@ -761,15 +789,26 @@ def partition(scale_factor, mesh, arg_infos, result_infos): result_infos, ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "..., ... -> ..." + register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) -def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): +def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): + """ + JAX based implementation of scaled softmax + """ return jax.nn.softmax(scale_factor * logits) -def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): +def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): + """ + JAX based implementation of scaled and masked softmax + """ if mask is not None: logits += jax.lax.select( mask > 0, @@ -779,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac return jax.nn.softmax(logits * scale_factor) -def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): +def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): + """ + JAX based implementation of scaled and upper triangle masked softmax + """ mask = 1 - jnp.tril(jnp.ones_like(logits)) logits += jax.lax.select( mask > 0, @@ -795,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: Return FP16/BF16 tensor """ if not ScaledSoftmaxFwdPrimitive.enabled(): - return _jax_scaled_softmax(logits, scale_factor) + return jax_scaled_softmax(logits, scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) @@ -807,7 +849,7 @@ def scaled_masked_softmax_fwd( Return FP16/BF16 tensor """ if not ScaledMaskedSoftmaxFwdPrimitive.enabled(): - return _jax_scaled_masked_softmax(logits, mask, scale_factor) + return jax_scaled_masked_softmax(logits, mask, scale_factor) return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, mask, scale_factor=scale_factor ) @@ -826,7 +868,7 @@ def scaled_masked_softmax_bwd( """ if not ScaledMaskedSoftmaxBwdPrimitive.enabled(): _, vjp_func = jax.vjp( - partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask + partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask ) return vjp_func(dz)[0] return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( @@ -840,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl Return FP16/BF16 tensor """ if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled(): - return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor) + return jax_scaled_upper_triang_masked_softmax(logits, scale_factor) return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, scale_factor=scale_factor ) @@ -855,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd( """ if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled(): _, vjp_func = jax.vjp( - partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits + partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits ) return vjp_func(dz)[0] return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1950d6cbab..aaaf57fab7 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -31,6 +31,9 @@ #include "transformer_engine/activation.h" #include "utils.h" +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); + namespace transformer_engine { namespace jax { @@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); + +pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, bool is_2x); + // Normalization XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); @@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training); @@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype); - -XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); - -pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x); + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index e71597e4b3..fc7f231f34 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -17,7 +17,7 @@ namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum, + Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto scaling_mode = static_cast(scaling_mode_enum); auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis @@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("act_enum") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid auto dact_input_tensor = TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); - auto output_tensor = TensorWrapper(static_cast(scaling_mode)); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (is_fp8_dtype(out_dtype)) { @@ -172,9 +173,8 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } if (is_2x) { - auto &tmp_shape = scaling_mode == static_cast(NVTE_DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter @@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } } - if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, - bool is_dbias, int64_t act_enum) { + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, bool is_2x, bool is_dbias) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); - auto scaling_mode = static_cast(scaling_mode_enum); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -295,9 +297,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK( - !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), - "TE/common does not support delayed scaling for 2x with gated activations."); + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { switch (act_type) { @@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") + .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias") - .Attr("act_enum"), + .Attr("is_dbias"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e5ec160c91..4318e19c75 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -15,34 +15,9 @@ namespace transformer_engine { namespace jax { -constexpr static size_t MXFP8_BLOCK_SIZE = 32; - -// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX) -Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr, - const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype, - uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, - const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, - uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const int64_t &scaling_mode, - cudaStream_t stream) { - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms; - std::unique_ptr dim_list_host = std::make_unique(3 * num_gemms); - - cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - +Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, + Variadic_Result_Type output_list, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], @@ -56,6 +31,18 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + if (num_gemms <= 0) { + return ffi_with_cuda_error_check(); + } + size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; + size_t expected_output_size = num_gemms + 1; + size_t actual_input_size = input_list.size(); + size_t actual_output_size = output_list.size(); + NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", + expected_input_size, actual_input_size); + NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", + expected_output_size, actual_output_size); + bool trans_lhs = true; bool trans_rhs = false; auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh std::vector out_list; std::vector workspace_list; + int lhs_list_offset = 0; + int rhs_list_offset = num_gemms; + int lhs_sinv_list_offset = 2 * num_gemms; + int rhs_sinv_list_offset = 3 * num_gemms; + int bias_list_offset = 4 * num_gemms; + int out_list_offset = 0; for (int i = 0; i < num_gemms; i++) { - size_t m = dim_list_host[i * 3]; - size_t n = dim_list_host[i * 3 + 1]; - size_t k = dim_list_host[i * 3 + 2]; + Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); + Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); + Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); + Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); + Result_Type out_i = output_list.get(out_list_offset + i).value(); + + DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); + DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); + DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); + + void *lhs_ptr = lhs_i.untyped_data(); + void *rhs_ptr = rhs_i.untyped_data(); + void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); + void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); + void *out_ptr = out_i->untyped_data(); + + // Placeholder for bias since it can be empty + DType bias_dtype = DType::kFloat32; + void *bias_ptr = nullptr; + + auto lhs_shape_ = lhs_i.dimensions(); + auto rhs_shape_ = rhs_i.dimensions(); + + // lhs and rhs has shape [1, m, k] and [1, n, k] + size_t m = lhs_shape_[1]; + size_t n = rhs_shape_[1]; + size_t k = lhs_shape_[2]; auto lhs_shape = std::vector{m, k}; auto rhs_shape = std::vector{n, k}; @@ -90,54 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - auto lhs_i = TensorWrapper(static_cast(lhs_ptr), lhs_shape, lhs_dtype, nullptr, - nullptr, reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i = TensorWrapper(static_cast(rhs_ptr), rhs_shape, rhs_dtype, nullptr, - nullptr, reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", - MXFP8_BLOCK_SIZE, k); - size_t sinv_k = k / MXFP8_BLOCK_SIZE; - lhs_sinv_shape[0] = m; - lhs_sinv_shape[1] = sinv_k; - rhs_sinv_shape[0] = n; - rhs_sinv_shape[1] = sinv_k; - + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + float *amax_dptr = nullptr; + float *scale_dptr = nullptr; + auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(lhs_sinv_ptr)); + auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(rhs_sinv_ptr)); + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); + } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { // Note: the scale_inv array should have been swizzled in Python before lowering - TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING); - TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, - lhs_sinv_shape); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, - rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); + auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); + auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); + for (int i = 0; i < 2; i++) { + lhs_sinv_shape[i] = lhs_sinv_shape_[i]; + rhs_sinv_shape[i] = rhs_sinv_shape_[i]; + } + + NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); + TensorWrapper lhs_i_(nvte_scaling_mode); + TensorWrapper rhs_i_(nvte_scaling_mode); + lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); + rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); + lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); + rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); + + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); } else { - NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); + NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); - lhs_ptr += m * k * lhs_dtype_bytes; - rhs_ptr += n * k * rhs_dtype_bytes; - out_ptr += m * n * out_dtype_bytes; - lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes; - + auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); void *pre_gelu_ptr = nullptr; auto bias_shape = std::vector{0}; auto pre_gelu_shape = std::vector{0}; - if (bias_ptr != nullptr) bias_shape[0] = n; + if (has_bias) { + auto bias_i_get = input_list.get(bias_list_offset + i); + Buffer_Type bias_i = bias_i_get.value(); + bias_ptr = bias_i.untyped_data(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); + bias_shape[0] = n; + } auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes; auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); - out_wrapper_list.push_back(std::move(out_i)); + out_wrapper_list.push_back(std::move(out_i_)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -148,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh out_list.push_back(out_wrapper_list.back().data()); } + auto workspace_get = output_list.get(num_gemms); + Result_Type workspace = workspace_get.value(); + uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); + size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -165,49 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh return ffi_with_cuda_error_check(); } -Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, - Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, - Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, - Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) { - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv_flatten.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv_flatten.untyped_data()); - auto bias_ptr = reinterpret_cast(bias_flatten.untyped_data()); - auto dim_list_ptr = reinterpret_cast(dim_list.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type()); - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type()); - - // Outputs - auto out_ptr = reinterpret_cast(out_flatten->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type()); - auto workspace_ptr = reinterpret_cast(workspace_flatten->untyped_data()); - auto workspace_size = workspace_flatten->dimensions().back() / num_streams; - - return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype, - rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype, - workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode, - stream); -} - XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_flatten - .Arg() // lhs_sinv_flatten - .Arg() // rhs_flatten - .Arg() // rhs_sinv_flatten - .Arg() // bias_flatten - .Arg() // dim_list - .Ret() // out_flatten - .Ret() // workspace_flatten + .RemainingArgs() // input list + .RemainingRets() // output list .Attr("num_gemms") - .Attr("scaling_mode"), + .Attr("scaling_mode") + .Attr("has_bias"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c8526e20c0..f7577c24f3 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -40,5 +40,28 @@ enum class QuantizeLayout { ROWWISE_COLWISE, }; +enum class JAXX_Scaling_Mode : int64_t { + NO_SCALING = 0, + DELAYED_TENSOR_SCALING = 1, + MXFP8_1D_SCALING = 2, +}; + +static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { + switch (mode) { + case JAXX_Scaling_Mode::NO_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::MXFP8_1D_SCALING: + return NVTEScalingMode::NVTE_MXFP8_1D_SCALING; + break; + default: + NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); + break; + } +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 03855753cf..e23e42f528 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -14,7 +14,8 @@ namespace jax { pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training) { auto input_shape = std::vector{batch_size, hidden_size}; @@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - auto _scaling_mode = static_cast(scaling_mode); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated - if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { int temp = 1; output_tensor.set_columnwise_data(static_cast(&temp), out_dtype, input_shape); } @@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, @@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, int scaling_mode, bool is_2x) { + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); - auto _scaling_mode = static_cast(scaling_mode); auto _norm_type = static_cast(norm_type); auto _is_2x = static_cast(is_2x); @@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); if (is_fp8_dtype(out_dtype)) { @@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc scale_inv_buf->dimensions().back()}); } - if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); cudaMemsetAsync(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); @@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, @@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("zero_centered_gamma") .Attr("epsilon") .Attr("sm_margin") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index ebdfe461c7..5c165cccb6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("RMSNorm", NVTE_Norm_Type::RMSNorm) .export_values(); - pybind11::enum_(m, "NVTE_Scaling_Mode", pybind11::module_local()) - .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) - .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) - .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) + pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) + .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) + .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .export_values(); pybind11::enum_(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index b48ee8a9b9..481dbd7cdf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -13,7 +13,9 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ int temp = 0; auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); - auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); - output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); + + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + TensorWrapper dummy_workspace; nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), @@ -44,8 +73,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T Result_Type output_buf, Result_Type output_trans_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, - int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, - int64_t flatten_axis) { + JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -54,7 +83,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); - auto scaling_mode = static_cast(scaling_mode_enum); auto const quantize_layout = static_cast(quantize_layout_enum); auto *output = output_buf->untyped_data(); @@ -77,14 +105,14 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T std::vector workspace_shape{workspace_dims.begin(), workspace_dims.end()}; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); @@ -109,14 +137,16 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (quantize_layout == QuantizeLayout::COLWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -153,7 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("q_layout") .Attr("is_dbias") .Attr("flatten_axis"), diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a944848881..ef60052768 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -13,7 +13,6 @@ from flax import linen as nn from flax.linen import partitioning as nn_partitioning from jax import lax -from jax import nn as jax_nn from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name @@ -26,7 +25,12 @@ from ..activation import activation from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes -from ..cpp_extensions import is_softmax_kernel_available +from ..cpp_extensions import ( + is_softmax_kernel_available, + jax_scaled_softmax, + jax_scaled_masked_softmax, + jax_scaled_upper_triang_masked_softmax, +) from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..sharding import get_non_contracting_logical_axes @@ -168,10 +172,10 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp input_dtype = inputs.dtype logits = inputs - if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( + # use primitives + if is_softmax_kernel_available( self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype ): - if bias is not None: logits = logits + bias.astype(input_dtype) @@ -180,31 +184,22 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp mask_ = None outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type) + # use default jax based implementation else: - attention_bias = None - if mask is not None: - attention_bias = lax.select( - mask > 0, - jnp.full(mask.shape, -1e10), - jnp.full(mask.shape, 0.0), - ) - attention_bias = attention_bias.astype(input_dtype) - if bias is not None: - attention_bias = _combine_biases(attention_bias, bias) - - if attention_bias is not None: - logits = logits + attention_bias.astype(input_dtype) + logits = logits + bias.astype(input_dtype) - # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED - # and kernel is unavailable, then try on pure scaled softmax custom calls. - if is_softmax_kernel_available( - SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype - ): - outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) + if self.softmax_type is SoftmaxType.SCALED: + outputs = jax_scaled_softmax(logits, self.scale_factor) + elif self.softmax_type is SoftmaxType.SCALED_MASKED: + outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor) + elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: + outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor) else: - outputs = jax_nn.softmax(logits * self.scale_factor) - + raise ValueError( + f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED," + " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]" + ) assert input_dtype == outputs.dtype return outputs @@ -361,7 +356,7 @@ def generate_quantize_meta(quantizer_name: str): ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 70a4da9186..10a0a06824 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -220,11 +220,11 @@ def convert_to_softmax_type(attn_mask_type, mask): if mask is not None: mask = apply_swa_mask(mask) # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + if mask is not None: + return SoftmaxType.SCALED_MASKED, mask + if attn_mask_type is AttnMaskType.CAUSAL_MASK: return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask - if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: - if mask is not None: - return SoftmaxType.SCALED_MASKED, mask + if attn_mask_type is AttnMaskType.NO_MASK: return SoftmaxType.SCALED, mask raise ValueError( f"Unsupported {attn_mask_type=}, supported attn_mask_type=" @@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods .. note:: THD format only supports 'padding' or 'causal_padding' mask type. + attn_mask_type mask/sequence_descriptor SWA softmax type + -------------------------------------------------------------------------------------------- + no_mask None None SCALED + causal None None SCALED_UPPER_TRIANG_MASKED + causal None Yes SCALED_MASKED + padding Required Yes/No SCALED_MASKED + padding_causal Required Yes/No SCALED_MASKED + attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. diff --git a/transformer_engine/jax/praxis/__init__.py b/transformer_engine/jax/praxis/__init__.py deleted file mode 100644 index 5352f1f53b..0000000000 --- a/transformer_engine/jax/praxis/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Praxis related Modules""" -from .module import FusedSoftmax, LayerNorm -from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer -from .transformer import DotProductAttention, MultiHeadAttention -from .transformer import RelativePositionBiases, TransformerLayer -from ..flax.transformer import TransformerLayerType diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py deleted file mode 100644 index ce407f94fc..0000000000 --- a/transformer_engine/jax/praxis/module.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -""" -Praxis Modules -""" -from dataclasses import field -from functools import partial -from typing import Callable, Iterable, Sequence, Tuple, Union - -from praxis import pax_fiddle -from praxis.base_layer import init_var -from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection -from praxis.layers import flax_adapter -from praxis.pytypes import JTensor - -from ..fp8 import FP8Helper -from ..flax.module import DenseGeneral, LayerNormDenseGeneral -from ..flax.module import LayerNorm as flax_LayerNorm -from ..flax.module import LayerNormMLP as flax_LayerNormMLP -from ..flax.module import Softmax -from ..softmax import SoftmaxType - - -def _generate_ln_scale_init(scale_init): - if scale_init is not None: - return TransformerEngineBaseLayer.generate_params_init("scale", scale_init) - return scale_init - - -class TransformerEngineBaseLayer(BaseLayer): - """TransformerEngineBaseLayer""" - - logical_axes_rules: Tuple[Tuple, ...] = None - - @staticmethod - def generate_params_init(name: str, initializer: WeightInit): - """generate_params_init""" - - def kernel_init(key, shape, dtype): - wp = WeightHParams(shape=shape, init=initializer, dtype=dtype) - return init_var(wp, key, name) - - return kernel_init - - def create_layer(self, name, flax_module_cls): - """create_layer""" - - fp8_collection_map = { - FP8Helper.FP8_COLLECTION_NAME: [ - WeightHParamsCollection.SKIP_LP_REGULARIZATION, - WeightHParamsCollection.OVERWRITE_WITH_GRADIENT, - WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION, - ] - } - - flax_module_p = pax_fiddle.Config( - flax_adapter.FlaxModuleAdapter, - module_factory_method=flax_module_cls, - logical_axes_rules=self.logical_axes_rules, - var_collection_map=fp8_collection_map, - ici_mesh_shape=self.ici_mesh_shape, - dcn_mesh_shape=self.dcn_mesh_shape, - mesh_axis_names=self.mesh_axis_names, - ) - - self.create_child(name, flax_module_p.clone()) - - -class LayerNorm(TransformerEngineBaseLayer): - """LayerNorm""" - - epsilon: float = 1e-6 - layernorm_type: str = "layernorm" - zero_centered_gamma: bool = False - scale_init: WeightInit = None - scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes: Tuple[str, ...] = () - transpose_batch_sequence: bool = False - - def setup(self) -> None: - """setup""" - super().setup() - - ln_cls = partial( - flax_LayerNorm, - epsilon=self.epsilon, - layernorm_type=self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - scale_init=_generate_ln_scale_init(self.scale_init), - scale_axes=self.scale_axes, - bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init), - bias_axes=self.bias_axes, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - ) - - self.create_layer("layer_norm", ln_cls) - - def __call__(self, x: JTensor) -> JTensor: - """__call__""" - return self.layer_norm(x) - - -class FusedSoftmax(TransformerEngineBaseLayer): - """FusedSoftmax""" - - scale_factor: float = 1.0 - softmax_type: SoftmaxType = SoftmaxType.SCALED - - def setup(self) -> None: - """setup""" - super().setup() - - fused_softmax_cls = partial( - Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type - ) - - self.create_layer("fused_softmax", fused_softmax_cls) - - def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor: - """__call__""" - return self.fused_softmax(x, mask, bias) - - -class Linear(TransformerEngineBaseLayer): - """Linear""" - - out_features: int = 512 - kernel_axes: Tuple[str, ...] = () - use_bias: bool = True - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes: Tuple[str, ...] = () - enable_low_rank_adaptation: bool = False - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - axis: Union[Iterable[int], int] = -1 - transpose_batch_sequence: bool = False - - def setup(self) -> None: - """setup""" - super().setup() - - dense_general_cls = partial( - DenseGeneral, - features=self.out_features, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - kernel_axes=self.kernel_axes, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - bias_axes=self.bias_axes, - enable_low_rank_adaptation=self.enable_low_rank_adaptation, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - axis=self.axis, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - ) - - self.create_layer("linear", dense_general_cls) - - def __call__(self, x: JTensor) -> JTensor: - """__call__""" - return self.linear(x) - - -class LayerNormLinear(TransformerEngineBaseLayer): - """LayerNormLinear""" - - out_features: int = 512 - enable_layernorm: bool = True - layernorm_type: str = "layernorm" - epsilon: float = 1e-6 - zero_centered_gamma: bool = False - scale_init: WeightInit = None - scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=1.0) - ) - ln_bias_axes: Tuple[str, ...] = () - kernel_axes: Tuple[str, ...] = () - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes: Tuple[str, ...] = () - enable_low_rank_adaptation: bool = False - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - return_layernorm_output: bool = True - axis: Union[Iterable[int], int] = -1 - transpose_batch_sequence: bool = False - depth_scaling: float = None - - def setup(self) -> None: - """setup""" - super().setup() - - ln_dense_general_cls = partial( - LayerNormDenseGeneral, - features=self.out_features, - enable_layernorm=self.enable_layernorm, - layernorm_type=self.layernorm_type, - epsilon=self.epsilon, - zero_centered_gamma=self.zero_centered_gamma, - scale_init=_generate_ln_scale_init(self.scale_init), - scale_axes=self.scale_axes, - ln_bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", self.ln_bias_init - ), - ln_bias_axes=self.ln_bias_axes, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - kernel_axes=self.kernel_axes, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - bias_axes=self.bias_axes, - enable_low_rank_adaptation=self.enable_low_rank_adaptation, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - return_layernorm_output=self.return_layernorm_output, - axis=self.axis, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - depth_scaling=self.depth_scaling, - ) - - self.create_layer("ln_linear", ln_dense_general_cls) - - def __call__(self, x: JTensor) -> JTensor: - """__call__""" - return self.ln_linear(x) - - -class LayerNormMLP(TransformerEngineBaseLayer): - """LayerNormMLP""" - - intermediate_dim: int = 2048 - enable_layernorm: bool = True - layernorm_type: str = "layernorm" - epsilon: float = 1e-6 - zero_centered_gamma: bool = False - scale_init: WeightInit = None - scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=1.0) - ) - ln_bias_axes: Tuple[str, ...] = () - kernel_axes_1: Tuple[str, ...] = () - kernel_axes_2: Tuple[str, ...] = () - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes_1: Tuple[str, ...] = () - bias_axes_2: Tuple[str, ...] = () - enable_low_rank_adaptation: bool = False - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ("relu",) - intermediate_dropout_rate: float = 0.1 - intermediate_hidden_dropout_dims: Sequence[int] = () - axis: Union[Iterable[int], int] = -1 - transpose_batch_sequence: bool = False - - def setup(self) -> None: - """setup""" - super().setup() - - ln_mlp_cls = partial( - flax_LayerNormMLP, - intermediate_dim=self.intermediate_dim, - enable_layernorm=self.enable_layernorm, - layernorm_type=self.layernorm_type, - epsilon=self.epsilon, - zero_centered_gamma=self.zero_centered_gamma, - scale_init=_generate_ln_scale_init(self.scale_init), - scale_axes=self.scale_axes, - ln_bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", self.ln_bias_init - ), - ln_bias_axes=self.ln_bias_axes, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - kernel_axes_1=self.kernel_axes_1, - kernel_axes_2=self.kernel_axes_2, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - bias_axes_1=self.bias_axes_1, - bias_axes_2=self.bias_axes_2, - enable_low_rank_adaptation=self.enable_low_rank_adaptation, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - return_layernorm_output=self.return_layernorm_output, - activations=self.activations, - intermediate_dropout_rate=self.intermediate_dropout_rate, - intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims, - axis=self.axis, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - ) - - self.create_layer("ln_mlp", ln_mlp_cls) - - def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor: - """__call__""" - return self.ln_mlp(x, deterministic) diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py deleted file mode 100644 index f441834355..0000000000 --- a/transformer_engine/jax/praxis/transformer.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -""" -Praxis Modules related Transformer -""" -from dataclasses import field -from functools import partial -from typing import Optional, Sequence, Tuple -import warnings - -from praxis import pax_fiddle -from praxis.base_layer import WeightInit -from praxis.pytypes import JTensor - -from .module import TransformerEngineBaseLayer -from ..flax.transformer import TransformerLayerType -from ..flax.transformer import DotProductAttention as flax_DotProductAttention -from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention -from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases -from ..flax.transformer import TransformerLayer as flax_TransformerLayer -from ..attention import AttnBiasType, AttnMaskType - - -class RelativePositionBiases(TransformerEngineBaseLayer): - """RelativePositionBiases""" - - num_buckets: int = 32 - max_distance: int = 128 - num_attention_heads: int = 64 - embedding_init: WeightInit = None - embedding_axes: Tuple[str, ...] = () - - @staticmethod - def generate_embedding_init(init, num_attention_heads, num_buckets): - """generate_embedding_init""" - embedding_init = init - if embedding_init is None: - rb_stddev = (num_attention_heads * num_buckets) ** -0.5 - embedding_init = WeightInit.Gaussian(rb_stddev) - return embedding_init - - def setup(self) -> None: - """setup""" - super().setup() - - embedding_init = RelativePositionBiases.generate_embedding_init( - self.embedding_init, self.num_attention_heads, self.num_buckets - ) - - rpb_cls = partial( - flax_RelativePositionBiases, - num_buckets=self.num_buckets, - max_distance=self.max_distance, - num_attention_heads=self.num_attention_heads, - embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", embedding_init - ), - embedding_axes=self.embedding_axes, - dtype=self.dtype, - ) - - self.create_layer("relative_position_bias", rpb_cls) - - def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor: - """__call__""" - return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional) - - -class DotProductAttention(TransformerEngineBaseLayer): - """DotProductAttention""" - - head_dim: int = 0 - num_attention_heads: int = 0 - num_gqa_groups: Optional[int] = None - attention_dropout: float = 0.0 - attn_mask_type: AttnMaskType = "causal" - attn_bias_type: AttnBiasType = None - dropout_rng_name: str = "dropout" - float32_logits: bool = False - qkv_layout: str = "bshd_bshd_bshd" - scale_factor: Optional[float] = None - transpose_batch_sequence: bool = True - window_size: Optional[Tuple[int, int]] = None - - def setup(self) -> None: - """setup""" - super().setup() - - assert self.head_dim > 0, f"{self.head_dim=}" - assert self.num_attention_heads > 0, f"{self.num_attention_heads=}" - - dpa_cls = partial( - flax_DotProductAttention, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - attention_dropout=self.attention_dropout, - dtype=self.dtype, - dropout_rng_name=self.dropout_rng_name, - float32_logits=self.float32_logits, - qkv_layout=self.qkv_layout, - scale_factor=self.scale_factor, - transpose_batch_sequence=self.transpose_batch_sequence, - window_size=self.window_size, - ) - - self.create_layer("dot_product_attention", dpa_cls) - - def __call__( - self, - query: JTensor, - key: JTensor, - value: JTensor, - mask: Optional[JTensor] = None, - bias: Optional[JTensor] = None, - *, - deterministic: bool = False, - ) -> JTensor: - """__call__""" - return self.dot_product_attention( - query, key, value, mask, bias, deterministic=deterministic - ) - - -class MultiHeadAttention(TransformerEngineBaseLayer): - """MultiHeadAttention""" - - head_dim: int = 0 - num_attention_heads: int = 0 - num_gqa_groups: Optional[int] = None - attention_dropout: float = 0.0 - dropout_rng_name: str = "dropout" - input_layernorm: bool = True - layernorm_type: str = "layernorm" - layernorm_epsilon: float = 1e-6 - zero_centered_gamma: bool = False - return_layernorm_output: bool = False - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - attn_mask_type: str = "causal" - attn_bias_type: Optional[str] = None - enable_rotary_pos_emb: bool = False - rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = "consecutive" - low_rank_adaptation_scope: str = "none" - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - fuse_qkv_params: bool = True - transpose_batch_sequence: bool = True - enable_sequence_parallel: bool = False - scale_attn_logits: bool = False - scaled_query_init: bool = True - float32_logits: bool = False - window_size: Optional[Tuple[int, int]] = None - - # Deprecated parameters - num_heads: Optional[int] = None - dropout_rate: Optional[float] = None - output_layernorm: Optional[bool] = None - apply_residual_connection_post_layernorm: Optional[bool] = None - fuse_qkv: Optional[bool] = None - - def __post_init__(self): - # Deal with the deprecated parameters - if self.num_heads is not None: - self.num_attention_heads = self.num_heads - warnings.warn( - f"{__class__}.num_heads is deprecated. It will be removed recently. " - f"Please uses {__class__}.num_attention_heads as the new API.", - DeprecationWarning, - ) - if self.dropout_rate is not None: - self.attention_dropout = self.dropout_rate - warnings.warn( - f"{__class__}.dropout_rate is deprecated. It will be removed recently. " - f"Please use {__class__}.attention_dropout as the new API.", - DeprecationWarning, - ) - if self.apply_residual_connection_post_layernorm is not None: - warnings.warn( - f"{__class__}.apply_residual_connection_post_layernorm is deprecated. " - f"It will be removed recently, please use {__class__}.return_layernorm_output.", - DeprecationWarning, - ) - if self.fuse_qkv is not None: - warnings.warn( - f"{__class__}.fuse_qkv is deprecated. It will be removed recently. " - f"Please use {__class__}.fuse_qkv_params as the new API.", - DeprecationWarning, - ) - assert self.output_layernorm is None, ( - f"{__class__}.output_layernorm is deprecated. It will be removed recently. " - f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm." - ) - - if self.num_gqa_groups is None: - self.num_gqa_groups = self.num_heads - super().__post_init__() - - def setup(self) -> None: - """setup""" - super().setup() - - assert self.head_dim > 0, f"{self.head_dim=}" - assert self.num_attention_heads > 0, f"{self.num_attention_heads=}" - - mha_cls = partial( - flax_MultiHeadAttention, - dtype=self.dtype, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - attention_dropout=self.attention_dropout, - dropout_rng_name=self.dropout_rng_name, - input_layernorm=self.input_layernorm, - layernorm_type=self.layernorm_type, - layernorm_epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - return_layernorm_output=self.return_layernorm_output, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - enable_rotary_pos_emb=self.enable_rotary_pos_emb, - rotary_pos_emb_windows=self.rotary_pos_emb_windows, - rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, - low_rank_adaptation_scope=self.low_rank_adaptation_scope, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - fuse_qkv_params=self.fuse_qkv_params, - transpose_batch_sequence=self.transpose_batch_sequence, - enable_sequence_parallel=self.enable_sequence_parallel, - scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init, - float32_logits=self.float32_logits, - window_size=self.window_size, - ) - - self.create_layer("multi_head_attn", mha_cls) - - def __call__( - self, - inputs_q: JTensor, - inputs_kv: JTensor, - mask: Optional[JTensor] = None, - bias: Optional[JTensor] = None, - *, - decode: bool = False, - deterministic: bool = False, - ) -> JTensor: - """__call__""" - return self.multi_head_attn( - inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic - ) - - -class TransformerLayer(TransformerEngineBaseLayer): - """TransformerLayer""" - - hidden_size: int = 512 - mlp_hidden_size: int = 2048 - num_attention_heads: int = 8 - num_gqa_groups: Optional[int] = None - layernorm_type: str = "layernorm" - layernorm_epsilon: float = 1e-6 - zero_centered_gamma: bool = False - hidden_dropout: float = 0.1 - hidden_dropout_dims: Sequence[int] = () - attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 - intermediate_dropout_dims: Sequence[int] = () - dropout_rng_name: str = "dropout" - mlp_activations: Sequence[str] = ("relu",) - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - apply_residual_connection_post_layernorm: bool = False - output_layernorm: bool = False - float32_attention_logits: bool = False - layer_type: TransformerLayerType = TransformerLayerType.ENCODER - self_attn_mask_type: str = "causal" - self_attn_bias_type: Optional[str] = None - enable_rotary_pos_emb: bool = False - rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = "consecutive" - low_rank_adaptation_scope: str = "none" - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - enable_relative_embedding: bool = True - relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) - drop_path: float = 0.0 - fuse_qkv_params: bool = True - transpose_batch_sequence: bool = False - enable_sequence_parallel: bool = False - scale_attn_logits: bool = False - scaled_query_init: bool = True - window_size: Optional[Tuple[int, int]] = None - - def __post_init__(self): - if self.num_gqa_groups is None: - self.num_gqa_groups = self.num_attention_heads - super().__post_init__() - - def setup(self) -> None: - """setup""" - super().setup() - - relative_embedding_flax_module = None - if self.enable_relative_embedding and self.relative_embedding is not None: - assert self.relative_embedding.num_attention_heads == self.num_attention_heads, ( - "TransformerLayer.relative_embedding.num_attention_heads shoule be" - "the same as TransformerLayer.num_attention_heads." - ) - - embedding_init = RelativePositionBiases.generate_embedding_init( - self.relative_embedding.embedding_init, - self.relative_embedding.num_attention_heads, - self.relative_embedding.num_buckets, - ) - - relative_embedding_flax_module = flax_RelativePositionBiases( - num_buckets=self.relative_embedding.num_buckets, - max_distance=self.relative_embedding.max_distance, - num_attention_heads=self.relative_embedding.num_attention_heads, - embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", embedding_init - ), - embedding_axes=self.relative_embedding.embedding_axes, - dtype=self.relative_embedding.dtype, - ) - - transformerlayer_cls = partial( - flax_TransformerLayer, - dtype=self.dtype, - hidden_size=self.hidden_size, - mlp_hidden_size=self.mlp_hidden_size, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - layernorm_type=self.layernorm_type, - layernorm_epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - hidden_dropout=self.hidden_dropout, - hidden_dropout_dims=self.hidden_dropout_dims, - attention_dropout=self.attention_dropout, - intermediate_dropout=self.intermediate_dropout, - intermediate_dropout_dims=self.intermediate_dropout_dims, - dropout_rng_name=self.dropout_rng_name, - mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mha_kernel", self.params_init - ), - mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mlp_kernel", self.params_init - ), - mlp_activations=self.mlp_activations, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, - output_layernorm=self.output_layernorm, - float32_attention_logits=self.float32_attention_logits, - layer_type=self.layer_type, - self_attn_mask_type=self.self_attn_mask_type, - self_attn_bias_type=self.self_attn_bias_type, - enable_rotary_pos_emb=self.enable_rotary_pos_emb, - rotary_pos_emb_windows=self.rotary_pos_emb_windows, - rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, - low_rank_adaptation_scope=self.low_rank_adaptation_scope, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - enable_relative_embedding=self.enable_relative_embedding, - relative_embedding=relative_embedding_flax_module, - drop_path=self.drop_path, - fuse_qkv_params=self.fuse_qkv_params, - transpose_batch_sequence=self.transpose_batch_sequence, - enable_sequence_parallel=self.enable_sequence_parallel, - scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init, - window_size=self.window_size, - ) - - self.create_layer("transformerlayer", transformerlayer_cls) - - def __call__( - self, - inputs: JTensor, - encoded: JTensor = None, - attention_mask: JTensor = None, - encoder_decoder_mask: JTensor = None, - deterministic: bool = False, - decode: bool = False, - max_decode_length: bool = None, - ) -> JTensor: - """__call__""" - return self.transformerlayer( - inputs, - encoded, - attention_mask, - encoder_decoder_mask, - deterministic, - decode, - max_decode_length, - ) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index b1e9ba03b4..d68eb3c6c2 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -84,8 +84,8 @@ def _dq_func_block_scaling(scaled_tensor): ) funcs = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, - ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling, + ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, + ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling, } @staticmethod diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 7d144aa69d..98f280b9a9 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -94,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ gpu_arch = get_device_compute_capability(gpu_id) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _check_block_scaling_fp8_support(gpu_arch) return (False, "Unsupported scaling_mode!") def is_fp8_available( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, ) -> Tuple[bool, str]: """Check if FP8 is available for the given scaling mode and GPU. @@ -179,9 +179,9 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ValueError: If the recipe type is not supported """ if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.NVTE_DELAYED_TENSOR_SCALING + return ScalingMode.DELAYED_TENSOR_SCALING if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.NVTE_MXFP8_1D_SCALING + return ScalingMode.MXFP8_1D_SCALING raise ValueError("Invalid fp8_recipe!") @@ -217,7 +217,7 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False IF_QUANTIZE_2X: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING + SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 @@ -253,11 +253,11 @@ def finalize(cls) -> None: cls.MARGIN = 0.0 cls.FP8_FORMAT = recipe.Format.HYBRID cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.FP8_2X_ACC_FPROP = False cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.IF_QUANTIZE_2X = False # DelayedScaling cls.AMAX_HISTORY_LEN = 1024 diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index bd7045453b..b57043a034 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) @@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): q_layout: Quantization axis (default: ROWWISE_COLWISE) """ - scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING + scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE def get_data_layout(self) -> str: @@ -530,8 +530,8 @@ class QuantizerFactory: """ quantizer_type_map = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, + ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, } @staticmethod @@ -556,8 +556,9 @@ def create( A single quantizer or tuple of quantizers """ # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted - # assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING - if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING): + assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" + # import pdb; pdb.set_trace() + if scaling_mode == ScalingMode.NO_SCALING: quantizers = [None] * n_quantizers else: quantizers = [] @@ -651,4 +652,4 @@ def create_set( return q_set[0] if len(q_set) == 1 else tuple(q_set) -noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING) +noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 95bbc9bb41..2a5b23bdcf 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -16,11 +16,33 @@ from functools import reduce import operator +from jax.experimental.custom_partitioning import CompoundFactor from jax.tree_util import register_pytree_node_class import jax.numpy as jnp +from transformer_engine_jax import JAXX_Scaling_Mode -__all__ = ["ScalingMode"] + +__all__ = ["QuantizeShardyRules", "ScalingMode"] + + +@dataclass +class QuantizeShardyRules: + """Information necessary to shard scale tensors with Shardy. + + Attributes: + input_spec: Specification for the input axes + rowwise_rule: Sharding rule for the row-wise scale tensor, depends on + the axes in `input_spec` + colwise_rule: Likewise for the column-wise scale tensor. + factor_sizes: For block scaling, contains the block size factor, which is + used in `input_spec`. + """ + + input_spec: Tuple[str] + rowwise_rule: Tuple[str] + colwise_rule: Tuple[str] + factor_sizes: Dict[str, int] class ScalingModeMetadataImpl(ABC): @@ -57,6 +79,21 @@ def get_scale_shape( The shape for scale tensors """ + @abstractmethod + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for delayed scaling mode. @@ -93,6 +130,23 @@ def get_scale_shape( del data_shape, is_colwise return (1,) + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + del flatten_axis + input_spec = tuple(f"x{i}" for i in range(input_rank)) + return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {}) + class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for block scaling mode. @@ -215,8 +269,44 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. -# (Phuong: Map the NVTEScalingMode value to the ScalingMode + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + + Returns: + The Shardy rules for the scaling mode + """ + input_spec = [f"x{i}" for i in range(input_rank)] + + # We have to use two different factors in the two CompoundFactors because of Shardy + # verifier requirements, even though they are the same. + rowwise_var = unique_var + colwise_var = f"{unique_var}_" + input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") + input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") + + # The rowwise and colwise scale tensors should be sharded the same way as the input. + # However, we need to adjust the dimensions where the block scaling factor applies. + rowwise = input_spec.copy() + rowwise[-1] = rowwise_var + + colwise = input_spec.copy() + colwise[flatten_axis - 1] = colwise_var + + # This implementation needs to be updated for different block dims. + assert self._block_dims == (1, 32) + + return QuantizeShardyRules( + tuple(input_spec), + tuple(rowwise), + tuple(colwise), + {"block_size_rowwise": 32, "block_size_colwise": 32}, + ) @dataclass(frozen=True) @@ -225,16 +315,14 @@ class ScalingMode(Enum): """Enumeration of tensor scaling modes with their corresponding metadata implementations. This class defines the available scaling modes for tensor quantization: - - NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - - NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - - NVTE_INVALID_SCALING: Invalid scaling mode - - NVTE_NO_SCALING: No scaling applied + - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales + - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales + - NO_SCALING: No scaling applied """ - NVTE_DELAYED_TENSOR_SCALING = 0 - NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 100 - NVTE_NO_SCALING = 1000 + NO_SCALING = JAXX_Scaling_Mode.NO_SCALING + DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING + MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -293,6 +381,20 @@ def get_scale_shape( """ return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis=-1 + ) -> Tuple[Tuple[str]]: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + + Returns: + The Shardy rules for the scaling mode + """ + return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + def __eq__(self, other): """Compare this scaling mode with another. @@ -329,8 +431,8 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR - ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c34a235d94..0ef30f4728 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -236,13 +236,12 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # TODO(Phuong): Handle padding !? scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv - # TODO(Phuong): constaint padded scale_inv? return ScaledTensor1x( data=data, scale_inv=scale_inv, diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index a9fc6b6b6f..ef3c05a882 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -101,7 +101,7 @@ ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, install_requires=["jax", "flax>=0.7.1"], - tests_require=["numpy", "praxis"], + tests_require=["numpy"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6440c628cd..3db13593f5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -19,6 +19,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.utils import ( get_cudnn_version, nvtx_range_pop, @@ -80,6 +81,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb +from .cpu_offload import mark_activation_offload # Setup Attention Logging @@ -616,7 +618,7 @@ def forward( rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -1564,7 +1566,7 @@ def backward(ctx, dout): rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) @@ -4321,10 +4323,9 @@ def forward( from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] - for tensor in tensor_list: - if tensor is not None: - tensor.activation_offloading = True + mark_activation_offload( + query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv + ) with self.attention_dropout_ctx(): # | API | use cases @@ -4726,12 +4727,9 @@ def forward( else: tensor_list = [q, k, v, out_save] - tensor_list.extend(aux_ctx_tensors) - qkv_layout = "sbhd_sbhd_sbhd" - for tensor in tensor_list: - if tensor is not None: - tensor.activation_offloading = True + mark_activation_offload(*tensor_list) + mark_activation_offload(*aux_ctx_tensors) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 @@ -6480,6 +6478,8 @@ class MultiheadAttention(torch.nn.Module): equal length. Please note that these formats do not reflect how tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. For that, please use `get_qkv_layout` to gain the layout information. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -6558,6 +6558,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -6609,6 +6610,8 @@ def __init__( self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups + self.name = name + common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, "tp_group": tp_group, @@ -6649,6 +6652,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_qkv" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6660,6 +6664,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=parameters_split, + name=name + ".linear_qkv" if name is not None else None, **common_gemm_kwargs, ) elif self.attention_type == "cross": @@ -6681,6 +6686,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_q" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6691,6 +6697,7 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, + name=name + ".linear_q" if name is not None else None, **common_gemm_kwargs, ) self.key_value = Linear( @@ -6701,6 +6708,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=("key", "value") if not fuse_qkv_params else None, + name=name + ".linear_kv" if name is not None else None, **common_gemm_kwargs, ) @@ -6730,6 +6738,7 @@ def __init__( ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, ub_name="proj", + name=name + ".proj" if name is not None else None, **common_gemm_kwargs, ) @@ -6920,6 +6929,9 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 948a13a03e..b970d0549d 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -9,11 +9,11 @@ import torch import transformer_engine_torch as tex from ..constants import TE_DType -from ..utils import assert_dim_for_fp8_exec, get_sm_count +from ..utils import get_sm_count from ..tensor.quantized_tensor import Quantizer -from ..tensor._internal.float8_tensor_base import Float8TensorBase -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ "general_gemm", @@ -27,46 +27,6 @@ def _empty_tensor() -> torch.Tensor: return torch.Tensor().cuda() -def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str): - """Swizzle gemm inputs and return original scaling factor inverses.""" - if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase): - return None - - original_scale_inverses = ( - A._rowwise_scale_inv, - A._columnwise_scale_inv, - B._rowwise_scale_inv, - B._columnwise_scale_inv, - ) - - if layout[0] == "T": - A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv) - else: - A._columnwise_scale_inv = tex.columnwise_swizzle( - A._columnwise_data, A._columnwise_scale_inv - ) - - if layout[1] == "N": - B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv) - else: - B._columnwise_scale_inv = tex.columnwise_swizzle( - B._columnwise_data, B._columnwise_scale_inv - ) - - return original_scale_inverses - - -def reset_swizzled_inputs(A, B, scale_inverses): - """Reset the swizzled scale inverses after GEMM.""" - if scale_inverses is not None: - ( - A._rowwise_scale_inv, - A._columnwise_scale_inv, - B._rowwise_scale_inv, - B._columnwise_scale_inv, - ) = scale_inverses - - def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -109,9 +69,20 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + debug_quantizer = None + if isinstance(quantization_params, DebugQuantizer): + debug_quantizer = quantization_params + quantization_params = quantization_params.parent_quantizer + A = A.get_tensor(not transa) + B = B.get_tensor(transb) + # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + # There is not use_split_accumulator == False + # implementation for Float8BlockwiseQTensorBase GEMM + use_split_accumulator = True args = ( A, transa, # transa @@ -137,9 +108,10 @@ def general_gemm( "bulk_overlap": bulk_overlap, } - original_scale_inverses = swizzle_inputs(A, B, layout) out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) - reset_swizzled_inputs(A, B, original_scale_inverses) + + if debug_quantizer is not None: + out = debug_quantizer.process_gemm_output(out) return out, bias_grad, gelu_input, extra_output @@ -169,14 +141,6 @@ def general_grouped_gemm( transa = layout[0] == "T" transb = layout[1] == "T" - # assert [a.is_contiguous() for a in A] - # assert [b.is_contiguous() for b in B] - - if isinstance(A[0], Float8TensorBase): - for a, b in zip(A, B): - assert_dim_for_fp8_exec(a._data) - assert_dim_for_fp8_exec(b._data) - empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 93df512ac6..814e699557 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -16,18 +16,22 @@ CPUOffloadEnabled = False -def set_offloading_param(tensor, param_name, value): +def mark_activation_offload(*tensors): """Set the type of the offloading needed for a tensor.""" - assert param_name in ["weight_offloading", "activation_offloading"] - if tensor is None: - return - if type(tensor) in [torch.Tensor, torch.nn.Parameter]: - setattr(tensor, param_name, value) - else: - data_tensors = tensor.get_data_tensors() - for tensor in data_tensors: - if tensor is not None: - setattr(tensor, param_name, value) + for tensor in tensors: + if tensor is None: + continue + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + tensor.activation_offloading = True + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + tensor.activation_offloading = True + # This is a hack to force clear the tensor after it is offloaded. + # It is needed, because .*TensorBase classes are saved in the ctx, + # and they contain the reference to their data tensors. + tensor.needs_force_clear = True def is_cpu_offload_enabled() -> bool: @@ -459,8 +463,15 @@ def synchronize_on_group_commit_forward(self, current_group): torch.cuda.current_stream().wait_stream(self.d2h_stream) # Time to free the activation memory after usage - for tensor_tag, _ in self.tensor_tag_to_buf.items(): + for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): if tensor_tag[0] == self.offloaded_group_count: + if hasattr(tensor_buf, "needs_force_clear"): + # Need to clear activation tensor - sometimes references persist in the code. + # This is the case for example with the Float8TensorBase class, + # which is saved directly inside the ctx while its internal tensors are + # saved inside save_for_backward. + tensor_buf.data = torch.Tensor() + # Release the pointer to the tensor self.tensor_tag_to_buf[tensor_tag] = None # Time to offload the next group @@ -538,7 +549,7 @@ def get_cpu_offload_context( num_layers: int = 1, model_layers: int = 1, offload_activations: bool = True, - offload_weights: bool = True, + offload_weights: bool = False, ): """ This function returns the CPU Offload context and the synchronizer function that needs to be @@ -570,28 +581,30 @@ def get_cpu_offload_context( """ - def tensor_need_offloading_checker_activations(tensor): - return hasattr(tensor, "activation_offloading") - - # This includes the Gradient Accumulation Buffer - def tensor_need_offloading_checker_weights(tensor): - return hasattr(tensor, "weight_offloading") - - def tensor_need_offloading_checker_all(tensor): - return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading") - - if offload_activations and offload_weights: - tensor_need_offloading_checker = tensor_need_offloading_checker_all - elif offload_activations: - tensor_need_offloading_checker = tensor_need_offloading_checker_activations - elif offload_weights: - tensor_need_offloading_checker = tensor_need_offloading_checker_weights - else: + if not offload_weights and not offload_activations: raise ValueError( "CPU Offloading is enabled while it is not " "mentioned what to offload (weights/activations)" ) + if offload_weights: + import warnings + + warnings.warn( + "Offloading weights is deprecated. Using offload_weights=True does not have any" + " effect.", + DeprecationWarning, + ) + + # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. + if not offload_activations: + return nullcontext(), lambda x: x + + def tensor_need_offloading_checker_activations(tensor): + return hasattr(tensor, "activation_offloading") + + tensor_need_offloading_checker = tensor_need_offloading_checker_activations + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( num_offload_group=num_layers, num_model_group=model_layers, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 338f1fcbb1..3b349e7f09 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. DType dtype; - - private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + + private: int block_scaling_dim = 2; public: diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a66fbf950d..770517a051 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -50,11 +50,11 @@ std::vector fused_attn_fwd( NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional page_table_k, const c10::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, + const std::optional page_table_k, const std::optional page_table_v, + py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, + const std::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -63,8 +63,8 @@ std::vector fused_attn_bwd( const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer); at::Tensor fa_prepare_fwd(at::Tensor qkvi); @@ -101,18 +101,22 @@ std::optional> te_general_grouped_gemm( bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +namespace transformer_engine::pytorch { + /*************************************************************************************************** * Transpose **************************************************************************************************/ -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, std::vector quantizer_list, transformer_engine::DType otype); at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, std::optional output = std::nullopt); +} // namespace transformer_engine::pytorch + namespace transformer_engine::pytorch { /*************************************************************************************************** @@ -266,12 +270,12 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank); /*************************************************************************************************** @@ -392,8 +396,6 @@ void nvshmem_finalize(); * swizzle **************************************************************************************************/ -void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); - at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv); at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 1ef6f5258d..bf037fe931 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + // sanity check, since activation fusion is not supported for blockwise quantization yet + // need to raise an error here instead of silently going into act_func with wrong numerics + NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); } else { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 424a988301..3414975b0e 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -8,7 +8,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank) { using namespace transformer_engine::pytorch; @@ -96,7 +96,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank) { using namespace transformer_engine::pytorch; TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index da82120f4a..37b6840f1a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -3,9 +3,11 @@ * * See LICENSE for license information. ************************************************************************/ + #include "extensions.h" #include "kv_cache.cuh" #include "thd_utils.cuh" +#include "transformer_engine/transformer_engine.h" constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -90,11 +92,11 @@ std::vector fused_attn_fwd( NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional page_table_k, const c10::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, + const std::optional page_table_k, const std::optional page_table_v, + py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, + const std::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; using namespace transformer_engine::pytorch; TensorWrapper te_Q, te_K, te_V, te_O, te_S; @@ -280,8 +282,8 @@ std::vector fused_attn_bwd( const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer) { using namespace transformer_engine; using namespace transformer_engine::pytorch; @@ -449,13 +451,13 @@ std::vector fused_attn_bwd( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - auto temp_vec = std::vector(tmp.begin(), tmp.end()); - const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()}; + const std::vector &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); + const std::vector tmp(signed_shape.begin(), signed_shape.end()); + NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - temp_shape}; + nvte_make_shape(tmp.data(), tmp.size())}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2c3ccff154..84e50dea22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob if (te_output.numel() == 0) return out; + QuantizationConfigWrapper quant_config; + quant_config.set_noop_tensor(te_noop.data()); + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; + // this config is used for cs scaling factor computation + // because compute scale is cannot be fused with quantize kernel + // so in nvte_quantize_v2 with current scaling, the quant config is not used again quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index ff61cd940c..5860d9ff2c 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -17,6 +17,7 @@ #include "extensions.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" +#include "util.h" namespace { @@ -175,8 +176,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + // Keep the swizzled scaling factor tensors alive during the GEMM. + std::vector> swizzled_scale_inverses_list; auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { + // Optionally swizzle the scaling factors + swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); + swizzled_scale_inverses_list.emplace_back( + std::move(swizzle_scaling_factors(B_tensor, !transb))); + if (comm_overlap) { // Prepare extra output tensor TensorWrapper extra_output_tensor; @@ -313,6 +321,8 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_vector, te_workspace_vector; std::vector wrappers; std::vector D_vectors; + // Keep the swizzled scaling factor tensors alive during the GEMMs. + std::vector> swizzled_scale_inverses_list; auto none = py::none(); @@ -379,6 +389,10 @@ std::optional> te_general_grouped_gemm( continue; } + // Optionally swizzle the scaling factors + swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa))); + swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb))); + auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index cbdeee5b48..dae6ce42e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -150,6 +150,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -166,15 +167,18 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::cast(mu), py::cast(rsigma)}; @@ -293,6 +297,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -309,15 +314,18 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::none(), py::cast(rsigma)}; diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 47282da504..1dc8dd0d17 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -52,18 +52,11 @@ std::tuple> moe_permute_fwd( sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, num_tokens * topK); - // Activations type - at::ScalarType _st; - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Byte; - else - _st = input.scalar_type(); - // Output buffer alloc num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; - at::Tensor permuted_output = torch::empty( - {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor permuted_output = + torch::empty({num_out_tokens, num_cols}, + torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); at::Tensor row_id_map = torch::empty( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); @@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d using namespace transformer_engine::pytorch; int num_cols = input.size(1); - // Activations type - at::ScalarType _st; - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Byte; - else - _st = input.scalar_type(); - // Output buffer alloc - at::Tensor unpermuted_output = torch::empty( - {num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor unpermuted_output = + torch::empty({num_tokens, num_cols}, + torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -136,17 +122,10 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); int num_cols = input_bwd.size(1); - // Activations type - at::ScalarType _st; - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Byte; - else - _st = input_bwd.scalar_type(); - // Output buffer alloc - at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, - torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor act_grad = + torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false)); at::Tensor prob_grad = torch::empty( {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 617ba42d4a..9b11ec5685 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -196,12 +196,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); - m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", - py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); + m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, + "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), + py::arg("quantizer_list"), py::arg("otype")); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); - m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), - py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", + py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), + py::call_guard()); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax")); diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 9ac6292e53..3be719eaf6 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -109,6 +109,7 @@ std::pair Float8Quantizer::create_tensor( } const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); opts = opts.dtype(torch::kFloat32); + // TODO: Replace with an empty tensor. at::Tensor scale_inv = at::reciprocal(scale); py::object ret; if (internal) { @@ -257,12 +258,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); - NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast(), - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires pow2 scales"); - NVTE_CHECK(quantizer.attr("amax_epsilon").cast() == 0.0, - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires amax_epsilon==0"); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, "Unsupported block scaling dim."); } diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index b127b5d75b..a16c43d2d9 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -6,14 +6,16 @@ #include "extensions.h" #include "transformer_engine/transformer_engine.h" +#include "util.h" -void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { +std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, + bool rowwise) { using namespace transformer_engine::pytorch; if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { - return; + return std::nullopt; } NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); @@ -48,9 +50,9 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); } else { - input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); - output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); } @@ -63,6 +65,8 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww } else { input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); } + + return swizzled_scale_inv; } at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) { diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index a873586032..5b8c121517 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -6,27 +6,38 @@ #include -#include "ATen/core/TensorBody.h" #include "extensions.h" +#include "pybind.h" -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, +namespace transformer_engine::pytorch { + +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, std::vector quantizer_list, transformer_engine::DType otype) { - using namespace transformer_engine::pytorch; + init_extension(); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; std::vector py_output_objects_list; std::vector tensor_wrappers; - auto none = py::none(); + if (output_list.has_value()) { + py_output_objects_list = output_list.value(); + } + + // Choose implementation + // Note: Currently only have fused kernel for FP8 cast-transpose + bool with_fused_kernel = true; // create TE tensors from input for (size_t i = 0; i < input_list.size(); i++) { - auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + auto input_tensor = makeTransformerEngineTensor(input_list[i]); const NVTEShape input_shape = input_tensor.shape(); transformer_engine::TensorWrapper output_tensor; + if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) { + with_fused_kernel = false; + } if (output_list == std::nullopt) { std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); @@ -48,16 +59,8 @@ std::vector fused_multi_quantize(std::vector input_list, NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), "Number of input and output tensors must match"); - // Choose implementation - // Note: Currently only have fused kernel for FP8 cast-transpose - bool with_fused_kernel = true; for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { - const auto& tensor = nvte_tensor_output_list[i]; - if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) { - with_fused_kernel = false; - break; - } - if (nvte_tensor_columnwise_data(tensor) == nullptr) { + if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) { with_fused_kernel = false; break; } @@ -68,9 +71,8 @@ std::vector fused_multi_quantize(std::vector input_list, nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); } else { - for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { - nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], - at::cuda::getCurrentCUDAStream()); + for (size_t i = 0; i < py_output_objects_list.size(); i++) { + quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt); } } return py_output_objects_list; @@ -78,7 +80,7 @@ std::vector fused_multi_quantize(std::vector input_list, at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, std::optional output) { - using namespace transformer_engine::pytorch; + init_extension(); const auto dim = input.dim(); NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); @@ -105,3 +107,5 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, return out; } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index cbdf0833ed..a69e2cc24f 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -7,6 +7,19 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ +#include + +#include + +#include "transformer_engine/transformer_engine.h" + bool non_tn_fp8_gemm_supported(); +/* Swizzle the scaling factor of the input tensor. + * + * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + */ +std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, + bool trans); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e245b788b4..fe77b69cad 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -19,15 +19,24 @@ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules -from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data +from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm from .constants import dist_group_type from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor +try: + import torch.distributed._symmetric_memory as symm_mem + + HAS_TORCH_SYMMETRIC = True +except ImportError: + HAS_TORCH_SYMMETRIC = False __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -660,10 +669,13 @@ def checkpoint( **kwargs, ) - # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need - # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + from .module.base import TransformerEngineBaseModule + + if isinstance(function, TransformerEngineBaseModule): + # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need + # to scatter/gather activations that we will recompute anyway. + setattr(function, "fsdp_wrapped", False) + setattr(function, "fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing @@ -937,6 +949,74 @@ def _all_gather_fp8( return out, handle +def _all_gather_fp8_blockwise( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, # pylint: disable=unused-argument + quantizer: Optional[Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: + """ + All-gather FP8 tensor along first dimension for blockwise quantization. + + Returns: quantizer(gather(inp)) + + NOTE: The implementation is not sophisticated enough to honor async_op=True. + In some cases it falls back to synchronous gather and invokes the quantizer. + """ + + # Input tensor attributes + device: torch.device + dtype: torch.dtype + if isinstance(inp, torch.Tensor): + device = inp.device + dtype = inp.dtype + elif isinstance(inp, Float8BlockwiseQTensorBase): + if inp._rowwise_data is not None: + device = inp._rowwise_data.device + elif inp._columnwise_data is not None: + device = inp._columnwise_data.device + else: + raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") + dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " + f"found {inp.__class__.__name__})" + ) + world_size = get_distributed_world_size(process_group) + + # Check that quantizer is valid + if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): + raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") + if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): + raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") + + # Output tensor dims + if out_shape is None: + out_shape = list(inp.size()) + out_shape[0] *= world_size + + # Doing BF16 gather for now as baseline because it's simpler + if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) + out = quantizer(out) + return out, None + # Implementation of fp8 gather needs to account for: + # * Getting columnwise data as a transpose of how it is stored for GEMMS. + # * Gathering non GEMM swizzled scales. + # * Refer to scaffold code when implementing at: + # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 + raise NotImplementedError("fp8 blockwise allgather not yet implemented") + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1075,7 +1155,9 @@ def gather_along_first_dim( async_op: bool = False, quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: - """All-gather tensors and concatenate along first dimension.""" + """ + All-gather tensors and concatenate along first dimension. + """ # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) @@ -1100,6 +1182,16 @@ def gather_along_first_dim( out_shape=out_shape, ) + # FP8 block scaling case, block length = 128 + if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + return _all_gather_fp8_blockwise( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # MXFP8 case if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) @@ -1111,6 +1203,28 @@ def gather_along_first_dim( out_shape=out_shape, ) + # Debug case - call gather_along_first_dim on each tensor + if isinstance(inp, DebugQuantizedTensor): + out_obj = inp + rowwise = inp.get_tensor(False) + columnwise = inp.get_tensor(True) + final_quantizer = ( + None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer + ) + rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] + out_obj.rowwise_gemm_tensor = rowwise_total + if rowwise is not columnwise: + final_quantizer_columnwise = ( + None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer + ) + columnwise_total, _ = gather_along_first_dim( + columnwise, process_group, False, final_quantizer_columnwise + ) + out_obj.columnwise_gemm_tensor = columnwise_total + else: + out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor + return out_obj, None + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( @@ -1153,6 +1267,152 @@ def gather_along_first_dim( return out, handle +# Global cache to store symmetric memory tensors +symmetric_mem_cache = {} + + +def get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group, tag=None): + """ + Gets or creates a symmetric memory tensor with specified properties. + + Reuses cached tensors when available to avoid redundant creation and rendezvous operations. + + Note: This function always returns a 1D tensor. + + Parameters + ---------- + tensor_numel : int + Number of elements in the tensor. + tensor_dtype : torch.dtype + Data type of the tensor. + tensor_device : torch.device + Device on which to allocate the tensor. + tp_group : dist_group_type + Process group for rendezvous operation. + tag : Any, optional + Optional identifier to further distinguish tensors. + + Returns + ------- + torch.Tensor + A symmetric memory tensor with the specified properties. + """ + # Create a cache key based on tensor properties and group + cache_key = (tensor_numel, tensor_dtype, tensor_device, tp_group.group_name, tag) + + # Check if we already have a symmetric memory tensor for this configuration + if cache_key not in symmetric_mem_cache: + # Create a new symmetric memory tensor if not in cache + msg = symm_mem.empty( + tensor_numel, + dtype=tensor_dtype, + device=tensor_device, + ) + # Perform the rendezvous once for this tensor + symm_mem.rendezvous(msg, group=tp_group) + # Store in cache + symmetric_mem_cache[cache_key] = msg + else: + # Reuse the existing symmetric memory tensor + msg = symmetric_mem_cache[cache_key] + + return msg + + +def symmetric_all_reduce( + inp: torch.Tensor, + tp_group: Optional[dist_group_type] = None, + async_op: bool = False, + all_reduce_type: str = "multimem_all_reduce", +): + """ + Performs an all-reduce operation across multiple processes using symmetric memory. + If the input tensor is already in the symmetric memory cache we can avoid copy + overheads by just directly using the input tensor for all reduce. Externally + created symmetric memory tensors not in the cache currently will not be able to + avoid the extra copies. + + Parameters + ---------- + inp : torch.Tensor + The input tensor to be reduced. The operation is performed in-place. + + tp_group : Optional[dist_group_type], default=None + The process group over which to perform the all-reduce operation. + If None, the default process group is used. + + async_op : bool, default=False + Whether to perform the operation asynchronously. + Note: Currently only synchronous operations are supported for symmetric memory variants. + + all_reduce_type : str, default="multimem_all_reduce" + The type of all-reduce implementation to use. Options include: + - "nccl": Standard PyTorch distributed all-reduce + - "multimem_all_reduce": multimem symmetric all-reduce + - "two_shot": Two-shot symmetric all-reduce + - "one_shot": One-shot symmetric all-reduce + + Returns + ------- + Tuple[torch.Tensor, Optional[torch.distributed.Work]] + - The first element is the input tensor with the all-reduce result. + - The second element is the async work handle if async_op=True, + otherwise None. + """ + assert async_op is False, "Async symmetric ops no supported yet" + assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch" + + if get_distributed_world_size(tp_group) == 1: + return inp, None + + if all_reduce_type == "nccl": + # Standard all-reduce implementation + handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op) + return inp, handle + + all_reduce_impl = None + if all_reduce_type == "multimem_all_reduce": + all_reduce_impl = torch.ops.symm_mem.multimem_all_reduce_ + elif all_reduce_type == "two_shot": + all_reduce_impl = torch.ops.symm_mem.two_shot_all_reduce_ + elif all_reduce_type == "one_shot": + all_reduce_impl = torch.ops.symm_mem.one_shot_all_reduce + else: + raise TypeError(f"All reduce type {all_reduce_type} is not supported.") + + group_name = tp_group.group_name + tensor_shape = inp.shape + tensor_numel = inp.numel() + tensor_dtype = inp.dtype + tensor_device = inp.device + + input_id = id(inp) + is_cached = any(id(cached_tensor) == input_id for cached_tensor in symmetric_mem_cache.values()) + # Check if the input tensor is already in the symmetric memory cache. If it is we can avoid copy overheads. + if is_cached: + all_reduce_impl( + inp, + "sum", + group_name, + ) + else: + # Get symmetric memory tensor. Build or retrieve from cache. + msg = get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group) + + msg.copy_(inp.reshape(-1)) + + all_reduce_impl( + msg, + "sum", + group_name, + ) + + # Copy the result back to the input tensor + inp.copy_(msg.reshape(tensor_shape)) + + return inp, None + + def allreduce( inp: torch.Tensor, tp_group: Optional[dist_group_type] = None, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 38f829c079..c02ff73391 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import itertools import os from contextlib import contextmanager from collections import deque @@ -19,6 +20,7 @@ Format, MXFP8BlockScaling, Float8CurrentScaling, + Float8BlockScaling, ) from .constants import dist_group_type @@ -49,6 +51,17 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available""" + if ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ): + return True, "" + return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + + def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -109,6 +122,8 @@ class FP8GlobalStateManager: skip_fp8_weight_update_tensor = None mxfp8_available = None reason_for_no_mxfp8 = "" + fp8_block_scaling_available = None + reason_for_no_fp8_block_scaling = None @classmethod def reset(cls) -> None: @@ -134,6 +149,8 @@ def reset(cls) -> None: cls.skip_fp8_weight_update_tensor = None cls.mxfp8_available = None cls.reason_for_no_mxfp8 = "" + cls.fp8_block_scaling_available = None + cls.reason_for_no_fp8_block_scaling = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -161,6 +178,15 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @classmethod + def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: + """Return if Float8 block scaling support is available.""" + if cls.fp8_block_scaling_available is None: + cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( + check_fp8_block_scaling_support() + ) + return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -434,6 +460,9 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8 + if isinstance(fp8_recipe, Float8BlockScaling): + fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() + assert fp8_block_available, reason_for_no_fp8_block @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -786,8 +815,10 @@ def create( cls = MXFP8BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState + elif recipe.float8_block_scaling(): + cls = Float8BlockScalingRecipeState else: - raise ValueError("{recipe.__class__.__name__} is not supported") + raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( recipe, mode=mode, @@ -928,3 +959,108 @@ def make_quantizers(self) -> list: from .tensor.mxfp8_tensor import MXFP8Quantizer return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + + +class Float8BlockScalingRecipeState(RecipeState): + """Configuration for Float8BlockScaling quantization. + + Float8BlockScaling quantization does not require state, + but different quantizers use different modes. + """ + + recipe: Float8BlockScaling + mode: str + qx_dtype: tex.DType + qw_dtype: tex.DType + qgrad_dtype: tex.DType + + def __init__( + self, + recipe: Float8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.qx_dtype = get_fp8_te_dtype(recipe, True) + self.qw_dtype = get_fp8_te_dtype(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_blockwise_tensor import Float8BlockQuantizer + + if self.mode == "forward": + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward, and doesn't play nicely with QuantizeOp, + # which is not associated with a GEMM. + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qw_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] + ) + ) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] + ) + ) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index c2b525ab55..4828e9bc10 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -9,6 +9,7 @@ from functools import reduce from operator import mul as multiply_op +import queue import torch from .. import cpp_extensions as tex @@ -216,3 +217,79 @@ def __post_init__(self): """Safeguard reference to the parameter's parent module and initialization function.""" if self.init_fn is None: self.init_fn = get_default_init_method() + + +class WeightGradStore: + """ + A class to manage weight gradient storage and computation in Transformer modules. + This class enables split backward propagation for better memory efficiency. + """ + + def __init__(self, delay_wgrad_compute=False, ub_bulk_wgrad=False): + """ + Initialize the WeightGradStore. + + Args: + delay_wgrad_compute (bool): Whether to delay weight gradient computation + ub_bulk_wgrad (bool): Whether to enable bulk weight gradient computation + """ + if delay_wgrad_compute: + self.context = queue.Queue() + assert ( + ub_bulk_wgrad is False + ), "ub_bulk_wgrad is not supported when enabling delay_wgrad_compute" + self.enabled = delay_wgrad_compute + else: + self.context = None + self.enabled = False + + def delay_wgrad_compute(self): + """ + Get the current split backward propagation status. + + Returns: + bool: True if split backward is enabled, False otherwise + """ + return self.enabled + + def enable_delay_wgrad_compute(self): + """Enable split backward propagation.""" + self.enabled = True + + def disable_delay_wgrad_compute(self): + """Disable split backward propagation.""" + self.enabled = False + + def put(self, tensor_list, func): + """ + Store tensors and computation function for later execution. + + Args: + tensor_list (list): List of tensors needed for computation + func (callable): Function to be executed with the tensors + """ + assert self.enabled is True, "delay_wgrad_compute is not enabled" + self.context.put([tensor_list, func]) + + def pop(self): + """ + Execute the stored computation with the stored tensors. + Raises an exception if the queue is empty. + """ + assert self.enabled is True, "delay_wgrad_compute is not enabled" + if self.context.qsize() > 0: + tensor_list, func = self.context.get() + return func(*tensor_list), tensor_list + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + raise RuntimeError(f"Pop empty queue. rank {rank}") + raise RuntimeError("Pop empty queue. No distributed environment detected.") + + def assert_empty(self): + """ + Assert that the queue is empty. + Used for debugging and ensuring proper cleanup. + """ + assert self.enabled is True, "delay_wgrad_compute is not enabled" + rank = torch.distributed.get_rank() + assert self.context.empty(), f"Queue is not empty. rank {rank}" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cdb75aa1b6..17848a36bf 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager +import logging from types import MethodType import torch @@ -18,11 +19,12 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from ._common import _ParameterInitMeta +from ._common import _ParameterInitMeta, noop_cat from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -34,8 +36,13 @@ ) from ..constants import dist_group_type from ..tensor import QuantizedTensor, Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ...common.recipe import Recipe +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor __all__ = ["initialize_ub", "destroy_ub"] @@ -43,6 +50,7 @@ _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _multi_stream_cublas_workspace = [] +_dummy_wgrads = {} _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 @@ -78,6 +86,22 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: return _multi_stream_cublas_workspace +def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: + """Returns a dummy tensor of given shape.""" + assert len(shape) == 2 + global _dummy_wgrads + if (shape[0], shape[1], dtype) not in _dummy_wgrads: + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + shape, + dtype=dtype, + device="cuda", + requires_grad=False, + ) + if zero: + _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) + return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + + def initialize_ub( shape: list, tp_size: int, @@ -393,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + self.name = None self.fp8_initialized = False self.fp8 = False self.fp8_calibration = False @@ -412,6 +437,9 @@ def __init__(self) -> None: self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None + if not TEDebugState.debug_enabled: + TEDebugState.initialize() + # Names of attributes that can be set quickly (see __setattr__ # method) _fast_setattr_names: Set[str] = { @@ -499,6 +527,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8CurrentScalingRecipeState ): return + if recipe.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -824,7 +856,7 @@ def grad_output_preprocess( gather_grad_output = row_parallel_mode and ctx.sequence_parallel # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8: + if not ctx.fp8 and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) @@ -834,6 +866,7 @@ def grad_output_preprocess( return grad_output, None # FP8 with all-gather: unfused bgrad, fused cast + transpose + # Also supports debug quantization, which is handled inside gather_along_first_dim. if gather_grad_output: grad_bias = None if ctx.use_bias: @@ -841,7 +874,13 @@ def grad_output_preprocess( if ctx.ub_overlap_ag: # Quantize the gradient if needed if not isinstance( - grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + grad_output, + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), ): grad_output = quantizer(grad_output) @@ -856,14 +895,41 @@ def grad_output_preprocess( ) return grad_output, grad_bias + # Debug without all-gather: unfused cast and bgrad + # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None + if ctx.debug: + grad_output_ = quantizer(grad_output) + if ( + isinstance( + grad_output_.get_tensor(True), + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase), + ) + and ctx.use_bias + ): + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias = None + grad_output = grad_output_ + return grad_output, grad_bias + # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: - if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance(quantizer, Float8BlockQuantizer): + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_output = quantizer(grad_output) return grad_output, grad_bias @@ -962,6 +1028,7 @@ def get_weight_workspace( update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, fsdp_group: Optional[dist_group_type] = None, + workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: """Get FP8 workspace buffer and maybe update its values @@ -984,6 +1051,9 @@ def get_weight_workspace( over `update_workspace` if provided. fsdp_group: bool, default = None FSDP process group that the weights are distributed over. + workspace_dtype: torch.dtype, default = None + If weight workspace contains high-precision tensor - for example + for debug quantization, this is dtype of the tensor. """ # FP8 primary weights @@ -997,6 +1067,7 @@ def get_weight_workspace( # Try getting workspace from cache out = None + if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) if quantizer is not None and isinstance(out, MXFP8TensorBase): @@ -1007,6 +1078,11 @@ def get_weight_workspace( out = None del self._fp8_workspaces[cache_name] + is_debug = isinstance(quantizer, DebugQuantizer) + is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) + if is_debug != is_out_debug_tensor: + out = None + # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # for models initialized with Fp8 primary weights. @@ -1024,7 +1100,7 @@ def get_weight_workspace( raise ValueError( "tensor and quantizer kwargs must be provided to construct FP8 workspace" ) - out = quantizer(tensor) + out = quantizer.quantize(tensor, dtype=workspace_dtype) # Update cache if cache_name is not None: @@ -1041,7 +1117,6 @@ def get_weight_workspace( out.quantize_(tensor, noop_flag=skip_update_flag) else: tex.quantize(tensor, quantizer, out, skip_update_flag) - return out def _load_from_state_dict( @@ -1064,3 +1139,68 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + + def backward_dw(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ + if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + return + with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): + (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() + if not self.fuse_wgrad_accumulation: + unfused_weights = [getattr(self, name) for name in self.weight_names] + weight_tensor = noop_cat(unfused_weights) + if weight_tensor.grad is None: + weight_tensor.grad = wgrad.to(weight_tensor.dtype) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + if bias_tensor.grad is None: + bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) + del grad_bias_ + del wgrad + + def _validate_name(self): + """ + Validate name passed to the module. + This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. + If no name is assigned, it creates a default name with layer count as the variable. + """ + assert TEDebugState.debug_enabled + import nvdlfw_inspect.api as debug_api + + if self.name is None: + debug_api.log_message( + "Names are not provided to debug modules. ", + "Creating and using generic names. Pass names to debug modules for better" + " insight. ", + level=logging.WARNING, + ) + self.name = f"Layer_{TEDebugState.get_layer_count()}" + + def _turn_off_unsupported_features_in_debug(self): + if ( + getattr(self, "ub_bulk_wgrad", False) + or getattr(self, "ub_bulk_dgrad", False) + or getattr(self, "ub_overlap_ag", False) + or getattr(self, "ub_overlap_rs_dgrad", False) + or getattr(self, "ub_overlap_rs", False) + ): + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + "UserBuffers are not supported in debug module. " + "Using UB optimization will not affect the debug module. ", + level=logging.WARNING, + ) + if hasattr(self, "ub_bulk_wgrad"): + self.ub_bulk_wgrad = None + if hasattr(self, "ub_bulk_dgrad"): + self.ub_bulk_dgrad = None + if hasattr(self, "ub_overlap_ag"): + self.ub_overlap_ag = None + if hasattr(self, "ub_overlap_rs_dgrad"): + self.ub_overlap_rs_dgrad = None + if hasattr(self, "ub_overlap_rs"): + self.ub_overlap_rs = None diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 2549d45728..9748408338 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -4,12 +4,13 @@ """FP8 Padding API""" -from typing import Union, List +from typing import List, Optional, Tuple import torch import transformer_engine_torch as tex +from ..fp8 import FP8GlobalStateManager from ..jit import no_torch_dynamo @@ -74,22 +75,30 @@ class Fp8Padding(torch.nn.Module): ---------- num_gemms: int number of GEMMs to be performed simutaneously. + align_size: int, optional + the alignment size for the input tensor. If not provided, the alignment size will + be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. """ def __init__( self, - num_gemms, + num_gemms: int, + align_size: Optional[int] = None, ) -> None: super().__init__() self.num_gemms = num_gemms + if align_size is None: + self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + else: + self.align_size = align_size @no_torch_dynamo() def forward( self, inp: torch.Tensor, m_splits: List[int], - ) -> Union[torch.Tensor, List[int]]: + ) -> Tuple[torch.Tensor, List[int]]: """ Apply the padding to the input. @@ -104,7 +113,12 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." # FP8 padding calculate - padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + padded_m_splits = [ + (m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits + ] + # no padding needed + if m_splits == padded_m_splits: + return inp, m_splits if torch.is_grad_enabled(): fn = _Fp8Padding.apply diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 479b91d396..7e1fbcb2a3 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -4,12 +4,13 @@ """FP8 Padding API""" -from typing import List +from typing import List, Optional import torch import transformer_engine_torch as tex +from ..fp8 import FP8GlobalStateManager from ..jit import no_torch_dynamo @@ -70,15 +71,23 @@ class Fp8Unpadding(torch.nn.Module): ---------- num_gemms: int number of GEMMs to be performed simutaneously. + align_size: int, optional + the alignment size for the input tensor. If not provided, the alignment size will + be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. """ def __init__( self, - num_gemms, + num_gemms: int, + align_size: Optional[int] = None, ) -> None: super().__init__() self.num_gemms = num_gemms + if align_size is None: + self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + else: + self.align_size = align_size @no_torch_dynamo() def forward( @@ -100,7 +109,12 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." # FP8 padding calculate - padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + padded_m_splits = [ + (m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits + ] + # no padding needed + if m_splits == padded_m_splits: + return inp if torch.is_grad_enabled(): fn = _Fp8Unpadding.apply diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9cd52b1e5..f9bb7d767a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -5,10 +5,12 @@ """GroupedLinear API""" from typing import Union, Optional, Callable, Tuple, List +import functools import torch import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_multi_stream_cublas_workspace, TransformerEngineBaseModule, @@ -16,6 +18,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ._common import WeightGradStore from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, @@ -37,7 +40,6 @@ from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..tensor.float8_tensor import Float8Tensor from ..cpu_offload import is_cpu_offload_enabled from ..tensor.quantized_tensor import ( @@ -47,7 +49,6 @@ restore_from_saved, ) - __all__ = ["GroupedLinear"] @@ -65,6 +66,7 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, input_quantizers: List[Quantizer], weight_quantizers: List[Quantizer], output_quantizers: List[Quantizer], @@ -85,13 +87,6 @@ def forward( biases = weights_and_biases[num_gemms:] device = inp.device - # TODO Support MXFP8 # pylint: disable=fixme - if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): - raise NotImplementedError("GroupedLinear does not yet support MXFP8") - # TODO Support Float8 Current Scaling # pylint: disable=fixme - if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): - raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") - # Make sure input dimensions are compatible in_features = weights[0].shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" @@ -124,7 +119,11 @@ def forward( for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator inputmats = tex.fused_multi_quantize( inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] ) @@ -165,7 +164,7 @@ def forward( m_splits=m_splits, bias=biases, use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, + use_split_accumulator=fprop_gemm_use_split_accumulator, ) if fp8_calibration: @@ -177,9 +176,19 @@ def forward( weight_quantizers[i].calibrate(weights[i]) if is_grad_enabled: - + ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] + # TODO: update after #1638 is merged. # pylint: disable=fixme + if weight_requires_grad: + for inputmat in inputmats: + if isinstance(inputmat, QuantizedTensor): + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if inp.requires_grad: + for weight in weights_fp8: + if isinstance(weight, QuantizedTensor): + weight.update_usage(columnwise_usage=True) + tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, @@ -200,6 +209,7 @@ def forward( ctx.num_gemms = num_gemms ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -213,6 +223,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() ) + ctx.wgrad_store = wgrad_store # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -245,10 +256,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_biases = [None] * ctx.num_gemms if ctx.fp8: if ctx.use_bias: - for i in range(ctx.num_gemms): - grad_biases[i], grad_output[i] = tex.bgrad_quantize( - grad_output_mats[i], ctx.grad_output_quantizers[i] - ) + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready + # for Float8BlockQuantizer. + if ctx.fp8_recipe.float8_block_scaling(): + for i in range(ctx.num_gemms): + grad_biases[i] = grad_output_mats[i].sum(dim=0) + grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i]) + else: + for i in range(ctx.num_gemms): + grad_biases[i], grad_output[i] = tex.bgrad_quantize( + grad_output_mats[i], ctx.grad_output_quantizers[i] + ) else: grad_output = tex.fused_multi_quantize( grad_output_mats, @@ -267,12 +285,25 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.requires_dgrad: + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_dgrad.use_split_accumulator + ) dgrad = torch.empty( (sum(ctx.m_splits), ctx.weights_shape_1), dtype=ctx.activation_dtype, device=ctx.device, ) + for weight, quantizer in zip(weights, ctx.weight_quantizers): + if quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) general_grouped_gemm( weights, grad_output, @@ -283,10 +314,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], layout="NN", m_splits=ctx.m_splits, grad=True, - use_split_accumulator=_2X_ACC_DGRAD, + use_split_accumulator=dgrad_gemm_use_split_accumulator, ) if ctx.weights_requires_grad: + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: @@ -294,28 +332,31 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - # WGRAD - _, grad_biases_, _ = general_grouped_gemm( - inputmats, - grad_output, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), + grouped_gemm_wgrad = functools.partial( + general_grouped_gemm, + out_dtype=ctx.activation_dtype, + workspaces=get_multi_stream_cublas_workspace(), layout="NT", grad=True, m_splits=ctx.m_splits, use_bias=ctx.use_bias if grad_biases[0] is None else None, bias=biases, - use_split_accumulator=_2X_ACC_WGRAD, + use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, ) - for i in range(ctx.num_gemms): - if grad_biases[i] is None: - grad_biases[i] = grad_biases_[i] - del grad_biases_ + # WGRAD + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) + else: + _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) - # Deallocate input tensor - clear_tensor_data(*inputmats) + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ + + # Deallocate input tensor + clear_tensor_data(*inputmats) def handle_custom_ddp_from_mcore(weight, wgrad): if ctx.weights_requires_grad: @@ -351,7 +392,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias: + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + wgrad_list = [None] * ctx.num_gemms + + if not ctx.use_bias or ( + ctx.wgrad_store is not None + and ctx.wgrad_store.delay_wgrad_compute() + and not ctx.fp8 + ): grad_biases = [None] * ctx.num_gemms if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): @@ -372,8 +420,9 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, - None, # is_grad_enabled - None, # is_grad_enabled + None, + None, + None, *wgrad_list, *grad_biases, ) @@ -422,7 +471,12 @@ class GroupedLinear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + delay_wgrad_compute : bool, default = `False` + Whether to delay weight gradient computation + Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and + `parallel_mode` are used to determine the shapes of weights and biases. + The TP communication should be handled in the dispatch and combine stages of MoE models. """ def __init__( @@ -445,6 +499,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, + delay_wgrad_compute: bool = False, ) -> None: super().__init__() @@ -465,7 +520,13 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} + self.wgrad_store = WeightGradStore(delay_wgrad_compute) + + self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} + self._num_fp8_tensors_per_gemm = { + "fwd": 3, + "bwd": 2, + } if tp_group is None: self.tp_size = tp_size @@ -476,6 +537,12 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if self.tp_size > 1 and bias: + raise ValueError( + "GroupedLinear doesn't support bias when TP > 1. " + "Because the TP communication is handled outside of this module." + ) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -502,7 +569,7 @@ def __init__( ), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=self._offsets["weight"] + i, + fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], ) # Construct bias parameters if needed @@ -527,12 +594,18 @@ def __init__( self.reset_parameters(defer_init=device == "meta") - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.apply_bias: - self.gemm_bias_unfused_add = True - else: - self.gemm_bias_unfused_add = False + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + assert not self.tp_size > 1, ( + "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " + "Because the TP communication is handled outside of this module." + ) + self._customize_quantizers_float8_current_scaling(fwd, recipe) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -590,7 +663,7 @@ def forward( produced) """ assert not isinstance( - inp, Float8Tensor + inp, QuantizedTensor ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." @@ -615,20 +688,27 @@ def forward( grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms if self.fp8: input_quantizers = [ - self.quantizers["scaling_fwd"][self._offsets["input"] + i] + self.quantizers["scaling_fwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ] for i in range(self.num_gemms) ] + # TODO: use internal after #1638 is merged. # pylint: disable=fixme for i in range(self.num_gemms): - input_quantizers[i].internal = True + input_quantizers[i].internal = False weight_quantizers = [ - self.quantizers["scaling_fwd"][self._offsets["weight"] + i] + self.quantizers["scaling_fwd"][ + self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ] for i in range(self.num_gemms) ] for i in range(self.num_gemms): weight_quantizers[i].internal = True if torch.is_grad_enabled(): grad_output_quantizers = [ - self.quantizers["scaling_bwd"][self._offsets["input"] + i] + self.quantizers["scaling_bwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] + ] for i in range(self.num_gemms) ] for i in range(self.num_gemms): @@ -643,10 +723,11 @@ def forward( args += ( inp, m_splits, - self.apply_bias and not self.gemm_bias_unfused_add, + self.apply_bias, is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, input_quantizers, weight_quantizers, output_quantizers, @@ -663,17 +744,61 @@ def forward( ) out = linear_fn(*args) - if self.gemm_bias_unfused_add: - out_shape = out.shape - out = torch.cat( - [ - o + cast_if_needed(b, self.activation_dtype) - for o, b in zip( - torch.split(out.view(-1, self.out_features), m_splits), bias_tensors - ) - ] - ).view(out_shape) - if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out + + def backward_dw(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ + if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + return + with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): + (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() + wgrad_list = tensor_list[2] + if not self.fuse_wgrad_accumulation: + for i in range(self.num_gemms): + weight_param = getattr(self, f"weight{i}") + if weight_param.grad is None: + weight_param.grad = wgrad_list[i].to(weight_param.dtype) + if self.use_bias: + for i in range(self.num_gemms): + bias_param = getattr(self, f"bias{i}") + if bias_param.grad is None: + bias_param.grad = grad_biases_[i].to(bias_param.dtype) + del grad_biases_ + del wgrad_list + del tensor_list + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + for i in range(self.num_gemms): + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + else: + for i in range(self.num_gemms): + # set grad_output_quantizer with amax epsilon and power_2_scale + self.quantizers["scaling_bwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..d3bfed5885 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -9,16 +9,19 @@ from functools import reduce from operator import mul as multiply_op +import functools import torch from torch.nn import init import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version from .base import ( get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -34,11 +37,13 @@ nvtx_range_pop, nvtx_range_push, requires_grad, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, + symmetric_all_reduce, reduce_scatter_along_first_dim, gather_along_first_dim, in_fp8_activation_recompute_phase, @@ -48,17 +53,21 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose +from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, prepare_for_saving, restore_from_saved, ) +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload + from ..cpp_extensions import ( general_gemm, ) @@ -83,12 +92,14 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, fuse_wgrad_accumulation: bool, input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -113,6 +124,8 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + symmetric_ar_type: str, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -137,11 +150,6 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") - # Avoid quantized norm kernel if norm output will be returned - with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered - ) - tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output @@ -174,6 +182,18 @@ def forward( columnwise_usage = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + # Avoid quantized norm kernel if norm output will be returned + # or if a gather of ln_out must be in high precision. + force_hp_blockwise_ln_out_gather = ( + fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + with_quantized_norm = ( + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not force_hp_blockwise_ln_out_gather + ) + # Apply normalization nvtx_range_push(f"{nvtx_label}.norm") ln_out, mu, rsigma = apply_normalization( @@ -204,13 +224,13 @@ def forward( # norm output will be returned ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total - if fp8: + if fp8 or debug: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = input_quantizer(ln_out_total) else: - if fp8: - if not with_quantized_norm: + if fp8 or debug: + if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: @@ -223,18 +243,19 @@ def forward( ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, - quantizer=(input_quantizer if fp8 else None), + quantizer=(input_quantizer if fp8 or debug else None), ) else: - if fp8 and not with_quantized_norm: + if (fp8 or debug) and not with_quantized_norm: ln_out = input_quantizer(ln_out) ln_out_total = ln_out nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype - if not fp8: - quantized_weight = False - weightmat = cast_if_needed(weight, activation_dtype) + weightmat = weight + quantized_weight = False + if not fp8 and not debug: + weightmat = cast_if_needed(weightmat, activation_dtype) else: quantized_weight = not isinstance(weight, QuantizedTensor) @@ -244,6 +265,7 @@ def forward( # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( tensor=weight, quantizer=weight_quantizer, @@ -251,11 +273,12 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -313,9 +336,11 @@ def forward( clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: @@ -326,20 +351,16 @@ def forward( if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # For force_hp_blockwise_ln_out_gather, we should + # be saving the unquantized ln_out to ctx. + assert not force_hp_blockwise_ln_out_gather + # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) if cpu_offloading: - if fp8 and weightmat is not None: - set_offloading_param(weightmat, "weight_offloading", True) - set_offloading_param(ln_weight, "weight_offloading", True) - set_offloading_param(weight, "weight_offloading", True) - - set_offloading_param(inputmat, "activation_offloading", True) - set_offloading_param(mu, "activation_offloading", True) - set_offloading_param(rsigma, "activation_offloading", True) - set_offloading_param(ln_out, "activation_offloading", True) + mark_activation_offload(inputmat, mu, rsigma, ln_out) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -384,6 +405,7 @@ def forward( if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.input_quantizer = input_quantizer ctx.owns_input = inputmat is not inp @@ -418,6 +440,8 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store + ctx.debug = debug # Row Parallel Linear if ub_overlap_rs_fprop: @@ -427,7 +451,10 @@ def forward( if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: - out, _ = allreduce(out, tp_group) + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") # [*, in_features] -> [*, out_features] except first dimension changes for SP @@ -595,7 +622,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.input_quantizer is not None: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -604,11 +631,14 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + # async_op is not compatible with high precision gather since + # gather_along_first_dim does not offer callback chaining. + gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -633,6 +663,11 @@ def backward( if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) dgrad, *_ = general_gemm( weight, grad_output, @@ -689,6 +724,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather may have been done in BF16 + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): @@ -713,24 +755,38 @@ def backward( # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") - wgrad, grad_bias_, *_, rs_out = general_gemm( - ln_out_total, - grad_output, - get_workspace(), - layout="NT", - grad=True, + general_gemm_wgrad = functools.partial( + general_gemm, out_dtype=( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), + workspace=get_workspace(), + layout="NT", + grad=True, bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad) + else: + wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if not ctx.return_layernorm_output: + # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme + clear_tensor_data(ln_out_total) nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") if ctx.ub_bulk_wgrad: @@ -739,16 +795,11 @@ def backward( else: dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) - if grad_bias is None: - grad_bias = grad_bias_ - del grad_bias_ - - # Deallocate input tensor - if not ctx.return_layernorm_output: - # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme - clear_tensor_data(ln_out_total) + # Don't return grad bias if not needed + if not ctx.use_bias: + grad_bias = None - # Make sure all tensor-parallel communication is finished + # Synchronize tensor parallel communication if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None @@ -796,18 +847,15 @@ def backward( if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -833,12 +881,14 @@ def backward( None, # is_first_microbatch None, # fp8 None, # fp8_calibration + None, # wgrad_store None, # fuse_wgrad_accumulation None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -861,8 +911,10 @@ def backward( None, # ub_bulk_wgrad None, # ub_name None, # fsdp_group + None, # debug None, # module None, # skip_fp8_weight_update + None, # symmetric_ar_type ) @@ -915,6 +967,8 @@ class LayerNormLinear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -950,6 +1004,15 @@ class LayerNormLinear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + delay_wgrad_compute : bool, default = `False` + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. """ def __init__( @@ -979,6 +1042,9 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, + delay_wgrad_compute: bool = False, + symmetric_ar_type: Optional[str] = None, + name: str = None, ) -> None: super().__init__() @@ -994,6 +1060,12 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma + self.symmetric_ar_type = symmetric_ar_type + + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) + self.name = name + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers if tp_group is None: self.tp_size = tp_size @@ -1059,6 +1131,13 @@ def __init__( assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + self.eps = eps layer_norm_weight = torch.nn.Parameter( torch.empty(self.in_features, device=device, dtype=params_dtype) @@ -1284,6 +1363,9 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1320,13 +1402,28 @@ def forward( else: bias_tensor = getattr(self, self.bias_names[0]) # Unused + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output, fp8_grad) + debug = False + + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1344,12 +1441,14 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, self.fuse_wgrad_accumulation, input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1374,6 +1473,8 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, + debug, ) out = fwd_fn(*args) @@ -1393,8 +1494,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1413,8 +1515,20 @@ def _get_quantizers(self, fp8_output, fp8_grad): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7dae573688..b5f574f766 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -8,6 +8,7 @@ from typing import Callable, Optional, Tuple, Union from functools import reduce from operator import mul as multiply_op +import functools import torch from torch.nn.parameter import Parameter @@ -16,6 +17,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version from .base import ( get_workspace, _ub_communicators, @@ -41,18 +43,19 @@ clear_tensor_data, requires_grad, non_tn_fp8_gemm_supported, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, + symmetric_all_reduce, reduce_scatter_along_first_dim, gather_along_first_dim, use_reentrant_activation_recompute, in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, ) - from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing @@ -62,8 +65,9 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ._common import apply_normalization, _fix_gathered_fp8_transpose -from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -73,6 +77,8 @@ from ..cpp_extensions import ( general_gemm, ) +from ...debug.pytorch.utils import any_feature_enabled +from ...debug.pytorch.debug_state import TEDebugState __all__ = ["LayerNormMLP"] @@ -104,17 +110,19 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } # no activation fusion written yet - # Per-tensor current scaling: [] - return { - "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), - "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, None), - "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, None), - } + # Per-tensor current scaling or fp8 blockwise scaling: [] + if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, None), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, None), + } + raise NotImplementedError(f"Unhandled recipe type {recipe}") def _act_func(activation: str, recipe: Optional[Recipe] = None): @@ -122,7 +130,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None): # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] - # Per-tensor current scaling: [] + # Per-tensor current scaling or fp8 blockwise scaling: [] funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -148,15 +156,20 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, fuse_wgrad_accumulation: bool, fc1_input_quantizer: Optional[Quantizer], fc1_weight_quantizer: Optional[Quantizer], + fc1_output_quantizer: Optional[Quantizer], + fc1_grad_input_quantizer: Optional[Quantizer], + fc1_grad_weight_quantizer: Optional[Quantizer], + fc1_grad_output_quantizer: Optional[Quantizer], fc2_input_quantizer: Optional[Quantizer], fc2_weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], - grad_fc2_output_quantizer: Optional[Quantizer], - grad_fc1_output_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], + fc2_output_quantizer: Optional[Quantizer], + fc2_grad_input_quantizer: Optional[Quantizer], + fc2_grad_weight_quantizer: Optional[Quantizer], + fc2_grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -182,6 +195,8 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + symmetric_ar_type: str, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -210,16 +225,31 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - # Avoid quantized norm kernel if norm output will be returned + # for fp8 DelayedScaling: layernorm output = FP8 + # only output of the linear is returned + # for return_layernorm_output: layernorm output = High precision, then cast to FP8 + # high precision layernorm output and output of the linear are returned + # for debug: : layernorm output = High precision to enable processing of this norm with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not debug ) + if isinstance(fc1_input_quantizer, Float8BlockQuantizer): + # Kernels not available for norm fusion. + with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_rs = ub_overlap_rs and is_grad_enabled backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + # TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe + force_hp_fc1_input_gather = ( + fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + # Configure quantizer for norm output if fp8: if fc1_input_quantizer is None: @@ -260,13 +290,14 @@ def forward( # norm output will be returned ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total - if fp8: - ln_out = fc1_input_quantizer(ln_out) + if fp8 or debug: + if not force_hp_fc1_input_gather: + ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = fc1_input_quantizer(ln_out_total) else: - if fp8: - if not with_quantized_norm: + if fp8 or debug: + if not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -279,18 +310,21 @@ def forward( ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, - quantizer=(fc1_input_quantizer if fp8 else None), + quantizer=(fc1_input_quantizer if fp8 or debug else None), ) else: - if fp8 and not with_quantized_norm: + # NOTE: force_hp_fc1_input_gather is redundant with else, but + # here for clarity. We should not quantize ln_out if bwd needs + # to gather in hp. + if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) ln_out_total = ln_out # Cast weights to expected dtype - if not fp8: - fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype) - fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype) - else: + fc1_weight_final = fc1_weight + fc2_weight_final = fc2_weight + + if fp8 or debug: # If weights are not quantized, we call get_weight_workspace, # which handles weight caching etc. # FP8 cast to workspace buffer @@ -302,6 +336,7 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc2_weight_final = module.get_weight_workspace( @@ -311,11 +346,15 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) + else: + fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) # Cast biases to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 if fc1_bias is not None: fc1_bias = cast_if_needed(fc1_bias, bias_dtype) @@ -336,6 +375,7 @@ def forward( # - bias_gelu_fusion - only for full precision. # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer if activation != "gelu": + # blockwise scaled gemms don't support gemm_gelu_fusion in fwd. gemm_gelu_fusion = bias_gelu_fusion = False else: if fp8: @@ -344,13 +384,16 @@ def forward( gemm_gelu_fusion = True if gemm_gelu_fusion and bias_gelu_fusion: gemm_gelu_fusion = False - + if debug: + gemm_gelu_fusion = False fc1_outputs = general_gemm( fc1_weight_final, ln_out_total, get_workspace(), quantization_params=( - fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 + fc2_input_quantizer + if gemm_gelu_fusion + else fc1_output_quantizer # fused gelu output is in fp8 ), out_dtype=activation_dtype, bias=( @@ -361,6 +404,7 @@ def forward( ub=ub_obj_lnout, ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ) + if not is_grad_enabled and (ln_out_total is not ln_out_return): clear_tensor_data(ln_out_total) @@ -374,9 +418,18 @@ def forward( act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) elif gemm_gelu_fusion: act_out, _, fc1_out, _ = fc1_outputs + elif debug: + fc1_out, *_ = fc1_outputs + act_out = activation_func(fc1_out, None) + act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + # tex.quantize does not support GELU fusion for blockwise. + act_out = activation_func(fc1_out, None) + act_out = tex.quantize(act_out, fc2_input_quantizer) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -406,7 +459,7 @@ def forward( get_workspace(), out_dtype=activation_dtype, bias=fc2_bias, - quantization_params=output_quantizer, + quantization_params=fc2_output_quantizer, out=fc2_out, use_split_accumulator=_2X_ACC_FPROP, ub=ub_obj_fc2out, @@ -425,23 +478,9 @@ def forward( clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) else: if cpu_offloading: - if fp8 and fc1_weight_final is not None: - set_offloading_param(fc1_weight_final, "weight_offloading", True) - if fp8 and fc2_weight_final is not None: - set_offloading_param(fc2_weight_final, "weight_offloading", True) - set_offloading_param(ln_weight, "weight_offloading", True) - set_offloading_param(fc1_weight, "weight_offloading", True) - set_offloading_param(fc2_weight, "weight_offloading", True) - set_offloading_param(fc1_bias, "weight_offloading", True) - - set_offloading_param(inputmat, "activation_offloading", True) - set_offloading_param(mu, "activation_offloading", True) - set_offloading_param(rsigma, "activation_offloading", True) - set_offloading_param(mu, "activation_offloading", True) - set_offloading_param(ln_out, "activation_offloading", True) - set_offloading_param(fc1_out, "activation_offloading", True) - set_offloading_param(fc1_out_without_bias, "activation_offloading", True) - set_offloading_param(act_out, "activation_offloading", True) + mark_activation_offload( + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + ) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -458,10 +497,14 @@ def forward( fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, ) + ctx.fc1_weight_quantizer = fc1_weight_quantizer + ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: clear_tensor_data(ln_out) ln_out = None + elif force_hp_fc1_input_gather: + assert not isinstance(ln_out, QuantizedTensor) if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None @@ -490,11 +533,15 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer - ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather + ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer + ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer + ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer + ctx.fc2_grad_input_quantizer = fc2_grad_input_quantizer + ctx.fc2_grad_weight_quantizer = fc2_grad_weight_quantizer + ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer ctx.fc1_input_quantizer = fc1_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer ctx.fc1_weight_requires_grad = fc1_weight.requires_grad ctx.fc2_weight_requires_grad = fc2_weight.requires_grad @@ -505,6 +552,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -526,6 +574,7 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag + ctx.debug = debug ctx.requires_dgrad = ( inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad @@ -540,13 +589,20 @@ def forward( if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store + # Row Parallel Linear if ub_overlap_rs: fc2_out = rs_out elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) elif set_parallel_mode and tensor_parallel: - fc2_out, _ = allreduce(fc2_out, tp_group) + if symmetric_ar_type is not None: + fc2_out, _ = symmetric_all_reduce( + fc2_out, tp_group, all_reduce_type=symmetric_ar_type + ) + else: + fc2_out, _ = allreduce(fc2_out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1]) @@ -649,18 +705,18 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_fc2_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None: rowwise_usage = True columnwise_usage = True if ctx.ub_overlap_ag and isinstance( - ctx.grad_fc2_output_quantizer, + ctx.fc2_grad_output_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer), ): # If data is in FP8 and communication is handled # with Userbuffers, we compute FP8 transposes # manually columnwise_usage = False - ctx.grad_fc2_output_quantizer.set_usage( + ctx.fc2_grad_output_quantizer.set_usage( rowwise=rowwise_usage, columnwise=columnwise_usage, ) @@ -675,7 +731,7 @@ def backward( grad_output, fc2_bias_grad, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer + ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer ) # Launch tensor-parallel communication for FC1 GEMM input @@ -688,7 +744,7 @@ def backward( and not ctx.ub_bulk_dgrad ): quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -696,11 +752,12 @@ def backward( else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) + gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) else: ln_out_total = ln_out @@ -712,17 +769,26 @@ def backward( ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - # There are 5 possible fusion paths + # There are 6 possible fusion paths # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + not ctx.fp8 + and (ctx.activation == "gelu") + and (not ctx.bias_gelu_fusion) + and (not ctx.debug) ) # FC2 DGRAD; Unconditional + if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor): + ctx.fc2_weight.update_usage( + rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage, + ) gemm_output, *_ = general_gemm( fc2_weight, grad_output, @@ -730,7 +796,9 @@ def backward( layout="NN", grad=True, quantization_params=( - ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None + ctx.fc1_grad_input_quantizer + if fc2_dgrad_gemm_gelu_fusion or ctx.debug + else None ), # high precision to activation out_dtype=ctx.activation_dtype, gelu=fc2_dgrad_gemm_gelu_fusion, @@ -753,39 +821,65 @@ def backward( if isinstance(grad_output, QuantizedTensor): grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) - fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( - act_out, - grad_output, - get_workspace(), + grad_arg = True + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + grad_arg = False + general_gemm_fc2_wgrad = functools.partial( + general_gemm, out_dtype=( origin_fc2_weight.main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - quantization_params=None, # wgrad in high precision + workspace=get_workspace(), + quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision layout="NT", - grad=True, - bias=fc2_bias if fc2_bias_grad is None else None, + grad=grad_arg, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) - if fc2_bias_grad is None: - fc2_bias_grad = fc2_bias_grad_ - clear_tensor_data(act_out) + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) + fc2_wgrad = None + else: + fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad( + act_out, + grad_output, + ) + + if fc2_bias_grad is None: + if ( + ctx.fp8 + and ctx.fp8_recipe.float8_block_scaling() + and fc2_bias is not None + ): + # BGRAD not fused with GEMM for float8 blockwise gemm. + fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) + + fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ + if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute(): + clear_tensor_data(act_out) # bias computation fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.grad_fc1_output_quantizer is not None: - ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.fc1_grad_output_quantizer is not None: + ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) - if ctx.grad_fc1_output_quantizer is not None: - dact = ctx.grad_fc1_output_quantizer(dact) + if ctx.fc1_grad_output_quantizer is not None: + dact = ctx.fc1_grad_output_quantizer(dact) + elif ctx.debug: + dact_func = _act_func(ctx.activation)[1] + dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None) + fc1_bias_grad = dact.sum(dim=0) + dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and ctx.fp8 @@ -795,7 +889,7 @@ def backward( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer + fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, @@ -808,7 +902,14 @@ def backward( ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + # TODO float8 blockwise current scaling has no bgrad fusion for now + if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): + fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) + dact = ctx.fc1_grad_output_quantizer(dact) + else: + fc1_bias_grad, dact = tex.bgrad_quantize( + dact, ctx.fc1_grad_output_quantizer + ) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 @@ -855,12 +956,20 @@ def backward( fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) # FC1 DGRAD: Unconditional + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensor + ): + ctx.fc1_weight.update_usage( + rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage, + ) fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( fc1_weight, dact, get_workspace(), out=fc1_dgrad_bulk, out_dtype=ctx.activation_dtype, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, ub=ub_obj_fc1_dgrad, @@ -904,6 +1013,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.fc1_input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): @@ -919,16 +1035,16 @@ def backward( ) # wgrad GEMM - fc1_wgrad_outputs = general_gemm( - ln_out_total, - dact, - get_workspace(), + general_gemm_fc1_wgrad = functools.partial( + general_gemm, out_dtype=( origin_fc1_weight.main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), + workspace=get_workspace(), layout="NT", + quantization_params=ctx.fc1_grad_weight_quantizer, grad=fuse_gemm_and_bias_fc1_wgrad, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, @@ -938,13 +1054,23 @@ def backward( extra_output=fc1_dgrad_rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) + fc1_wgrad = None + if fuse_gemm_and_bias_fc1_wgrad: + fc1_bias_grad = None + else: + fc1_wgrad_outputs = general_gemm_fc1_wgrad( + ln_out_total, + dact, + ) - clear_tensor_data(ln_out_total, dact) + clear_tensor_data(ln_out_total, dact) - if fuse_gemm_and_bias_fc1_wgrad: - fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs - else: - fc1_wgrad, *_ = fc1_wgrad_outputs + if fuse_gemm_and_bias_fc1_wgrad: + fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs + else: + fc1_wgrad, *_ = fc1_wgrad_outputs if ctx.ub_bulk_wgrad: if ub_obj_fc1_wgrad.is_fp8_ubuf(): @@ -1061,15 +1187,20 @@ def backward( None, # is_first_microbatch None, # fp8 None, # fp8_calibration + None, # wgrad_store None, # fuse_wgrad_accumulation - None, # fc1_input_quantizer - None, # fc1_weight_quantizer - None, # fc2_input_quantizer - None, # fc2_weight_quantizer - None, # output_quantizer - None, # grad_fc2_output_quantizer - None, # grad_fc1_output_quantizer - None, # grad_input_quantizer + None, # fc1_input_quantizer, + None, # fc1_weight_quantizer, + None, # fc1_output_quantizer, + None, # fc1_grad_input_quantizer, + None, # fc1_grad_weight_quantizer, + None, # fc1_grad_output_quantizer, + None, # fc2_input_quantizer, + None, # fc2_weight_quantizer, + None, # fc2_output_quantizer, + None, # fc2_grad_input_quantizer, + None, # fc2_grad_weight_quantizer, + None, # fc2_grad_output_quantizer, None, # cpu_offloading None, # tp_group None, # tp_size @@ -1095,6 +1226,8 @@ def backward( None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # symmetric_ar_type + None, # debug ) @@ -1147,6 +1280,8 @@ class LayerNormMLP(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -1189,6 +1324,15 @@ class LayerNormMLP(TransformerEngineBaseModule): batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase. + delay_wgrad_compute : bool, default = `False` + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. """ def __init__( @@ -1216,10 +1360,13 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, + name: str = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, + delay_wgrad_compute: bool = False, + symmetric_ar_type: Optional[str] = None, ) -> None: super().__init__() @@ -1238,6 +1385,7 @@ def __init__( ) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma + self.symmetric_ar_type = symmetric_ar_type # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( @@ -1245,6 +1393,12 @@ def __init__( and self.activation == "gelu" and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers + + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if tp_group is None: self.tp_size = tp_size @@ -1273,6 +1427,13 @@ def __init__( ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad ) + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1405,7 +1566,9 @@ def reset_parameters(self, defer_init=False): @no_torch_dynamo() def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1428,6 +1591,9 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1442,17 +1608,35 @@ def forward( fp8_output = True with self.prepare_forward(inp, num_gemms=2) as inp: + + quantizers = ( + self._get_quantizers(fp8_output) + if not debug + else self._get_debug_quantizers(fp8_output) + ) + if debug: + if not any_feature_enabled(quantizers): + quantizers = self._get_quantizers(fp8_output) + debug = False + + if isinstance(self.fc1_weight, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + # Get quantizers ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = self._get_quantizers(fp8_output) + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers # Get weight tensors fc1_weight = self.fc1_weight @@ -1487,15 +1671,20 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, self.fuse_wgrad_accumulation, fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1504,7 +1693,7 @@ def forward( self.activation_dtype, self.return_layernorm_output, self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8, + self.bias_gelu_nvfusion and not self.fp8 and not debug, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, @@ -1517,10 +1706,12 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_bulk_dgrad, self.ub_bulk_wgrad, - self.gemm_gelu_fusion, + self.gemm_gelu_fusion and not debug, self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, + debug, ) out = fwd_fn(*args) @@ -1542,13 +1733,17 @@ def _get_quantizers(self, fp8_output): ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = [None] * 8 + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = [None] * 12 if self.fp8: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = False # temporary @@ -1556,35 +1751,60 @@ def _get_quantizers(self, fp8_output): fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( - rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + rowwise=True, + columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] + fc2_output_quantizer = self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_OUTPUT + ] if torch.is_grad_enabled(): - grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ] - grad_fc2_output_quantizer.internal = True - grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer.internal = True + fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ] - grad_fc1_output_quantizer.internal = True - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2] - grad_input_quantizer.internal = True + fc1_grad_output_quantizer.internal = True return ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, ) + def _get_debug_quantizers(self, fp8_output): + from ...debug.pytorch.debug_quantization import DebugQuantizer + + base_quantizers = list(self._get_quantizers(fp8_output)) + assert TEDebugState.debug_enabled + + def make_debug(prefix, offset): + labels = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return [ + DebugQuantizer( + f"{self.name}.{prefix}", + label, + None if label in ("dgrad", "wgrad") else base_quantizers[i + offset], + self.tp_group, + ) + for i, label in enumerate(labels) + ] + + return tuple(make_debug("fc1", 0) + make_debug("fc2", 6)) + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + layernorm_mlp.""" assert ( @@ -1629,14 +1849,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8FwdTensors.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: - # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer + # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer + # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale @@ -1644,10 +1864,48 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_INPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon if self.sequence_parallel and self.set_parallel_mode: - # grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + + def backward_dw(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ + if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + return + with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): + (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() + if self.use_bias and self.fc1_bias.grad is None: + (fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop() + else: + (fc1_wgrad, *_), _ = self.wgrad_store.pop() + fc1_bias_grad = None + if self.use_bias: + if self.fc2_bias.grad is None: + if ( + self.fp8 + and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling() + and self.apply_bias + and not self.gemm_bias_unfused_add + ): + act_out = tensor_list_fc2[0] + # BGRAD not fused with GEMM for float8 blockwise gemm. + fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) + self.fc2_bias.grad = fc2_bias_grad_.to(self.fc2_bias.dtype) + if self.fc1_bias.grad is None: + self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype) + if not self.fuse_wgrad_accumulation: + if self.fc2_weight.grad is None: + self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) + if self.fc1_weight.grad is None: + self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) + del fc2_bias_grad_ + del fc2_wgrad + del fc1_wgrad + del fc1_bias_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..7803f4a084 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,36 +7,41 @@ from functools import reduce from operator import mul as multiply_op +import functools import torch import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version from .base import ( get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import noop_cat, _fix_gathered_fp8_transpose +from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ..fp8 import FP8GlobalStateManager from ..utils import ( cast_if_needed, clear_tensor_data, divide, init_method_constant, + requires_grad, + needs_quantized_gemm, non_tn_fp8_gemm_supported, assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, - requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, + symmetric_all_reduce, reduce_scatter_along_first_dim, gather_along_first_dim, is_fp8_activation_recompute_enabled, @@ -59,8 +64,10 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - -from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled __all__ = ["Linear"] @@ -79,11 +86,13 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, tp_group: Union[dist_group_type, None], @@ -104,6 +113,8 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + symmetric_ar_type: str, + debug: Optional[bool] = False, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -129,6 +140,10 @@ def forward( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False + # TODO(kwyss): Support FP8 allgather for FP8 block quantization. + force_hp_input_gather = ( + fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. if fp8: assert_dim_for_fp8_exec(inputmat, weight) if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( @@ -138,23 +153,31 @@ def forward( "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" " current scaling" ) - + if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - if not isinstance(inputmat, QuantizedTensor): - columnwise_usage = backward_needs_input and isinstance( - input_quantizer, MXFP8Quantizer + if force_hp_input_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, tp_group, quantizer=input_quantizer + ) + else: + if not isinstance(inputmat, QuantizedTensor): + columnwise_usage = backward_needs_input and isinstance( + input_quantizer, MXFP8Quantizer + ) + # force_hp_input_gather should enforce this + assert not isinstance(input_quantizer, Float8BlockQuantizer) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + inputmat = input_quantizer(inputmat) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, ) - input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - inputmat = input_quantizer(inputmat) - own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=False) - inputmat_total, _ = gather_along_first_dim( - inputmat, - tp_group, - quantizer=input_quantizer, - ) else: if ( FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() @@ -182,9 +205,9 @@ def forward( nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype - if not fp8: - weightmat = cast_if_needed(weight, activation_dtype) - else: + weightmat = weight + + if fp8 or debug: # Configure quantizer if weight_quantizer is not None: columnwise_usage = is_grad_enabled and inp.requires_grad @@ -194,7 +217,6 @@ def forward( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch weightmat = module.get_weight_workspace( @@ -204,11 +226,14 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) + else: + weightmat = cast_if_needed(weightmat, activation_dtype) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -263,6 +288,7 @@ def forward( nvtx_range_pop(f"{nvtx_label}.gemm") if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer saved_inputmat = None ctx.backward_input_needs_gather = ( @@ -276,6 +302,8 @@ def forward( # can be allgathered. if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if force_hp_input_gather: + assert not isinstance(inputmat, QuantizedTensor) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. @@ -283,11 +311,8 @@ def forward( if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) - if cpu_offloading: - set_offloading_param(weight, "weight_offloading", True) - set_offloading_param(weightmat, "weight_offloading", True) - if saved_inputmat is not None: - set_offloading_param(saved_inputmat, "activation_offloading", True) + if cpu_offloading and saved_inputmat is not None: + mark_activation_offload(saved_inputmat) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights @@ -322,15 +347,18 @@ def forward( ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer - ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad + ctx.debug = debug ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -354,6 +382,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store # Row Parallel Linear if ub_overlap_rs_fprop: @@ -363,7 +392,10 @@ def forward( if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: - out, _ = allreduce(out, tp_group) + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") out = out.view(-1, *inp_shape[1:-1], out_features) @@ -510,7 +542,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -519,11 +551,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + gather_quantizer = None if ctx.force_hp_input_gather else quantizer inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -545,7 +578,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Update quantizer if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - # dgrad GEMM nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD @@ -556,6 +588,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], recipe.fp8_gemm_dgrad.use_split_accumulator ) + if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor): + weight_fp8.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) + dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, @@ -609,6 +647,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None + if ctx.input_quantizer is not None and not isinstance( + inputmat_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data if isinstance(inputmat_total, QuantizedTensor): @@ -633,24 +678,37 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") - wgrad, grad_bias_, _, rs_out = general_gemm( - inputmat_total, - grad_output, - get_workspace(), - layout="NT", - grad=True, + general_gemm_wgrad = functools.partial( + general_gemm, out_dtype=( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), + workspace=get_workspace(), + layout="NT", + grad=True, bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad) + else: + wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if ctx.owns_input: + clear_tensor_data(inputmat_total) nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") if ctx.ub_bulk_wgrad: @@ -659,14 +717,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) - if grad_bias is None: - grad_bias = grad_bias_ - del grad_bias_ - - # Deallocate input tensor - if ctx.owns_input: - clear_tensor_data(inputmat_total) - # Don't return grad bias if not needed if not ctx.use_bias: grad_bias = None @@ -688,18 +738,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -721,11 +768,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # is_first_microbatch None, # fp8 None, # fp8_calibration + None, # wgrad_store None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # fuse_wgrad_accumulation None, # cpu_offloading None, # tp_group @@ -746,6 +795,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # symmetric_ar_type + None, # debug ) @@ -781,6 +832,8 @@ class Linear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -816,7 +869,15 @@ class Linear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. - + delay_wgrad_compute : bool, default = `False` + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. """ def __init__( @@ -842,6 +903,9 @@ def __init__( ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, + delay_wgrad_compute: bool = False, + symmetric_ar_type: Optional[str] = None, + name: Optional[str] = None, ) -> None: super().__init__() @@ -854,6 +918,13 @@ def __init__( self.apply_bias = bias and not return_bias self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.symmetric_ar_type = symmetric_ar_type + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers + + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -919,6 +990,13 @@ def __init__( assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." self.ub_name = ub_name + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1097,6 +1175,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() + if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() else: @@ -1132,13 +1214,28 @@ def forward( else: bias_tensor = None + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output, fp8_grad) + debug = False + + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization @@ -1159,11 +1256,13 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1184,6 +1283,8 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, + debug, ) out = linear_fn(*args) if self.gemm_bias_unfused_add: @@ -1195,8 +1296,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1214,8 +1316,20 @@ def _get_quantizers(self, fp8_output, fp8_grad): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 45c78bea87..aa0bb1a52b 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -13,6 +13,7 @@ import transformer_engine_torch as tex from ...fp8 import FP8GlobalStateManager from ...tensor import QuantizedTensor +from ...tensor.float8_tensor import Float8CurrentScalingQuantizer from ...utils import clear_tensor_data, devices_match from ..op import BasicOperation, OperationContext from .._common import reshape @@ -37,8 +38,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): the first half of the input tensor, while PyTorch applies it to the second half. + Parameters + ---------- + cache_quantized_input: bool, default = False + Quantize input tensor when caching for use in the backward + pass. This will typically reduce memory usage but require + extra compute and increase numerical error. This feature is + highly experimental. + """ + def __init__(self, *, cache_quantized_input: bool = False): + super().__init__() + self.cache_quantized_input: bool = cache_quantized_input + @abc.abstractmethod def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: """Forward implementation @@ -100,9 +113,16 @@ def op_forward( if y.dim() != x.dim(): y = y.reshape(list(x.shape[:-1]) + [-1]) + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + quantizer.set_usage(rowwise=True, columnwise=False) + x = quantizer(x) + # Save state for backward pass ctx.save_for_backward(x.detach()) ctx.fp8_enabled = fp8_enabled + ctx.dtype = dtype ctx.prev_op = prev_op return y @@ -116,10 +136,18 @@ def op_backward( # Saved tensors from forward pass (x,) = ctx.saved_tensors + # Check input tensor + if isinstance(x, QuantizedTensor): + x = x.dequantize(dtype=ctx.dtype) + elif x.dtype != ctx.dtype: + x = x.to(dtype=ctx.dtype) + if not x.is_contiguous(): + x = x.contiguous() + # Check grad output tensor dy = grad_output if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() + dy = dy.dequantize(dtype=ctx.dtype) if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: dy = dy.to(device=x.device, dtype=x.dtype) if not dy.is_contiguous(): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb93eb5e6b..86f17608f4 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -23,6 +23,7 @@ from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext @@ -412,7 +413,6 @@ def _functional_forward( x = None x_async = None with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel - own_quantized_x_local = False if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") @@ -428,7 +428,6 @@ def _functional_forward( else: if not isinstance(x_local, QuantizedTensor): x_local = input_quantizer(x_local) - own_quantized_x_local = True x = x_local else: if isinstance(x_local, QuantizedTensor): @@ -483,6 +482,12 @@ def _functional_forward( "Attempting to generate MXFP8 output tensor, " "but GEMM with MXFP8 output is not supported" ) + if isinstance(output_quantizer, Float8BlockQuantizer): + raise RuntimeError( + "Attempting to generate Float8BlockQuantized output tensor, " + "but GEMM with Float8BlockQuantized output is not supported" + ) + if output_quantizer is not None: output_quantizer.set_usage(rowwise=True, columnwise=False) @@ -521,16 +526,16 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Configure input tensor for backward pass - if own_quantized_x_local: - x_local.update_usage(rowwise_usage=False) - # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save # input tensor as context for backward pass. if x_local is input: x_local = x_local.detach() + # Configure input tensor for backward pass + if with_quantized_compute and isinstance(x_local, QuantizedTensor): + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) + return y, x_local, w @staticmethod @@ -679,7 +684,9 @@ def _functional_backward( quantizer=input_quantizer, ) else: - if not isinstance(x_local, QuantizedTensor): + if isinstance(x_local, QuantizedTensor): + x_local.update_usage(columnwise_usage=True) + else: x_local = input_quantizer(x_local) x = x_local else: @@ -706,15 +713,19 @@ def _functional_backward( raise ValueError("Weight tensor is required to compute input grad") w = weight w_is_quantized = isinstance(w, QuantizedTensor) - if with_quantized_compute and not w_is_quantized: - if weight_quantizer is None: - raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(columnwise=True) - w = weight_quantizer(w) - elif not with_quantized_compute and w_is_quantized: - w = w.dequantize() - if not with_quantized_compute and w.dtype != dtype: - w = w.to(dtype=dtype) + if with_quantized_compute: + if w_is_quantized: + w.update_usage(columnwise_usage=True) + else: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + else: + if w_is_quantized: + w = w.dequantize(dtype=dtype) + elif w.dtype != dtype: + w = w.to(dtype=dtype) # Synchronize tensor-parallel communication _wait_async(dy_async) @@ -867,8 +878,8 @@ def op_forward( # Configure quantizers # Note: We cache the quantized input for backward pass, # but discard the quantized weights. - input_quantizer.set_usage(columnwise=weight_requires_grad) - weight_quantizer.set_usage(columnwise=False) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) # Get autocast dtype if needed dtype = None diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 2e212e15f4..802f4c25e3 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -17,8 +17,10 @@ from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, + fp8_autocast, ) from ..tensor import Quantizer @@ -218,6 +220,11 @@ def _reset_quantization_recipe_state( if num_quantizers == 0: continue + if recipe.float8_block_scaling(): + raise NotImplementedError( + "Fusible operations do not support FP8 block scaling recipe" + ) + # Construct quantization recipe state recipe_state = RecipeState.create( recipe, @@ -259,8 +266,13 @@ def _update_quantization_recipe_state( continue recipe_state = self._fp8_metas[mode][fp8_meta_key] need_to_reset_recipe_state = ( - recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) - ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) + or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + or ( + recipe.float8_block_scaling() + and not isinstance(recipe_state, Float8BlockScalingRecipeState) + ) + ) if need_to_reset_recipe_state: self._reset_quantization_recipe_state(recipe=recipe) return @@ -508,7 +520,7 @@ def forward( def get_extra_state(self) -> torch.Tensor: """Serialize extra state - Contains metadata for FP8 casting. + Contains metadata for quantization recipe. """ @@ -540,23 +552,27 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: dst.copy_(src, non_blocking=True) return dst - # Store FP8 state + # Store quantizer state if needed state = {} for mode in ("forward", "backward"): - # Get state for a given FP8 tensor - if self.num_quantizers(mode) == 0: + # Skip if op has no quantizer state + if self._fp8_metas is None or self._fp8_metas[mode] is None: continue - fp8_meta = self.get_fp8_meta(mode) + + # Quantizer state + fp8_meta = self._fp8_metas[mode] state[mode] = {} + state[mode]["recipe"] = fp8_meta["recipe"] - # Store tensors - if "scaling_fwd" in fp8_meta: - state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) - if "scaling_bwd" in fp8_meta: - state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) + # Copy tensors to CPU and store + if state[mode]["recipe"].delayed(): + if mode == "forward": + state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) + state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) + if mode == "backward": + state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) + state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items extra = {} @@ -595,37 +611,37 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) dst.copy_(src, non_blocking=True) - # Load FP8 state + # Load quantizer state if needed for mode in ("forward", "backward"): - # Get state for a given FP8 tensor + # Skip if checkpoint has no quantizer state if mode not in state: continue - if self.num_quantizers(mode) == 0: - continue - fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue - # Load extra state + # Get op's quantizer state, initializing if needed + if self._fp8_metas is None or self._fp8_metas[mode] is None: + with fp8_autocast(fp8_recipe=state[mode]["recipe"]): + self._reset_quantization_recipe_state() + fp8_meta = self._fp8_metas[mode] + + # Load extra items + fp8_meta["recipe"] = state[mode]["recipe"] fp8_meta.update(state[mode]["extra_fp8_variables"]) - if "amax_history_fwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0) - elif "amax_history_bwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0) if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Load tensors - fp8_meta = self.get_fp8_meta(mode) - if "scaling_fwd" in fp8_meta: - fp8_meta_fwd = fp8_meta["scaling_fwd"] - copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) - if "scaling_bwd" in fp8_meta: - fp8_meta_bwd = fp8_meta["scaling_bwd"] - copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) + if state[mode]["recipe"].delayed(): + if mode == "forward": + copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale) + copy_tensor( + state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history + ) + if mode == "backward": + copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale) + copy_tensor( + state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history + ) # Finish CPU-GPU memory transfers torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 070f46e937..18f7e2031a 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -133,10 +133,10 @@ def __init__( # Add constraints to dtypes of states. if master_weights and master_weight_dtype not in [torch.float32, torch.float16]: raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.") - if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]: - raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.") - if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]: - raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.") + if exp_avg_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg.") + if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.") # Currently, capturable mode only supports fp32 master weights and optimizer states. # The reason is, if the master weights or optimizer states are not in fp32 dtype, @@ -259,6 +259,10 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): scale (torch.Tensor): A FP32 tensor representing the scaling factor. """ assert unscaled_state.dtype == torch.float32 + if scaled_state.dtype == torch.bfloat16: + scaled_state.copy_(unscaled_state.bfloat16()) + return + dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: assert isinstance(scaled_state, Float8Tensor) @@ -313,8 +317,11 @@ def get_unscaled_state(self, param, state_name): else: assert state[state_name].dtype == torch.float32 unscaled = state[state_name] + elif dtype == torch.bfloat16: + assert state[state_name].dtype == torch.bfloat16 + unscaled = state[state_name].float() else: - raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") + raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.") return unscaled def set_scaled_state(self, param, state_name, unscaled_state): @@ -329,6 +336,7 @@ def set_scaled_state(self, param, state_name, unscaled_state): and 'master_param`. unscaled_state (torch.Tensor): The original high-precision(FP32) state. """ + store_param_remainders = ( self.store_param_remainders and state_name == "master_param" diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index dd2f60deba..d88047a012 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -4,14 +4,16 @@ """MoE Permutaion API""" import warnings -from typing import Tuple +from typing import Optional, Tuple import torch import transformer_engine_torch as tex import transformer_engine.pytorch.triton.permutation as triton_permutation from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.float8_tensor import Float8Tensor - +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor __all__ = [ "moe_permute", @@ -46,17 +48,7 @@ def forward( assert inp.size(0) == index.size(0), "Permute not possible" # Data type check - fp8 = isinstance(inp, Float8Tensor) - if fp8: - assert ( - inp._quantizer.scale.ndim == 0 - ), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute." - dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data - else: - dtype = TE_DType[inp.dtype] + dtype = TE_DType[inp.dtype] if index.dtype != torch.int32: warnings.warn( f"The data type of the input `index` of Permute is {index.dtype}! " @@ -80,19 +72,9 @@ def forward( _moe_permute_index_map.max_expanded_token_num, ) - if fp8: - permuted_act = Float8Tensor( - data=permuted_act, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv, - shape=permuted_act.shape, - dtype=fake_dtype, - ) - ctx.row_id_map = row_id_map ctx.num_tokens = index.size(0) ctx.topK = index.size(1) - ctx.fp8 = fp8 return permuted_act, row_id_map @staticmethod @@ -109,30 +91,12 @@ def backward( if not permuted_act_grad.is_contiguous(): permuted_act_grad = permuted_act_grad.contiguous() - if ctx.fp8: - assert isinstance( - permuted_act_grad, Float8Tensor - ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." - dtype = permuted_act_grad._fp8_dtype - fp8_scale_inv = permuted_act_grad._scale_inv - fake_dtype = permuted_act_grad.dtype - permuted_act_grad = permuted_act_grad._data - else: - dtype = TE_DType[permuted_act_grad.dtype] - + dtype = TE_DType[permuted_act_grad.dtype] act_grad = None if ctx.needs_input_grad[0]: act_grad = tex.moe_permute_bwd( permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK ) - if ctx.fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv * ctx.topK, - shape=act_grad.shape, - dtype=fake_dtype, - ) return act_grad, None, None, None @@ -176,14 +140,7 @@ def forward( assert row_id_map.is_cuda, "TransformerEngine needs CUDA." # Data type check - fp8 = isinstance(inp, Float8Tensor) - if fp8: - dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data - else: - dtype = TE_DType[inp.dtype] + dtype = TE_DType[inp.dtype] if row_id_map.dtype != torch.int32: warnings.warn( f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " @@ -193,17 +150,7 @@ def forward( unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) - if fp8: - unpermuted_output = Float8Tensor( - data=unpermuted_output, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv, - shape=unpermuted_output.shape, - dtype=fake_dtype, - ) - ctx.save_for_backward(inp, row_id_map, probs) - ctx.fp8 = fp8 return unpermuted_output @staticmethod @@ -219,17 +166,7 @@ def backward( if not unpermuted_act_grad.is_contiguous(): unpermuted_act_grad = unpermuted_act_grad.contiguous() - if ctx.fp8: - assert isinstance( - unpermuted_act_grad, Float8Tensor - ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." - dtype = unpermuted_act_grad._fp8_dtype - fp8_scale_inv = unpermuted_act_grad._scale_inv - fake_dtype = unpermuted_act_grad.dtype - unpermuted_act_grad = unpermuted_act_grad._data - else: - dtype = TE_DType[unpermuted_act_grad.dtype] - + dtype = TE_DType[unpermuted_act_grad.dtype] inp, row_id_map, probs = ctx.saved_tensors act_grad = None @@ -238,14 +175,6 @@ def backward( act_grad, prob_grad = tex.moe_unpermute_bwd( unpermuted_act_grad, inp, dtype, row_id_map, probs ) - if ctx.fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) if not ctx.needs_input_grad[2]: prob_grad = None @@ -282,29 +211,86 @@ def forward( row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) - fp8 = isinstance(inp, Float8Tensor) + fp8 = isinstance(inp, QuantizedTensor) + per_tensor_recipe = isinstance(inp, Float8Tensor) + blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(inp, MXFP8Tensor) + if fp8: fp8_dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv fake_dtype = inp.dtype - inp = inp._data - output, permuted_probs = triton_permutation.permute_with_mask_map( + # blockwise scaling + if blockwise_recipe: + fp8_scale = inp._rowwise_scale_inv.T.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + # mxfp8 scaling + elif mxfp8_recipe: + fp8_scale = inp._rowwise_scale_inv.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + # per-tensor scaling + elif per_tensor_recipe: + # Kernel does not need scale in per-tensor scaling + fp8_scale = None + scale_hidden_dim = None + fp8_scale_inv = inp._scale_inv + inp = inp._data + else: + raise ValueError("Unsupported FP8 recipe") + else: + fp8_scale = None + fp8_dtype = None + scale_hidden_dim = None + + output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map( inp, row_id_map, probs, + fp8_scale, num_tokens, num_experts, num_out_tokens, hidden_size, + scale_hidden_dim, ) + if fp8: - output = Float8Tensor( - data=output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=output.shape, - dtype=fake_dtype, - ) + if per_tensor_recipe: + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) + elif blockwise_recipe: + output = Float8BlockwiseQTensor( + shape=output.shape, + dtype=fake_dtype, + rowwise_data=output, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, + requires_grad=output.requires_grad, + ) + elif mxfp8_recipe: + output = MXFP8Tensor( + shape=output.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=output, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=output.requires_grad, + ) ctx.save_for_backward(row_id_map) ctx.num_experts = num_experts @@ -327,14 +313,9 @@ def backward( probs_grad = None if ctx.needs_input_grad[0]: (row_id_map,) = ctx.saved_tensors - fp8 = isinstance(permuted_act_grad, Float8Tensor) - if fp8: - fp8_dtype = permuted_act_grad._fp8_dtype - fp8_scale_inv = permuted_act_grad._scale_inv - fake_dtype = permuted_act_grad.dtype - permuted_act_grad = permuted_act_grad._data - else: - fp8_dtype = None + assert not isinstance( + permuted_act_grad, QuantizedTensor + ), "The backward of moe_permute does not support FP8." act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( permuted_act_grad, row_id_map, @@ -343,16 +324,7 @@ def backward( ctx.num_tokens, ctx.num_experts, ctx.hidden_size, - fp8_dtype, ) - if fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv * ctx.num_experts, - shape=act_grad.shape, - dtype=fake_dtype, - ) if not ctx.needs_input_grad[3]: probs_grad = None return act_grad, None, None, probs_grad @@ -366,8 +338,8 @@ def forward( ctx, inp: torch.Tensor, row_id_map: torch.Tensor, - merging_probs: torch.Tensor, - restore_shape: torch.Size, + merging_probs: Optional[torch.Tensor], + restore_shape: Optional[torch.Size], ) -> torch.Tensor: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -387,17 +359,9 @@ def forward( assert inp.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA." - fp8 = isinstance(inp, Float8Tensor) - if fp8: - fp8_dtype = inp._fp8_dtype - if not with_probs: - fp8_scale_inv = inp._scale_inv * num_experts - else: - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data - else: - fp8_dtype = None + assert not isinstance( + inp, QuantizedTensor + ), "The forward of moe_unpermute does not support FP8." unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( inp, row_id_map, @@ -406,16 +370,7 @@ def forward( num_tokens, num_experts, hidden_size, - fp8_dtype=fp8_dtype, ) - if fp8: - unpermuted_output = Float8Tensor( - data=unpermuted_output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=unpermuted_output.shape, - dtype=fake_dtype, - ) if with_probs: ctx.save_for_backward(inp, row_id_map, merging_probs) @@ -442,16 +397,44 @@ def backward(ctx, unpermuted_act_grad): else: (row_id_map,) = ctx.saved_tensors - fp8 = isinstance(unpermuted_act_grad, Float8Tensor) + fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) + per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) + blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor) + if fp8: fp8_dtype = unpermuted_act_grad._fp8_dtype - fp8_scale_inv = unpermuted_act_grad._scale_inv fake_dtype = unpermuted_act_grad.dtype - unpermuted_act_grad = unpermuted_act_grad._data + # per-tensor scaling + if per_tensor_recipe: + # Kernel does not need scale in per-tensor scaling + fp8_scale = None + scale_hidden_dim = None + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + # blockwise scaling + elif blockwise_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + # mxfp8 scaling + elif mxfp8_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + else: + raise ValueError("Unsupported FP8 recipe") else: + scale_hidden_dim = None fp8_dtype = None + fp8_scale = None if ctx.with_probs: + assert ( + not fp8 + ), "The backward of moe_unpermute with merging probs does not support FP8." act_grad, probs_grad = ( triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( unpermuted_act_grad, @@ -462,28 +445,55 @@ def backward(ctx, unpermuted_act_grad): ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, - fp8_dtype, ) ) else: - act_grad, _ = triton_permutation.permute_with_mask_map( + act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map( unpermuted_act_grad, row_id_map, None, + fp8_scale, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, + scale_hidden_dim, ) if fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) + if per_tensor_recipe: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) + elif blockwise_recipe: + act_grad = Float8BlockwiseQTensor( + shape=act_grad.shape, + dtype=fake_dtype, + rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, + requires_grad=act_grad.requires_grad, + ) + elif mxfp8_recipe: + act_grad = MXFP8Tensor( + shape=act_grad.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=act_grad.requires_grad, + ) if not ctx.needs_input_grad[2]: probs_grad = None @@ -568,10 +578,10 @@ def moe_permute_with_probs( def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, - merging_probs: torch.Tensor = None, - restore_shape: torch.Tensor = None, + merging_probs: Optional[torch.Tensor] = None, + restore_shape: Optional[torch.Size] = None, map_type: str = "mask", - probs: torch.Tensor = None, + probs: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -588,7 +598,7 @@ def moe_unpermute( The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. - restore_shape: torch.Tensor + restore_shape: torch.Size, default = None The output shape after the unpermute operation. map_type: str, default = 'mask' Type of the routing map tensor. Should be the same as the value passed to moe_permute. diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 22b86fbcc6..7fa12cc087 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -42,3 +42,27 @@ def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: torch.nn.Module.float = _make_module_cast_func(torch.float32) torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) + + +def get_all_tensor_types(): + """ + Get all tensor-like types that can be used in TE. + """ + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ) + + all_tensor_types = [ + torch.Tensor, + torch.nn.Parameter, + Float8Tensor, + Float8TensorBase, + MXFP8Tensor, + MXFP8TensorBase, + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ] + return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 9135237854..7dc380606d 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase: def __new__( cls, *args, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, @@ -71,10 +71,16 @@ def get_metadata(self) -> Dict[str, Any]: def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: - """Prepare the tensor base for saving for backward""" + """ + Prepare the tensor base for saving for backward + + This does not clear the tensors currently, because with PP config + that clears the weight cache between micro-batches. If the rowwise + data is not required for backward, this is a possible memory + pessimization, but is consistent with the other quantized tensor + classes. + """ tensors = [self._rowwise_data, self._columnwise_data] - self._rowwise_data = None - self._columnwise_data = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 2fea2c4f28..2b54e9ed79 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -27,12 +27,14 @@ def forward( dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - dtype = torch_to_transformer_engine_dtype[dtype] + te_dtype = torch_to_transformer_engine_dtype[dtype] # Make sure FP8 data is in expected format if tensor._data is not None: + if tensor._data.numel() == 0: + return torch.empty_like(tensor._data, dtype=dtype) # Cast from FP8 - return tex.dequantize(tensor, dtype) + return tex.dequantize(tensor, te_dtype) raise NotImplementedError("Casting back from the transpose not implemented yet!") diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 138d1fd29e..695c5ffb8c 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -44,7 +44,6 @@ def __init__( block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - assert rowwise self.dtype = fp8_dtype self.block_len = 128 self.force_pow_2_scales = force_pow_2_scales @@ -168,6 +167,11 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: colwise_shape.append(shape[i]) return tuple(colwise_shape) + # TODO(kwyss): With FP8 gather support, we need to implement a + # shape/layout/swizzle check to know whether FP8 gather works + # cleanly by stacking data without aliasing tiles and whether + # the scales also stack on the proper dimensions. + def make_empty( self, shape: Iterable[int], @@ -181,13 +185,16 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, - dtype=torch.float32, - device=device, - ) + data = None + scale_inv = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -489,7 +496,6 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv - dst.dtype = src.dtype # Check that tensor dimensions match if ( diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 843c7936f2..2694319a0f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -347,6 +347,7 @@ def _make_in_reduce_ex( columnwise_scale_inv: torch.Tensor, fp8_dtype: TE_DType, dtype: torch.dtype, + shape: torch.shape, ) -> MXFP8Tensor: """Build MXFP8Tensor, for use in __reduce__ @@ -361,10 +362,11 @@ def _make_in_reduce_ex( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling""" return ( MXFP8Tensor._make_in_reduce_ex, ( @@ -374,6 +376,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._fp8_dtype, self.dtype, + self.shape, ), ) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 019aca9f60..aa433e58bc 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -37,7 +37,8 @@ def prepare_for_saving( def restore_from_saved( tensors: list[Optional[Any]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], -) -> list[Optional[Any]]: + return_saved_tensors: bool = False, +) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]: """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] for tensor in tensors: @@ -47,6 +48,9 @@ def restore_from_saved( else: saved_tensors = tensor.restore_from_saved(saved_tensors) tensor_objects.append(tensor) + + if return_saved_tensors: + return tensor_objects, saved_tensors return tensor_objects @@ -113,7 +117,11 @@ def update_quantized( """Quantize tensor in-place""" def quantize( - self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None + self, + tensor: torch.Tensor, + *, + out: Optional[QuantizedTensor] = None, + dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override ) -> QuantizedTensor: """Quantize tensor""" if out is not None: diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 33c0953d94..8dd04b52d0 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -305,4 +305,11 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo amax=torch.Tensor(), fp8_dtype=model_weight._fp8_dtype, ) + if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor): + # NOTE: The fsdp shard model weight may be a unit8 tensor instead of + # a float8 tensor. We should handle this situation properly. + model_weight_fragment = quantizer.create_tensor_from_data( + model_weight_fragment.view(-1), + model_weight.dtype, + ) quantizer.update_quantized(master_weight, model_weight_fragment) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index d829275777..ef7c4c8ab2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -11,6 +11,7 @@ import torch from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention import ( MultiheadAttention, ) @@ -33,6 +34,7 @@ dist_group_type, ) from transformer_engine.pytorch.distributed import get_distributed_world_size +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module): head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -277,6 +281,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -336,6 +341,8 @@ def __init__( self.attn_input_format = attn_input_format + self.name = name + attention_args = ( hidden_size, num_attention_heads, @@ -376,6 +383,7 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, + name=name + ".self_attention" if name is not None else None, ) if layer_type == "decoder": @@ -389,6 +397,7 @@ def __init__( return_bias=True, normalization=normalization, device=device, + name=name + ".inter_attention" if name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -423,6 +432,7 @@ def __init__( activation=activation, normalization=normalization, device=device, + name=name + ".layernorm_mlp" if name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -679,6 +689,9 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 1c5fd73581..ebf8dd551e 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -10,8 +10,6 @@ import triton import triton.language as tl -from transformer_engine_torch import DType as TE_DType - @triton.jit def _row_id_map_pass_1_kernel( @@ -116,11 +114,14 @@ def _permute_kernel( output_ptr, row_id_map_ptr, probs_ptr, + scale_ptr, permuted_probs_ptr, + permuted_scale_ptr, # sizes num_tokens, num_experts, hidden_size, + scale_hidden_dim, # strides stride_input_token, stride_input_hidden, @@ -128,9 +129,14 @@ def _permute_kernel( stride_output_hidden, stride_probs_token, stride_probs_expert, + stride_scale_token, + stride_scale_hidden, stride_permuted_probs_token, + stride_permuted_scale_token, + stride_permuted_scale_hidden, # metas PERMUTE_PROBS: tl.constexpr, + PERMUTE_SCALE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -140,11 +146,21 @@ def _permute_kernel( mask = cur_off < hidden_size input_off = pid * stride_input_token + cur_off * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) + if PERMUTE_SCALE: + mask_scale = cur_off < scale_hidden_dim + scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden + scale = tl.load(scale_ptr + scale_off, mask=mask_scale) for expert_idx in range(num_experts): dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) if dst_row != -1: output_off = dst_row * stride_output_token + cur_off * stride_output_hidden tl.store(output_ptr + output_off, inp, mask=mask) + if PERMUTE_SCALE: + permuted_scale_off = ( + dst_row * stride_permuted_scale_token + + cur_off * stride_permuted_scale_hidden + ) + tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) if PERMUTE_PROBS: if cur_pos == 0: prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert @@ -173,10 +189,12 @@ def permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor, + scale: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, hidden_size: int, + scale_hidden_dim: int, ): # pylint: disable=missing-function-docstring output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") @@ -184,26 +202,42 @@ def permute_with_mask_map( permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") else: permuted_probs = None + + if scale is not None: + permuted_scale = torch.empty( + (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda" + ) + else: + permuted_scale = None + grid = (num_tokens,) _permute_kernel[grid]( inp, output, row_id_map, probs, + scale, permuted_probs, + permuted_scale, num_tokens, num_experts, hidden_size, + scale_hidden_dim, inp.stride(0), inp.stride(1), output.stride(0), output.stride(1), probs.stride(0) if probs is not None else None, probs.stride(1) if probs is not None else None, + scale.stride(0) if scale is not None else None, + scale.stride(1) if scale is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None, + permuted_scale.stride(0) if permuted_scale is not None else None, + permuted_scale.stride(1) if permuted_scale is not None else None, PERMUTE_PROBS=probs is not None, + PERMUTE_SCALE=scale is not None, ) - return output, permuted_probs + return output, permuted_scale, permuted_probs @triton.jit @@ -232,18 +266,9 @@ def _unpermute_kernel( # metas WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, - FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - if FP8_DTYPE == "e5m2": - data_type = tl.float8e5 - pytorch_tensor_dtype = tl.uint8 - elif FP8_DTYPE == "e4m3": - data_type = tl.float8e4nv - pytorch_tensor_dtype = tl.uint8 - else: - data_type = input_ptr.dtype.element_ty - assert FP8_DTYPE is None + data_type = input_ptr.dtype.element_ty compute_type = tl.float32 pid = tl.program_id(0) @@ -257,8 +282,6 @@ def _unpermute_kernel( if src_row != -1: input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) - if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True) inp = inp.to(compute_type) if WITH_MERGING_PROBS: merging_prob_off = ( @@ -279,14 +302,7 @@ def _unpermute_kernel( tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) else: tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) - if FP8_DTYPE is not None: - if not WITH_MERGING_PROBS: - # Directly adding these value may cause overflow for fp8, we scale it here. - # The outside fp8_scale_inv is also scaled in the meantime. - accumulator /= num_experts - accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True) - else: - accumulator = accumulator.to(data_type) + accumulator = accumulator.to(data_type) output_off = pid * stride_output_token + current_offset * stride_output_hidden tl.store(output_ptr + output_off, accumulator, mask=mask) current_start += BLOCK_SIZE @@ -315,15 +331,8 @@ def unpermute_with_mask_map( num_tokens: int, num_experts: int, hidden_size: int, - fp8_dtype: TE_DType, ): # pylint: disable=missing-function-docstring - if fp8_dtype == TE_DType.kFloat8E5M2: - fp8_dtype = "e5m2" - elif fp8_dtype == TE_DType.kFloat8E4M3: - fp8_dtype = "e4m3" - else: - fp8_dtype = None output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") if permuted_probs is not None: unpermuted_probs = torch.empty( @@ -353,7 +362,6 @@ def unpermute_with_mask_map( unpermuted_probs.stride(1) if unpermuted_probs is not None else None, WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, - FP8_DTYPE=fp8_dtype, ) return output, unpermuted_probs @@ -383,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel( stride_merging_probs_grad_token, stride_merging_probs_grad_expert, # metas - FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - if FP8_DTYPE == "e5m2": - data_type = tl.float8e5 - pytorch_tensor_dtype = tl.uint8 - elif FP8_DTYPE == "e4m3": - data_type = tl.float8e4nv - pytorch_tensor_dtype = tl.uint8 - else: - data_type = fwd_output_grad_ptr.dtype.element_ty - assert FP8_DTYPE is None + data_type = fwd_output_grad_ptr.dtype.element_ty compute_type = tl.float32 pid = tl.program_id(0) @@ -411,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel( + current_offset * stride_fwd_output_grad_hidden ) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) - if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True) inp = inp.to(compute_type) merging_prob_off = ( pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert @@ -420,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel( merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) output = inp * merging_prob output = output.to(data_type) - if FP8_DTYPE is not None: - output = output.to(pytorch_tensor_dtype, bitcast=True) output_off = ( dst_row * stride_fwd_input_grad_token + current_offset * stride_fwd_input_grad_hidden @@ -432,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel( dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden ) fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) - if FP8_DTYPE is not None: - fwd_input = fwd_input.to(data_type, bitcast=True) prob_grad_accum += fwd_input.to(compute_type) * inp current_start += BLOCK_SIZE probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) @@ -474,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_experts: int, num_out_tokens: int, hidden_size: int, - fp8_dtype: TE_DType, ): # pylint: disable=missing-function-docstring - if fp8_dtype == TE_DType.kFloat8E5M2: - fp8_dtype = "e5m2" - elif fp8_dtype == TE_DType.kFloat8E4M3: - fp8_dtype = "e4m3" - else: - fp8_dtype = None act_grad = torch.empty( (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" ) @@ -510,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( merging_probs.stride(1), merging_probs_grad.stride(0), merging_probs_grad.stride(1), - fp8_dtype, ) return act_grad, merging_probs_grad diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 603c1d5de4..8450460c46 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,6 +11,7 @@ import torch import transformer_engine.pytorch.cpp_extensions as ext +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from .tensor.quantized_tensor import QuantizedTensor @@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple): return ((value + multiple - 1) // multiple) * multiple +def needs_quantized_gemm(obj, rowwise=True): + """Used to check if obj will need quantized gemm or normal gemm.""" + if isinstance(obj, DebugQuantizedTensor): + return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck + torch.Tensor, + torch.nn.Parameter, + ] + return type(obj) not in [ + torch.Tensor, + torch.nn.Parameter, + ] # pylint: disable=unidiomatic-typecheck + + @functools.lru_cache(maxsize=None) def _nvtx_enabled() -> bool: """Check if NVTX range profiling is enabled"""