From d69585a8057e93851646f7c92b179492b7b9337d Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 6 Mar 2025 16:21:28 -0800 Subject: [PATCH 01/53] Recipe setup for Linear modules. Signed-off-by: Keith Wyss --- tests/pytorch/distributed/run_numerics.py | 5 +- tests/pytorch/distributed/test_numerics.py | 4 +- .../test_float8_blockwise_gemm_exact.py | 9 +- .../test_float8_blockwise_scaling_exact.py | 157 +++++++++++++++++- .../test_float8_current_scaling_exact.py | 16 +- tests/pytorch/test_numerics.py | 12 ++ .../common/gemm/cublaslt_gemm.cu | 15 +- transformer_engine/common/recipe/__init__.py | 101 +++++++++++ .../pytorch/cpp_extensions/gemm.py | 6 +- transformer_engine/pytorch/distributed.py | 3 + transformer_engine/pytorch/fp8.py | 118 ++++++++++++- transformer_engine/pytorch/module/base.py | 19 ++- .../pytorch/module/grouped_linear.py | 2 + .../pytorch/module/layernorm_linear.py | 42 +++++ .../pytorch/module/layernorm_mlp.py | 68 +++++--- transformer_engine/pytorch/module/linear.py | 4 +- 16 files changed, 532 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ae5993eb1e..b423bce53d 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -19,6 +19,7 @@ MXFP8BlockScaling, DelayedScaling, Float8CurrentScaling, + Float8BlockScaling, Format, Recipe, ) @@ -49,6 +50,8 @@ def quantization_recipe() -> Recipe: return MXFP8BlockScaling() if QUANTIZATION == "fp8_cs": return Float8CurrentScaling() + if QUANTIZATION == "fp8_block_scaling": + return Float8BlockScaling() return te.fp8.get_default_fp8_recipe() @@ -85,7 +88,7 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE SEQ_LEN = 32 BATCH_SIZE = 32 diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index b4e2b680b3..1333349661 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -48,7 +48,7 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -56,4 +56,6 @@ def test_distributed(quantization): pytest.skip(fp8_available) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) _run_test(quantization) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 9ddb4b9989..364851a7f7 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,21 +8,18 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from transformer_engine.pytorch.utils import get_device_compute_capability from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - return ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ) + supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() + return supported def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index e638fe8c5b..7e1d28ce32 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -4,11 +4,14 @@ from typing import Tuple import math +import os +import pathlib import pytest import torch import transformer_engine as te import transformer_engine_torch as tex -from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -18,10 +21,28 @@ BlockwiseQuantizerReference, QuantizeResult, ) +from tests.pytorch.test_float8_current_scaling_exact import ( + TestFP8RecipeLinearBase, + TestFP8RecipeLayerNormLinearBase +) + +# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() + +class GetRecipes: + + @staticmethod + def none(): + return None -# TODO replace with call to fp8.py when recipe added. -recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 -reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + @staticmethod + def fp8_blockwise(): + # return default configs + return Float8BlockScaling() def initialize_for_many_scales( @@ -65,7 +86,6 @@ def initialize_for_many_scales( ] return result - @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, N", @@ -292,3 +312,130 @@ def test_quantization_block_tiling_extrema_versus_reference( atol=0.0, rtol=0.0, ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 9741b1258c..1d7ea56987 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -82,7 +82,8 @@ def _get_sum_abs_error(a, b): @staticmethod def _get_mean_abs_relative_error(a, b): - return torch.mean(torch.abs((a - b) / b)) + error = torch.where(b == 0, 0.0, torch.abs((a - b) / b)) + return torch.mean(error) @staticmethod def _load_golden_tensor_values(a, b): @@ -97,9 +98,16 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = ( + "ScalingType.PER_TENSOR" + ) + elif recipe.fp8blockwise(): + scaling_type = ( + "ScalingType.BLOCKWISE" + ) + else: + scaling_type = "Unknown" current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 35f65a75f4..a904d08c27 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -50,6 +50,7 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -104,6 +105,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] @@ -563,6 +565,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.fp8blockwise() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -675,6 +679,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.fp8blockwise() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -1528,6 +1534,8 @@ def test_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.fp8blockwise(): + pytest.skip("Grouped linear for FP8 blockwise unsupported.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1723,6 +1731,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.fp8blockwise(): + pytest.skip("Float8 block scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1933,6 +1943,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.fp8blockwise() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 483a1380ef..aa6cd11afc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -116,7 +116,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transA = CUBLAS_OP_T; ret.transB = CUBLAS_OP_N; - NVTE_CHECK(ret.lda == ret.ldb, "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); + NVTE_CHECK(ret.lda == ret.ldb, + "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); } else { // In these scaling modes, the physical layout of @@ -131,7 +132,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (transa_bool && transb_bool) { // TT NVTE_ERROR("TT layout not allowed."); - } + } } if (is_tensor_scaling(A.scaling_mode)) { @@ -294,10 +295,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, - param.transA == CUBLAS_OP_N ? k : m, param.lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, - param.transB == CUBLAS_OP_N ? n : k, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( + &Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, + param.transA == CUBLAS_OP_N ? k : m, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( + &Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, + param.transB == CUBLAS_OP_N ? n : k, param.ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b676bf6ab0..6f428c540b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -81,6 +81,9 @@ def float8_per_tensor_scaling(self): """Whether the given recipe is per-tensor scaling.""" return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + def fp8blockwise(self): + """Whether the given recipe is float8 blockwise scaling.""" + return isinstance(self, Float8BlockScaling) @dataclass() class DelayedScaling(Recipe): @@ -287,3 +290,101 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + + + +@dataclass() +class Float8BlockScaling(Recipe): + """ + Use block-wise scaling for FP8 tensors. + + In this strategy, tensors are scaled in blockwise fashion. Values within + each block share a common scaling factor. The block dimensionality + can be configured. The scaling factors are float32. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), the quantized tensor and its transpose are not numerically + equivalent. Due to this, when Transformer Engine needs both the FP8 tensor + and its transpose (e.g. to calculate both forward and backward pass), + during the quantization both versions are computed from the high precision + input to avoid double quantization errors. + + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of gradient tensor dY + x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for x. + w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for w. + grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for grad. + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_dpa: bool, default = `False` + Whether to enable FP8 dot product attention (DPA). When the model is placed in an + `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + inputs from higher precision to FP8, performs attention in FP8, and casts tensors + back to higher precision as outputs. FP8 DPA currently is only supported in the + `FusedAttention` backend. + fp8_mha: bool, default = `False` + Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting + operations mentioned above at the DPA boundaries. Currently only standard MHA modules + i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When + `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as + `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. + When `fp8_mha = True, fp8_dpa = True`, it becomes + `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + """ + + + fp8_format: Format = Format.HYBRID + fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + x_block_scaling_dim: int = 1 + w_block_scaling_dim: int = 2 + grad_block_scaling_dim: int = 1 + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" + assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" + assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" + assert not (self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2), "2D by 2D block gemm not supported." + assert not (self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2), "2D by 2D block gemm not supported." + assert not (self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2), "2D by 2D block gemm not supported." + assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." + assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." + assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"x_block_scaling_dim={self.x_block_scaling_dim}, " + f"w_block_scaling_dim={self.w_block_scaling_dim}, " + f"grad_block_scaling_dim={self.grad_block_scaling_dim}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 948a13a03e..f1030fc127 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,7 +14,7 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = [ "general_gemm", "general_grouped_gemm", @@ -112,6 +112,10 @@ def general_gemm( # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + # There is not use_split_accumulator == False + # implementation for Float8BlockwiseQTensorBase GEMM + use_split_accumulator = True args = ( A, transa, # transa diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e245b788b4..3f2f93181c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -382,6 +382,9 @@ def backward( ), fp8_autocast( enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe ): + # FIXME(kwyss): Activation recomputation should + # reuse the quantization settings that were present + # at the time of the original forward pass. outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) # Set the states back to what it was at the start of this function. diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 38f829c079..76549d165e 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -19,6 +19,7 @@ Format, MXFP8BlockScaling, Float8CurrentScaling, + Float8BlockScaling, ) from .constants import dist_group_type @@ -48,6 +49,13 @@ def check_mxfp8_support() -> Tuple[bool, str]: return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available""" + if (get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.8): + return True, "" + return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" @@ -109,6 +117,8 @@ class FP8GlobalStateManager: skip_fp8_weight_update_tensor = None mxfp8_available = None reason_for_no_mxfp8 = "" + fp8_block_scaling_available = None + reason_for_no_fp8_block_scaling = None @classmethod def reset(cls) -> None: @@ -134,6 +144,8 @@ def reset(cls) -> None: cls.skip_fp8_weight_update_tensor = None cls.mxfp8_available = None cls.reason_for_no_mxfp8 = "" + cls.fp8_block_scaling_available = None + cls.reason_for_no_fp8_block_scaling = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -161,6 +173,13 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @classmethod + def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: + """Return if Float8 block scaling support is available.""" + if cls.fp8_block_scaling_available is None: + cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = check_fp8_block_scaling_support() + return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -434,6 +453,10 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8 + if isinstance(fp8_recipe, Float8BlockScaling): + fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() + assert fp8_block_available, reason_for_no_fp8_block + @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -786,8 +809,10 @@ def create( cls = MXFP8BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState + elif recipe.fp8blockwise(): + cls = Float8BlockScalingRecipeState else: - raise ValueError("{recipe.__class__.__name__} is not supported") + raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( recipe, mode=mode, @@ -928,3 +953,94 @@ def make_quantizers(self) -> list: from .tensor.mxfp8_tensor import MXFP8Quantizer return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + + +class Float8BlockScalingRecipeState(RecipeState): + """Configuration for Float8BlockScaling quantization. + + Float8BlockScaling quantization does not require state, + but different quantizers use different modes. + """ + + recipe: Float8BlockScaling + mode: str + qx_dtype: tex.DType + qw_dtype: tex.DType + qgrad_dtype: tex.DType + + def __init__( + self, + recipe: Float8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.qx_dtype = get_fp8_te_dtype(recipe, True) + self.qw_dtype = get_fp8_te_dtype(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_blockwise_tensor import Float8BlockQuantizer + + if self.mode == "forward": + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward, and doesn't play nicely with QuantizeOp, + # which is not associated with a GEMM. + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] * (self.num_quantizers // 3) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] * (self.num_quantizers // 2) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cdb75aa1b6..d1bab93385 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -23,6 +23,7 @@ MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -34,8 +35,10 @@ ) from ..constants import dist_group_type from ..tensor import QuantizedTensor, Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["initialize_ub", "destroy_ub"] @@ -499,6 +502,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8CurrentScalingRecipeState ): return + if recipe.fp8blockwise() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -841,7 +848,7 @@ def grad_output_preprocess( if ctx.ub_overlap_ag: # Quantize the gradient if needed if not isinstance( - grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase) ): grad_output = quantizer(grad_output) @@ -859,11 +866,15 @@ def grad_output_preprocess( # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: - if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase)): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance(quantizer, Float8BlockQuantizer): + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase)): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9cd52b1e5..ad869bf0d4 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -91,6 +91,8 @@ def forward( # TODO Support Float8 Current Scaling # pylint: disable=fixme if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") + if fp8 and FP8GlobalStateManager.get_fp8_recipe().fp8blockwise(): + raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..745b3caf17 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -56,9 +56,12 @@ restore_from_saved, ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + +from transformer_engine.common.recipe import Recipe from ..cpp_extensions import ( general_gemm, ) @@ -170,10 +173,49 @@ def forward( columnwise_usage and with_input_all_gather and not isinstance(input_quantizer, MXFP8Quantizer) + and not isinstance(input_quantizer, Float8BlockQuantizer) ): columnwise_usage = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + # Configure quantizer for normalization output + with_quantized_norm = fp8 and not return_layernorm_output + # for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer + # so we need to set with_quantized_norm to False + if isinstance(input_quantizer, Float8CurrentScalingQuantizer): + with_quantized_norm = False + if isinstance(input_quantizer, Float8BlockQuantizer): + # Quantizer has not been fused with norm yet. + with_quantized_norm = False + if with_quantized_norm: + if with_input_all_gather: + if isinstance(input_quantizer, MXFP8Quantizer): + with_quantized_norm = False + + # Reduce duplicated transpose in `_fix_gathered_fp8_transpose` + if ( + fp8 + and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + and ub_bulk_dgrad + ): + input_quantizer.set_usage(rowwise=True, columnwise=False) + + ub_obj_fprop = None + ln_out = None + # For DelayScaling, output of normalization will be in fp8. + # For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8. + if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer): + ub_obj_fprop = get_ub(ub_name + "_fprop") + ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) + elif with_quantized_norm: + if with_input_all_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") + else: + ln_out = torch.empty_like( + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" + ) + # Apply normalization nvtx_range_push(f"{nvtx_label}.norm") ln_out, mu, rsigma = apply_normalization( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7dae573688..354e6ecd5a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -52,7 +52,7 @@ in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, ) - +from transformer_engine.common.recipe import Recipe from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing @@ -62,6 +62,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..tensor.quantized_tensor import ( @@ -104,17 +105,20 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } # no activation fusion written yet - # Per-tensor current scaling: [] - return { - "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), - "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, None), - "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, None), - } + # Per-tensor current scaling or fp8 blockwise scaling: [] + if recipe.float8_current_scaling() or recipe.fp8blockwise(): + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, None), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, None), + } + else: + raise NotImplementedError(f"Unhandled recipe type {recipe}") def _act_func(activation: str, recipe: Optional[Recipe] = None): @@ -122,7 +126,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None): # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] - # Per-tensor current scaling: [] + # Per-tensor current scaling or fp8 blockwise scaling: [] funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -214,6 +218,9 @@ def forward( with_quantized_norm = ( fp8 and not return_layernorm_output and not return_layernorm_output_gathered ) + if isinstance(fc1_input_quantizer, Float8BlockQuantizer): + # Kernels not available for norm fusion. + with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered @@ -336,6 +343,7 @@ def forward( # - bias_gelu_fusion - only for full precision. # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer if activation != "gelu": + # blockwise scaled gemms don't support gemm_gelu_fusion in fwd. gemm_gelu_fusion = bias_gelu_fusion = False else: if fp8: @@ -376,7 +384,12 @@ def forward( act_out, _, fc1_out, _ = fc1_outputs else: fc1_out, *_ = fc1_outputs - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8 and FP8GlobalStateManager.get_fp8_recipe().fp8blockwise(): + # tex.quantize does not support GELU fusion for blockwise. + act_out = activation_func(fc1_out, None) + act_out = tex.quantize(act_out, fc2_input_quantizer) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -505,6 +518,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -712,14 +726,16 @@ def backward( ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - # There are 5 possible fusion paths + # There are 6 possible fusion paths # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + not ctx.fp8 and (ctx.activation == "gelu") + and (not ctx.bias_gelu_fusion) ) # FC2 DGRAD; Unconditional @@ -753,6 +769,9 @@ def backward( if isinstance(grad_output, QuantizedTensor): grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + grad_arg = True + if ctx.fp8 and ctx.fp8_recipe.fp8blockwise(): + grad_arg = False fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( act_out, grad_output, @@ -764,14 +783,18 @@ def backward( ), quantization_params=None, # wgrad in high precision layout="NT", - grad=True, - bias=fc2_bias if fc2_bias_grad is None else None, + grad=grad_arg, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if fc2_bias_grad is None: + if ctx.fp8 and ctx.fp8_recipe.fp8blockwise() and fc2_bias is not None: + # BGRAD not fused with GEMM for float8 blockwise gemm. + fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ clear_tensor_data(act_out) # bias computation @@ -808,7 +831,14 @@ def backward( ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + # TODO float8 blockwise current scaling has no bgrad fusion for now + if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer): + fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) + dact = ctx.grad_fc1_output_quantizer(dact) + else: + fc1_bias_grad, dact = tex.bgrad_quantize( + dact, ctx.grad_fc1_output_quantizer + ) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..185a91ab83 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -61,6 +61,8 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from transformer_engine.common.recipe import Recipe + __all__ = ["Linear"] @@ -322,8 +324,8 @@ def forward( ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.input_quantizer = input_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer From 754e0bdddd08434fff0367245d9ee5f4ac04ecf4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 11 Mar 2025 13:18:15 -0700 Subject: [PATCH 02/53] Use 12.9 feature test. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_scaling_exact.py | 4 +++- .../test_float8_current_scaling_exact.py | 8 ++----- tests/pytorch/test_numerics.py | 4 +++- .../common/gemm/cublaslt_gemm.cu | 15 +++++-------- transformer_engine/common/recipe/__init__.py | 15 ++++++++----- .../pytorch/cpp_extensions/gemm.py | 1 + transformer_engine/pytorch/fp8.py | 17 +++++++++----- transformer_engine/pytorch/module/base.py | 22 ++++++++++++++----- .../pytorch/module/layernorm_mlp.py | 3 +-- 9 files changed, 53 insertions(+), 36 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 7e1d28ce32..e4a26d128f 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -23,7 +23,7 @@ ) from tests.pytorch.test_float8_current_scaling_exact import ( TestFP8RecipeLinearBase, - TestFP8RecipeLayerNormLinearBase + TestFP8RecipeLayerNormLinearBase, ) # read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory @@ -33,6 +33,7 @@ TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() + class GetRecipes: @staticmethod @@ -86,6 +87,7 @@ def initialize_for_many_scales( ] return result + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, N", diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 1d7ea56987..75ee2939c2 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -99,13 +99,9 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias # Expected tensor names based on the naming template if recipe.float8_current_scaling(): - scaling_type = ( - "ScalingType.PER_TENSOR" - ) + scaling_type = "ScalingType.PER_TENSOR" elif recipe.fp8blockwise(): - scaling_type = ( - "ScalingType.BLOCKWISE" - ) + scaling_type = "ScalingType.BLOCKWISE" else: scaling_type = "Unknown" current_seed = torch.initial_seed() # Get the current seed diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a904d08c27..54df59fe65 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -50,7 +50,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) sm_80plus = get_device_compute_capability() >= (8, 0) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index aa6cd11afc..483a1380ef 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -116,8 +116,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transA = CUBLAS_OP_T; ret.transB = CUBLAS_OP_N; - NVTE_CHECK(ret.lda == ret.ldb, - "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); + NVTE_CHECK(ret.lda == ret.ldb, "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); } else { // In these scaling modes, the physical layout of @@ -132,7 +131,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (transa_bool && transb_bool) { // TT NVTE_ERROR("TT layout not allowed."); - } + } } if (is_tensor_scaling(A.scaling_mode)) { @@ -295,12 +294,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( - &Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, - param.transA == CUBLAS_OP_N ? k : m, param.lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( - &Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, - param.transB == CUBLAS_OP_N ? n : k, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, + param.transA == CUBLAS_OP_N ? k : m, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, + param.transB == CUBLAS_OP_N ? n : k, param.ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 6f428c540b..a1fb9527cb 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -85,6 +85,7 @@ def fp8blockwise(self): """Whether the given recipe is float8 blockwise scaling.""" return isinstance(self, Float8BlockScaling) + @dataclass() class DelayedScaling(Recipe): """ @@ -292,7 +293,6 @@ def __repr__(self) -> str: return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," - @dataclass() class Float8BlockScaling(Recipe): """ @@ -348,7 +348,6 @@ class Float8BlockScaling(Recipe): `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. """ - fp8_format: Format = Format.HYBRID fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) @@ -366,9 +365,15 @@ def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" - assert not (self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2), "2D by 2D block gemm not supported." - assert not (self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2), "2D by 2D block gemm not supported." - assert not (self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2), "2D by 2D block gemm not supported." + assert not ( + self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index f1030fc127..79d6391e79 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -15,6 +15,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase + __all__ = [ "general_gemm", "general_grouped_gemm", diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 76549d165e..78c8ad6d2d 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -49,14 +49,18 @@ def check_mxfp8_support() -> Tuple[bool, str]: return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." + def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" - if (get_device_compute_capability() >= (9, 0) + if ( + get_device_compute_capability() >= (9, 0) and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.8): + and float(torch.version.cuda) >= 12.9 + ): return True, "" return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -177,7 +181,9 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: """Return if Float8 block scaling support is available.""" if cls.fp8_block_scaling_available is None: - cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = check_fp8_block_scaling_support() + cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( + check_fp8_block_scaling_support() + ) return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling @staticmethod @@ -457,7 +463,6 @@ def fp8_autocast_enter( fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() assert fp8_block_available, reason_for_no_fp8_block - @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" @@ -996,7 +1001,7 @@ def make_quantizers(self) -> list: # The index convention (coming from base.py set_meta_tensor) # is somewhat awkward, and doesn't play nicely with QuantizeOp, # which is not associated with a GEMM. - assert self.num_quantizers % 3 == 0 # x, w, output per gemm + assert self.num_quantizers % 3 == 0 # x, w, output per gemm return [ Float8BlockQuantizer( fp8_dtype=self.qx_dtype, @@ -1025,7 +1030,7 @@ def make_quantizers(self) -> list: ] * (self.num_quantizers // 3) assert self.mode == "backward", f"Unexpected mode {self.mode}" - assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm return [ Float8BlockQuantizer( fp8_dtype=self.qgrad_dtype, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d1bab93385..f136df864b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -502,9 +502,7 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8CurrentScalingRecipeState ): return - if recipe.fp8blockwise() and isinstance( - recipe_state, Float8BlockScalingRecipeState - ): + if recipe.fp8blockwise() and isinstance(recipe_state, Float8BlockScalingRecipeState): return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and @@ -848,7 +846,13 @@ def grad_output_preprocess( if ctx.ub_overlap_ag: # Quantize the gradient if needed if not isinstance( - grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase) + grad_output, + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), ): grad_output = quantizer(grad_output) @@ -866,7 +870,10 @@ def grad_output_preprocess( # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: - if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase)): + if isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: if isinstance(quantizer, Float8BlockQuantizer): @@ -874,7 +881,10 @@ def grad_output_preprocess( grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase)): + if not isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 354e6ecd5a..4b065add64 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -734,8 +734,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 and (ctx.activation == "gelu") - and (not ctx.bias_gelu_fusion) + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) ) # FC2 DGRAD; Unconditional From 14839969f2a8affad461624587b740991c8fae40 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 11 Mar 2025 20:01:11 -0700 Subject: [PATCH 03/53] Run against tensor dumps from internal library. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_scaling_exact.py | 2 +- tests/pytorch/test_float8_current_scaling_exact.py | 12 ++++++++---- transformer_engine/common/recipe/__init__.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index e4a26d128f..44af51e78a 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -435,7 +435,7 @@ def test_fp8_current_scaling_with_layernorm_linear_module( dtype=dtype, y_error=0.5, ln_out_error=0.5, - dgrad_error=1, + dgrad_error=1.6, wgrad_error=1, bgrad_error=0.5, recipe1_golden_tensors=None, diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 75ee2939c2..d8c19143d7 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -101,7 +101,7 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias if recipe.float8_current_scaling(): scaling_type = "ScalingType.PER_TENSOR" elif recipe.fp8blockwise(): - scaling_type = "ScalingType.BLOCKWISE" + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" else: scaling_type = "Unknown" current_seed = torch.initial_seed() # Get the current seed @@ -441,9 +441,13 @@ def _check_golden_tensor_dumps( fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.fp8blockwise(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" + current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index a1fb9527cb..b17cd2fe72 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -348,7 +348,7 @@ class Float8BlockScaling(Recipe): `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. """ - fp8_format: Format = Format.HYBRID + fp8_format: Format = Format.E4M3 fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) From f0dadc561dc374556a21d7419c4edd2dc1a15210 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 12 Mar 2025 17:00:15 -0700 Subject: [PATCH 04/53] Update FIXME to TODO with linked issue. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 3f2f93181c..0c0628823b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -382,7 +382,7 @@ def backward( ), fp8_autocast( enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe ): - # FIXME(kwyss): Activation recomputation should + # TODO(github issues/1568): Activation recomputation should # reuse the quantization settings that were present # at the time of the original forward pass. outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) From e4f2c28edc599589e0c14a879b164e4de9028dff Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 13 Mar 2025 11:16:20 -0700 Subject: [PATCH 05/53] Update full recompute feature to save recipe. The recompute context uses the same recipe and fp8 settings as the original fwd pass. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 0c0628823b..e245b788b4 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -382,9 +382,6 @@ def backward( ), fp8_autocast( enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe ): - # TODO(github issues/1568): Activation recomputation should - # reuse the quantization settings that were present - # at the time of the original forward pass. outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) # Set the states back to what it was at the start of this function. From ea8f53e2da586e778d46b4708ca3ce69df8bfb7a Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 13 Mar 2025 11:37:05 -0700 Subject: [PATCH 06/53] MR Feedback. Avoid reusing quantizer objects. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/fp8.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 78c8ad6d2d..ec59ab8b52 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import itertools import os from contextlib import contextmanager from collections import deque @@ -1002,7 +1003,7 @@ def make_quantizers(self) -> list: # is somewhat awkward, and doesn't play nicely with QuantizeOp, # which is not associated with a GEMM. assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return [ + return list(itertools.chain.from_iterable([[ Float8BlockQuantizer( fp8_dtype=self.qx_dtype, rowwise=True, @@ -1027,11 +1028,11 @@ def make_quantizers(self) -> list: force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, block_scaling_dim=self.recipe.x_block_scaling_dim, ), - ] * (self.num_quantizers // 3) + ] for _ in range(self.num_quantizers // 3)])) assert self.mode == "backward", f"Unexpected mode {self.mode}" assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return [ + return list(itertools.chain.from_iterable([[ Float8BlockQuantizer( fp8_dtype=self.qgrad_dtype, rowwise=True, @@ -1048,4 +1049,4 @@ def make_quantizers(self) -> list: force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, block_scaling_dim=self.recipe.grad_block_scaling_dim, ), - ] * (self.num_quantizers // 2) + ] for _ in range(self.num_quantizers // 2)])) From dfcb3df238feab047642c43dbd7ad8489b2d5905 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 13 Mar 2025 12:16:18 -0700 Subject: [PATCH 07/53] Update logic in module. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4b065add64..6c0b7fe449 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1585,7 +1585,7 @@ def _get_quantizers(self, fp8_output): fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( - rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + rowwise=True, columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)) ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True From b938c3e4e06aa12a69c04b0faf8ccf1d663635b7 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 14 Mar 2025 11:41:07 -0700 Subject: [PATCH 08/53] Format py. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/fp8.py | 102 ++++++++++-------- .../pytorch/module/layernorm_mlp.py | 3 +- 2 files changed, 60 insertions(+), 45 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index ec59ab8b52..681bddd774 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1003,50 +1003,64 @@ def make_quantizers(self) -> list: # is somewhat awkward, and doesn't play nicely with QuantizeOp, # which is not associated with a GEMM. assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return list(itertools.chain.from_iterable([[ - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - ] for _ in range(self.num_quantizers // 3)])) + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] + ) + ) assert self.mode == "backward", f"Unexpected mode {self.mode}" assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return list(itertools.chain.from_iterable([[ - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - ] for _ in range(self.num_quantizers // 2)])) + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] + ) + ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c0b7fe449..9b59fa9ab9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1585,7 +1585,8 @@ def _get_quantizers(self, fp8_output): fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( - rowwise=True, columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)) + rowwise=True, + columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True From c8f6322ad942649805bc3e789fa78f8affc5475c Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 31 Mar 2025 10:29:41 -0700 Subject: [PATCH 09/53] Update for PP bug. Signed-off-by: Keith Wyss --- .../tensor/_internal/float8_blockwise_tensor_base.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 9135237854..c33535f13c 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -71,10 +71,16 @@ def get_metadata(self) -> Dict[str, Any]: def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: - """Prepare the tensor base for saving for backward""" + """ + Prepare the tensor base for saving for backward + + This does not clear the tensors currently, because with PP config + that clears the weight cache between micro-batches. If the rowwise + data is not required for backward, this is a possible memory + pessimization, but is consistent with the other quantized tensor + classes. + """ tensors = [self._rowwise_data, self._columnwise_data] - self._rowwise_data = None - self._columnwise_data = None return tensors, self def restore_from_saved( From 4c5f51f2c56a715ffedafb61436f1db244c82119 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 31 Mar 2025 11:20:51 -0700 Subject: [PATCH 10/53] Update test numerics. Signed-off-by: Keith Wyss --- tests/pytorch/distributed/test_numerics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 1333349661..632f50e90a 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -28,6 +28,9 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) From cad09a96f038da05e827dfa8141815cf8b6a3c26 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 11:56:36 -0700 Subject: [PATCH 11/53] Update force_power_of_2 scales in the recipe. Signed-off-by: Keith Wyss --- transformer_engine/common/recipe/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b17cd2fe72..89bfc69137 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -349,9 +349,9 @@ class Float8BlockScaling(Recipe): """ fp8_format: Format = Format.E4M3 - fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_inp = QParams(power_2_scale=True, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=True, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=True, amax_epsilon=0.0) x_block_scaling_dim: int = 1 w_block_scaling_dim: int = 2 grad_block_scaling_dim: int = 1 From a9e31782ee5c85230c7d3ffcfab242450f5660b2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 14:09:05 -0700 Subject: [PATCH 12/53] Update usage method to satisfy upstream changes. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/module/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 185a91ab83..bd2811cdb3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -146,7 +146,7 @@ def forward( if with_input_all_gather_nccl: if not isinstance(inputmat, QuantizedTensor): columnwise_usage = backward_needs_input and isinstance( - input_quantizer, MXFP8Quantizer + input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer) ) input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) inputmat = input_quantizer(inputmat) From ac65cee10696f9ad5942d963e64851733eb5c1fa Mon Sep 17 00:00:00 2001 From: zhongboz Date: Wed, 2 Apr 2025 19:09:01 -0700 Subject: [PATCH 13/53] fix subchannel recipe in distributed test with bf16 gather Signed-off-by: zhongboz --- transformer_engine/pytorch/distributed.py | 164 +++++++++++++++++- .../pytorch/module/layernorm_linear.py | 7 +- .../pytorch/module/layernorm_mlp.py | 7 +- transformer_engine/pytorch/module/linear.py | 35 ++-- .../pytorch/tensor/float8_blockwise_tensor.py | 12 ++ 5 files changed, 209 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e245b788b4..0dc4777838 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -24,10 +24,11 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.float8_blockwise_tensor import Float8BlockQuantizer, Float8BlockwiseQTensor from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -936,6 +937,157 @@ def _all_gather_fp8( return out, handle +def _all_gather_fp8_blockwise( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: Optional[Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[Float8BlockwiseQTensorBase, Optional[torch.distributed.Work]]: + """All-gather FP8 tensor along first dimension for blockwise quantization.""" + + # Input tensor attributes + in_shape: Iterable[int] + device: torch.device + dtype: torch.dtype + if isinstance(inp, torch.Tensor): + in_shape = inp.size() + device = inp.device + dtype = inp.dtype + elif isinstance(inp, Float8BlockwiseQTensorBase): + if inp._rowwise_data is not None: + in_shape = inp._rowwise_data.device.size() + device = inp._rowwise_data.device + dtype = inp._rowwise_data.dtype + elif inp._columnwise_data is not None: + in_shape = inp._columnwise_data.device.size() + device = inp._columnwise_data.device + dtype = inp._columnwise_data.dtype + else: + raise ValueError("Got Float8BlockwiseQTensor input tensor without any data") + dtype = torch.bfloat16 + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " + f"found {inp.__class__.__name__})" + ) + + world_size = get_distributed_world_size(process_group) + + # Check that quantizer is valid + if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): + raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") + else: + assert quantizer.block_scaling_dim == 1 and quantizer.block_len == 128, "Only 1D blockwise quantization is supported for allgather" + + # Output tensor dims + if out_shape is None: + out_shape = list(inp.size()) + out_shape[0] *= world_size + + # Since gather is for inp tensor or grad tensor, the block scaling dimension should be 1D 1x128 + # This means that we need to do quantization in both directions, and then gather twice + # This also requires a bidirectional quantization without transpose kernel + # Scaling factors should be in compact format without sizzling fused for the purpose of communication + # Doing BF16 gather for now as baseline because it's simpler + if ( + not isinstance(inp, Float8BlockwiseQTensorBase) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None + + raise Exception("Not implemented for doing fp8 blockwise allgather") + + # TODO: finish the scheleton below for doing fp8 blockwise allgather + + # # Cast input tensor to Float8BlockwiseQTensor with required data + # if not isinstance(inp, Float8BlockwiseQTensorBase): + # inp = quantizer(inp) + # elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( + # quantizer.columnwise_usage and inp._columnwise_data is None + # ): + # warnings.warn( + # "Input and quantizer do not have matching usages. " + # "Dequantizing and requantizing to Float8BlockwiseQTensor." + # ) + # inp = quantizer(inp.dequantize()) + + # # Construct Float8BlockwiseQTensor output tensor + # out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # # Coalesce NCCL collectives + # with torch.distributed._coalescing_manager( + # group=process_group, + # device=device, + # async_ops=async_op, + # ) as coalescing_manager: + + # # Gather Float8BlockwiseQTensor data for row-wise usage + # if quantizer.rowwise_usage: + + # # Remove padding from Float8BlockwiseQTensor scale-inverses + # # TODO: figure out scale inv tensor shape here + + # # Launch all-gathers + # torch.distributed.all_gather_into_tensor( + # out_scale_inv, + # in_scale_inv, + # group=process_group, + # ) + # torch.distributed.all_gather_into_tensor( + # out._rowwise_data, + # inp._rowwise_data, + # group=process_group, + # ) + + # # Gather Float8BlockwiseQTensor data for column-wise usage + # if quantizer.columnwise_usage: + + # # Remove padding from Float8BlockwiseQTensor scale-inverses + # # TODO: figure out columnwise scale inv tensor shape here + + # # Launch all-gathers + # torch.distributed.all_gather_into_tensor( + # out_scale_inv, + # in_scale_inv, + # group=process_group, + # ) + # torch.distributed.all_gather_into_tensor( + # out._columnwise_data, + # inp._columnwise_data, + # group=process_group, + # ) + + # handle = coalescing_manager if async_op else None + + # # Unlink MXFP8, this fp8 blockwise tensor should also work with Hopper + # # This means that we need to transpose the gathered columnwise data + # # Example usage is grad_output tensor, ie. dY in linear backward + # # We want to gather two FP8 tensors (rowwise and columnwise) along dim0 + # # and then transpose the columnwise data to match the rowwise data + # # Make sure FP8 transpose is populated if needed + # needs_transpose = ( + # quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported() + # ) + # if needs_transpose: + # if handle is not None: + # handle.wait() + # handle = None + # # TODO: this transpose will transpose both data and scale inverses + # # out._create_transpose() + + # return out, handle + def _all_gather_mxfp8( inp: torch.Tensor, @@ -1099,6 +1251,16 @@ def gather_along_first_dim( quantizer=quantizer, out_shape=out_shape, ) + + # FP8 block scaling case, block length = 128 + if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + return _all_gather_fp8_blockwise( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) # MXFP8 case if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 745b3caf17..c626a263c7 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -241,6 +241,8 @@ def forward( ln_out_total = None ub_obj_fprop = None if with_input_all_gather: + # TODO: force BF16 allgather in the case when fp8 gather is not fully supported + force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered # norm output will be returned @@ -252,7 +254,7 @@ def forward( ln_out_total = input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not (with_quantized_norm or force_high_precision_gather): ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: @@ -642,6 +644,9 @@ def backward( if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) + elif isinstance(quantizer, Float8BlockQuantizer): + # TODO: I had to set both True even though only columnwise is used, beucase there is no columnwise only kernel + quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9b59fa9ab9..f856b2844f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -262,6 +262,8 @@ def forward( ln_out_total = None ub_obj_lnout = None if sequence_parallel: + # TODO: force BF16 allgather in the case when fp8 gather is not fully supported + force_high_precision_gather = isinstance(fc1_input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered # norm output will be returned @@ -273,7 +275,7 @@ def forward( ln_out_total = fc1_input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not (with_quantized_norm or force_high_precision_gather): ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -707,6 +709,9 @@ def backward( if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) + elif isinstance(quantizer, Float8BlockQuantizer): + # TODO: I had to set both True even though only columnwise is used, beucase there is no columnwise only kernel + quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index bd2811cdb3..50c73d5faf 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -59,7 +59,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from transformer_engine.common.recipe import Recipe @@ -144,19 +144,25 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - if not isinstance(inputmat, QuantizedTensor): - columnwise_usage = backward_needs_input and isinstance( - input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer) + # TODO: force BF16 allgather in the case when fp8 gather is not fully supported + force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) + if force_high_precision_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group, quantizer=input_quantizer) + else: + if not isinstance(inputmat, QuantizedTensor): + columnwise_usage = backward_needs_input and isinstance( + input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer) + ) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + inputmat = input_quantizer(inputmat) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, ) - input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - inputmat = input_quantizer(inputmat) - own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=False) - inputmat_total, _ = gather_along_first_dim( - inputmat, - tp_group, - quantizer=input_quantizer, - ) else: if ( FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() @@ -517,6 +523,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) + elif isinstance(quantizer, Float8BlockQuantizer): + # TODO: I had to set both True even though only columnwise is used, beucase there is no columnwise only kernel + quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 138d1fd29e..2bf78814bd 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -167,6 +167,18 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: for i in range(len(shape) - 1): colwise_shape.append(shape[i]) return tuple(colwise_shape) + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + # if inp.ndim < 2: + # return False + # if inp.shape[-1] % self.block_len != 0: + # return False + # if math.prod(inp.shape[:-1]) % self.block_len != 0: + # return False + # return True + + return False # TODO: remove this, returning False for now to trigger BF16 allgather def make_empty( self, From c64d0e705cff31e5a7c7024e08330de51e67ee02 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 10:33:04 -0700 Subject: [PATCH 14/53] Edit and cleanup BF16 gather code. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 118 +++--------------- .../pytorch/module/layernorm_linear.py | 5 +- .../pytorch/module/layernorm_mlp.py | 5 +- transformer_engine/pytorch/module/linear.py | 9 +- .../pytorch/tensor/float8_blockwise_tensor.py | 17 +-- 5 files changed, 35 insertions(+), 119 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 0dc4777838..7cfb8c99d5 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -937,6 +937,7 @@ def _all_gather_fp8( return out, handle + def _all_gather_fp8_blockwise( inp: torch.Tensor, process_group: dist_group_type, @@ -957,45 +958,38 @@ def _all_gather_fp8_blockwise( dtype = inp.dtype elif isinstance(inp, Float8BlockwiseQTensorBase): if inp._rowwise_data is not None: - in_shape = inp._rowwise_data.device.size() + in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device - dtype = inp._rowwise_data.dtype elif inp._columnwise_data is not None: - in_shape = inp._columnwise_data.device.size() + in_shape = inp._columnwise_data.size() device = inp._columnwise_data.device - dtype = inp._columnwise_data.dtype else: raise ValueError("Got Float8BlockwiseQTensor input tensor without any data") - dtype = torch.bfloat16 + dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " f"found {inp.__class__.__name__})" ) - world_size = get_distributed_world_size(process_group) # Check that quantizer is valid if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") - else: - assert quantizer.block_scaling_dim == 1 and quantizer.block_len == 128, "Only 1D blockwise quantization is supported for allgather" + elif not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): + raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") # Output tensor dims if out_shape is None: out_shape = list(inp.size()) out_shape[0] *= world_size - # Since gather is for inp tensor or grad tensor, the block scaling dimension should be 1D 1x128 - # This means that we need to do quantization in both directions, and then gather twice - # This also requires a bidirectional quantization without transpose kernel - # Scaling factors should be in compact format without sizzling fused for the purpose of communication # Doing BF16 gather for now as baseline because it's simpler - if ( - not isinstance(inp, Float8BlockwiseQTensorBase) - and quantizer is not None - and not quantizer.is_quantizable(inp) - ): + if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: + # TODO: Because the quantization is after the gather, + # async_op is ignored. A better solution would be + # to quantize as a future callback and chain async + # handles. out = torch.empty( out_shape, dtype=dtype, @@ -1005,88 +999,12 @@ def _all_gather_fp8_blockwise( torch.distributed.all_gather_into_tensor(out, inp, group=process_group) out = quantizer(out) return out, None - - raise Exception("Not implemented for doing fp8 blockwise allgather") - - # TODO: finish the scheleton below for doing fp8 blockwise allgather - - # # Cast input tensor to Float8BlockwiseQTensor with required data - # if not isinstance(inp, Float8BlockwiseQTensorBase): - # inp = quantizer(inp) - # elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( - # quantizer.columnwise_usage and inp._columnwise_data is None - # ): - # warnings.warn( - # "Input and quantizer do not have matching usages. " - # "Dequantizing and requantizing to Float8BlockwiseQTensor." - # ) - # inp = quantizer(inp.dequantize()) - - # # Construct Float8BlockwiseQTensor output tensor - # out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - - # # Coalesce NCCL collectives - # with torch.distributed._coalescing_manager( - # group=process_group, - # device=device, - # async_ops=async_op, - # ) as coalescing_manager: - - # # Gather Float8BlockwiseQTensor data for row-wise usage - # if quantizer.rowwise_usage: - - # # Remove padding from Float8BlockwiseQTensor scale-inverses - # # TODO: figure out scale inv tensor shape here - - # # Launch all-gathers - # torch.distributed.all_gather_into_tensor( - # out_scale_inv, - # in_scale_inv, - # group=process_group, - # ) - # torch.distributed.all_gather_into_tensor( - # out._rowwise_data, - # inp._rowwise_data, - # group=process_group, - # ) - - # # Gather Float8BlockwiseQTensor data for column-wise usage - # if quantizer.columnwise_usage: - - # # Remove padding from Float8BlockwiseQTensor scale-inverses - # # TODO: figure out columnwise scale inv tensor shape here - - # # Launch all-gathers - # torch.distributed.all_gather_into_tensor( - # out_scale_inv, - # in_scale_inv, - # group=process_group, - # ) - # torch.distributed.all_gather_into_tensor( - # out._columnwise_data, - # inp._columnwise_data, - # group=process_group, - # ) - - # handle = coalescing_manager if async_op else None - - # # Unlink MXFP8, this fp8 blockwise tensor should also work with Hopper - # # This means that we need to transpose the gathered columnwise data - # # Example usage is grad_output tensor, ie. dY in linear backward - # # We want to gather two FP8 tensors (rowwise and columnwise) along dim0 - # # and then transpose the columnwise data to match the rowwise data - # # Make sure FP8 transpose is populated if needed - # needs_transpose = ( - # quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported() - # ) - # if needs_transpose: - # if handle is not None: - # handle.wait() - # handle = None - # # TODO: this transpose will transpose both data and scale inverses - # # out._create_transpose() - - # return out, handle + # Implementation of fp8 gather needs to account for: + # * Getting columnwise data as a transpose of how it is stored for GEMMS. + # * Gathering non GEMM swizzled scales. + # * Refer to scaffold code when implementing at: + # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 + raise NotImplementedError("fp8 blockwise allgather not yet implemented") def _all_gather_mxfp8( @@ -1251,7 +1169,7 @@ def gather_along_first_dim( quantizer=quantizer, out_shape=out_shape, ) - + # FP8 block scaling case, block length = 128 if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): return _all_gather_fp8_blockwise( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index c626a263c7..a49bd1e109 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -241,7 +241,7 @@ def forward( ln_out_total = None ub_obj_fprop = None if with_input_all_gather: - # TODO: force BF16 allgather in the case when fp8 gather is not fully supported + # TODO: Support FP8 gather of Float8BlockQuantizer. force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered @@ -645,7 +645,8 @@ def backward( # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) elif isinstance(quantizer, Float8BlockQuantizer): - # TODO: I had to set both True even though only columnwise is used, beucase there is no columnwise only kernel + # TODO: Add support in the quantizer for + # rowwise=False, columnwise=True and configure that here. quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f856b2844f..f9b9c86163 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -262,7 +262,7 @@ def forward( ln_out_total = None ub_obj_lnout = None if sequence_parallel: - # TODO: force BF16 allgather in the case when fp8 gather is not fully supported + # TODO: Support FP8 allgather of Float8Block quantization. force_high_precision_gather = isinstance(fc1_input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered @@ -710,7 +710,8 @@ def backward( # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) elif isinstance(quantizer, Float8BlockQuantizer): - # TODO: I had to set both True even though only columnwise is used, beucase there is no columnwise only kernel + # TODO: Support rowwise=False, columnwise=True in quantizer + # and configure that here. quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 50c73d5faf..b6200f3086 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -144,11 +144,13 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - # TODO: force BF16 allgather in the case when fp8 gather is not fully supported + # TODO: Support FP8 allgather for FP8 block quantization. force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) if force_high_precision_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) - inputmat_total, _ = gather_along_first_dim(inputmat, tp_group, quantizer=input_quantizer) + inputmat_total, _ = gather_along_first_dim( + inputmat, tp_group, quantizer=input_quantizer + ) else: if not isinstance(inputmat, QuantizedTensor): columnwise_usage = backward_needs_input and isinstance( @@ -524,7 +526,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) elif isinstance(quantizer, Float8BlockQuantizer): - # TODO: I had to set both True even though only columnwise is used, beucase there is no columnwise only kernel + # TODO: Support rowwise=False, columnwise=True in quantizer + # and configure that here. quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 2bf78814bd..adfa24127d 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -167,18 +167,11 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: for i in range(len(shape) - 1): colwise_shape.append(shape[i]) return tuple(colwise_shape) - - def is_quantizable(self, inp: torch.Tensor) -> bool: - """Returns whether or not given inp can be quantized""" - # if inp.ndim < 2: - # return False - # if inp.shape[-1] % self.block_len != 0: - # return False - # if math.prod(inp.shape[:-1]) % self.block_len != 0: - # return False - # return True - - return False # TODO: remove this, returning False for now to trigger BF16 allgather + + # TODO(kwyss): With FP8 gather support, we need to implement a + # shape/layout/swizzle check to know whether FP8 gather works + # cleanly by stacking data without aliasing tiles and whether + # the scales also stack on the proper dimensions. def make_empty( self, From 99b5908e5135c9b22efdf593d5de2b52f547848a Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 18:16:02 -0700 Subject: [PATCH 15/53] Update test import. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_scaling_exact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 44af51e78a..d96d482ce9 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -21,7 +21,7 @@ BlockwiseQuantizerReference, QuantizeResult, ) -from tests.pytorch.test_float8_current_scaling_exact import ( +from test_float8_current_scaling_exact import ( TestFP8RecipeLinearBase, TestFP8RecipeLayerNormLinearBase, ) From 6daa8df98c4503577d67e4c9ff7a57d722bcf102 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Thu, 3 Apr 2025 21:31:38 -0700 Subject: [PATCH 16/53] support columnwise only mode to 1D quantize kernel Signed-off-by: zhongboz --- tests/pytorch/test_float8blockwisetensor.py | 16 +++++++ transformer_engine/common/common.h | 15 +++++++ .../common/transpose/cast_transpose.h | 5 ++- .../quantize_transpose_vector_blockwise.cu | 44 ++++++++++++------- .../common/util/cast_kernels.cuh | 8 +++- .../pytorch/module/layernorm_linear.py | 4 -- .../pytorch/module/layernorm_mlp.py | 4 -- transformer_engine/pytorch/module/linear.py | 4 -- .../pytorch/tensor/float8_blockwise_tensor.py | 1 - 9 files changed, 68 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index d030426b74..e403ae14a0 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -150,6 +150,22 @@ def test_quantize_dequantize_dtypes( ) self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1]) + def test_quantize_dequantize_columnwise_only( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=False, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True) + @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] ) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b1fe436379..aa71d70c8a 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -44,6 +44,21 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } +// enum class for rowwise usage: None, ROWWISE +// there is no need to transpose the rowwise usage tensor, so no ROWWISE_TRANSPOSE +enum class RowwiseUsageOption { + NONE, // No rowwise data + ROWWISE // Rowwise data +}; + +// enum class for columnwise usage: None, COLUMNWISE, COLUMNWISE_TRANSPOSE +// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling +enum class ColumnwiseUsageOption { + NONE, // No columnwise data + COLUMNWISE, // TODO(zhongbo): Columnwise data (current not supported getting columnwise data without transposing) + COLUMNWISE_TRANSPOSE // Columnwise data with transpose +}; + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 298d087337..bfe2e02757 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -32,8 +32,9 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, - const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + RowwiseUsageOption rowwise_option, + ColumnwiseUsageOption columnwise_option, + const bool pow_2_scale, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 732d97999c..661478f594 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -146,7 +146,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, - bool return_transpose, bool pow_2_scaling) { + RowwiseUsageOption rowwise_option, + ColumnwiseUsageOption columnwise_option, + const bool pow_2_scaling) { + + bool return_rowwise = rowwise_option == RowwiseUsageOption::ROWWISE; + bool return_columnwise_transpose = columnwise_option == ColumnwiseUsageOption::COLUMNWISE_TRANSPOSE; + using SMemVec = Vec; using OVec = Vec; union IVec { @@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) __syncthreads(); // Step 2: Cast and store to output_c - { + if (return_rowwise){ constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; @@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Step 3: Transpose, cast and store to output_t - if (return_transpose) { + if (return_columnwise_transpose) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); @@ -389,10 +395,15 @@ namespace transformer_engine::detail { void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, - const bool return_transpose, const bool pow2_scale, + RowwiseUsageOption rowwise_option, + ColumnwiseUsageOption columnwise_option, + const bool pow2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); - NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + // assert that rowwise_option and columnwise_option are not both NONE + NVTE_CHECK(rowwise_option != RowwiseUsageOption::NONE || columnwise_option != ColumnwiseUsageOption::NONE, + "rowwise_option and columnwise_option cannot both be NONE"); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -408,21 +419,20 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor } // Options for scale layout of cuBLAS GEMM kernel. - - NVTE_CHECK(input.shape.size() == output.shape.size(), - "Input and output must have the same shape."); - size_t scale_stride_x = 0; size_t scale_stride_y = 0; - NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); - size_t scale_k = scale_inv.shape[1]; - scale_stride_x = scale_k; - scale_stride_y = 1; - size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; - if (return_transpose) { + if (rowwise_option != RowwiseUsageOption::NONE) { + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = scale_k; + scale_stride_y = 1; + } + + if (columnwise_option != ColumnwiseUsageOption::NONE) { NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); if (output_t.shape.size() > 0) { @@ -469,8 +479,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(output_t.dptr), reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, - pow2_scale);) // kAligned + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, + columnwise_option, pow2_scale);) // kAligned ) // OutputType ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 412a6f6ef0..f10fefca36 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1279,11 +1279,17 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); constexpr bool force_pow_2_scales = true; + RowwiseUsageOption rowwise_option = output_tensor->has_data() + ? RowwiseUsageOption::ROWWISE + : RowwiseUsageOption::NONE; + ColumnwiseUsageOption columnwise_option = output_tensor->has_columnwise_data() + ? ColumnwiseUsageOption::COLUMNWISE_TRANSPOSE + : ColumnwiseUsageOption::NONE; quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, /*epsilon=*/0.0, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + rowwise_option, columnwise_option, force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a49bd1e109..7e057b3139 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -644,10 +644,6 @@ def backward( if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) - elif isinstance(quantizer, Float8BlockQuantizer): - # TODO: Add support in the quantizer for - # rowwise=False, columnwise=True and configure that here. - quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f9b9c86163..b0b1fcbf02 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -709,10 +709,6 @@ def backward( if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) - elif isinstance(quantizer, Float8BlockQuantizer): - # TODO: Support rowwise=False, columnwise=True in quantizer - # and configure that here. - quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b6200f3086..d35bb4ece3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -525,10 +525,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) - elif isinstance(quantizer, Float8BlockQuantizer): - # TODO: Support rowwise=False, columnwise=True in quantizer - # and configure that here. - quantizer.set_usage(rowwise=True, columnwise=True) else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index adfa24127d..7dd79b458b 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -44,7 +44,6 @@ def __init__( block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - assert rowwise self.dtype = fp8_dtype self.block_len = 128 self.force_pow_2_scales = force_pow_2_scales From fb661482525880183be92c46ad5416ec751caa00 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 10:15:07 -0700 Subject: [PATCH 17/53] Format and move enum Signed-off-by: Keith Wyss --- tests/pytorch/test_float8blockwisetensor.py | 4 +- transformer_engine/common/common.h | 15 ------ .../common/transpose/cast_transpose.h | 27 +++++++++- .../quantize_transpose_vector_blockwise.cu | 52 ++++++++++--------- .../common/util/cast_kernels.cuh | 15 +++--- 5 files changed, 63 insertions(+), 50 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index e403ae14a0..2558a1d190 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -164,7 +164,9 @@ def test_quantize_dequantize_columnwise_only( columnwise=True, block_scaling_dim=block_scaling_dim, ) - self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True) + self._test_quantize_dequantize( + quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True + ) @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index aa71d70c8a..b1fe436379 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -44,21 +44,6 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } -// enum class for rowwise usage: None, ROWWISE -// there is no need to transpose the rowwise usage tensor, so no ROWWISE_TRANSPOSE -enum class RowwiseUsageOption { - NONE, // No rowwise data - ROWWISE // Rowwise data -}; - -// enum class for columnwise usage: None, COLUMNWISE, COLUMNWISE_TRANSPOSE -// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling -enum class ColumnwiseUsageOption { - NONE, // No columnwise data - COLUMNWISE, // TODO(zhongbo): Columnwise data (current not supported getting columnwise data without transposing) - COLUMNWISE_TRANSPOSE // Columnwise data with transpose -}; - inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index bfe2e02757..3148b4f720 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -29,11 +29,34 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor const bool return_transpose, const bool pow_2_scale, cudaStream_t stream); +// enum class for rowwise usage +enum class FP8BlockwiseRowwiseOption { + // No rowwise data + NONE, + // Rowwise data, scales in GEMM format + ROWWISE + // TODO: FP8 all gather requires some changes. + // 1. Compact scales are better for gathering than the GEMM format. +}; + +// enum class for columnwise usage +// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling +enum class FP8BlockwiseColumnwiseOption { + // No columnwise data + NONE, + // Columnwise data transposed from original shape. + // Scales in GEMM format corresponding to GEMM ingesting transposed column data. + COLUMNWISE_TRANSPOSE + // TODO: FP8 all gather requires some changes. + // 1. The transpose gets in the way of the all gather. + // 2. Compact scales are better for gathering than the GEMM format. +}; + void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, - RowwiseUsageOption rowwise_option, - ColumnwiseUsageOption columnwise_option, + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, const bool pow_2_scale, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 661478f594..91f73dea1e 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -16,11 +16,15 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" #include "common/utils.cuh" namespace transformer_engine { namespace { +using transformer_engine::detail::FP8BlockwiseColumnwiseOption; +using transformer_engine::detail::FP8BlockwiseRowwiseOption; + // clang-format off /* @@ -138,20 +142,16 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); template -__global__ void __launch_bounds__(kThreadsPerBlock) - block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, - OType* const output_t, CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, const size_t row_length, - const size_t num_rows, const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, const float epsilon, - RowwiseUsageOption rowwise_option, - ColumnwiseUsageOption columnwise_option, - const bool pow_2_scaling) { - - bool return_rowwise = rowwise_option == RowwiseUsageOption::ROWWISE; - bool return_columnwise_transpose = columnwise_option == ColumnwiseUsageOption::COLUMNWISE_TRANSPOSE; +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, + FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scaling) { + bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; + bool return_columnwise_transpose = + columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; using SMemVec = Vec; using OVec = Vec; @@ -209,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) __syncthreads(); // Step 2: Cast and store to output_c - if (return_rowwise){ + if (return_rowwise) { constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; @@ -395,14 +395,14 @@ namespace transformer_engine::detail { void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, - RowwiseUsageOption rowwise_option, - ColumnwiseUsageOption columnwise_option, - const bool pow2_scale, - cudaStream_t stream) { + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); // assert that rowwise_option and columnwise_option are not both NONE - NVTE_CHECK(rowwise_option != RowwiseUsageOption::NONE || columnwise_option != ColumnwiseUsageOption::NONE, + NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE || + columnwise_option != FP8BlockwiseColumnwiseOption::NONE, "rowwise_option and columnwise_option cannot both be NONE"); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; @@ -424,7 +424,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; - if (rowwise_option != RowwiseUsageOption::NONE) { + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, + "Unexpected rowwise enum value"); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); size_t scale_k = scale_inv.shape[1]; @@ -432,7 +434,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor scale_stride_y = 1; } - if (columnwise_option != ColumnwiseUsageOption::NONE) { + if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { + NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE, + "Unexpected columnwise enum value"); NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); if (output_t.shape.size() > 0) { @@ -481,8 +485,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, columnwise_option, pow2_scale);) // kAligned - ) // OutputType - ) // InputType + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index f10fefca36..39c7769d69 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1279,17 +1279,16 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); constexpr bool force_pow_2_scales = true; - RowwiseUsageOption rowwise_option = output_tensor->has_data() - ? RowwiseUsageOption::ROWWISE - : RowwiseUsageOption::NONE; - ColumnwiseUsageOption columnwise_option = output_tensor->has_columnwise_data() - ? ColumnwiseUsageOption::COLUMNWISE_TRANSPOSE - : ColumnwiseUsageOption::NONE; + FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() + ? FP8BlockwiseRowwiseOption::ROWWISE + : FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = + output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE + : FP8BlockwiseColumnwiseOption::NONE; quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, - rowwise_option, columnwise_option, force_pow_2_scales, stream); + /*epsilon=*/0.0, rowwise_option, columnwise_option, force_pow_2_scales, stream); break; } default: From 70c503462026689f96d47d632b5a04b31d365c01 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 10:37:26 -0700 Subject: [PATCH 18/53] Skip alloc. Signed-off-by: Keith Wyss --- .../_internal/float8_blockwise_tensor_base.py | 4 ++-- .../pytorch/tensor/float8_blockwise_tensor.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index c33535f13c..7dc380606d 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase: def __new__( cls, *args, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 7dd79b458b..df7c9b2690 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -185,13 +185,16 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, - dtype=torch.float32, - device=device, - ) + data = None + scale_inv = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None From d81946ca569525988c998e0e87a888855e07d7e6 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Fri, 4 Apr 2025 02:03:50 -0700 Subject: [PATCH 19/53] try async bf16 gather Signed-off-by: zhongboz --- transformer_engine/pytorch/distributed.py | 7 ++++--- transformer_engine/pytorch/module/layernorm_linear.py | 2 ++ transformer_engine/pytorch/module/layernorm_mlp.py | 2 ++ transformer_engine/pytorch/module/linear.py | 2 ++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 7cfb8c99d5..98024d87a4 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -996,9 +996,10 @@ def _all_gather_fp8_blockwise( device=device, memory_format=torch.contiguous_format, ) - torch.distributed.all_gather_into_tensor(out, inp, group=process_group) - out = quantizer(out) - return out, None + handle = torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=async_op) + if handle is None: + out = quantizer(out) + return out, handle # Implementation of fp8 gather needs to account for: # * Getting columnwise data as a transpose of how it is stored for GEMMS. # * Gathering non GEMM swizzled scales. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7e057b3139..00ef761fcf 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -733,6 +733,8 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.input_quantizer is not None and not isinstance(ln_out_total, QuantizedTensor): + ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b0b1fcbf02..dff3eebe64 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -935,6 +935,8 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.fc1_input_quantizer is not None and not isinstance(ln_out_total, QuantizedTensor): + ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d35bb4ece3..366725f952 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -619,6 +619,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None + if ctx.input_quantizer is not None and not isinstance(inputmat_total, QuantizedTensor): + inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data if isinstance(inputmat_total, QuantizedTensor): From a577801aee3d7d95e99804ebeb18522a24580c6e Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 15:14:31 -0700 Subject: [PATCH 20/53] Format python code. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 4 +++- transformer_engine/pytorch/module/layernorm_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- transformer_engine/pytorch/module/linear.py | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 98024d87a4..6372a9961d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -996,7 +996,9 @@ def _all_gather_fp8_blockwise( device=device, memory_format=torch.contiguous_format, ) - handle = torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=async_op) + handle = torch.distributed.all_gather_into_tensor( + out, inp, group=process_group, async_op=async_op + ) if handle is None: out = quantizer(out) return out, handle diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 00ef761fcf..431d1687b1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -733,7 +733,9 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.input_quantizer is not None and not isinstance(ln_out_total, QuantizedTensor): + if ctx.input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index dff3eebe64..0d81a76465 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -935,7 +935,9 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fc1_input_quantizer is not None and not isinstance(ln_out_total, QuantizedTensor): + if ctx.fc1_input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 366725f952..8bfb63dad8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -619,7 +619,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.input_quantizer is not None and not isinstance(inputmat_total, QuantizedTensor): + if ctx.input_quantizer is not None and not isinstance( + inputmat_total, QuantizedTensor + ): inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data From a6e9d28130fb530361ecefeedfdb57d6ef784b1e Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 15:33:30 -0700 Subject: [PATCH 21/53] Document and type code. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 29 ++++++++++++++----- .../pytorch/module/layernorm_linear.py | 2 ++ .../pytorch/module/layernorm_mlp.py | 2 ++ transformer_engine/pytorch/module/linear.py | 2 ++ 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 6372a9961d..1ec28dd68e 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -945,8 +945,20 @@ def _all_gather_fp8_blockwise( async_op: bool = False, quantizer: Optional[Quantizer] = None, out_shape: Optional[list[int]] = None, -) -> tuple[Float8BlockwiseQTensorBase, Optional[torch.distributed.Work]]: - """All-gather FP8 tensor along first dimension for blockwise quantization.""" +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: + """ + All-gather FP8 tensor along first dimension for blockwise quantization. + + Usually returns a Float8BlockwiseQTensorBase. In the case that the + all gather is done asynchronously, but quantization is deferred until + after the gather, this returns the full precision tensor. + + NOTE: It would be preferable to always honor the quantizer and quantize + the result, and this may be possible via calling `get_future()` on the + asynchronous handler, calling `make_empty()` on the quantizer, and chaining + a callback with `then()` to perform `update_quantized`. This invites + complications and also requires pre-allocating the quantized tensor. + """ # Input tensor attributes in_shape: Iterable[int] @@ -986,10 +998,6 @@ def _all_gather_fp8_blockwise( # Doing BF16 gather for now as baseline because it's simpler if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: - # TODO: Because the quantization is after the gather, - # async_op is ignored. A better solution would be - # to quantize as a future callback and chain async - # handles. out = torch.empty( out_shape, dtype=dtype, @@ -999,6 +1007,7 @@ def _all_gather_fp8_blockwise( handle = torch.distributed.all_gather_into_tensor( out, inp, group=process_group, async_op=async_op ) + # NOTE: if async, out will not be quantized. if handle is None: out = quantizer(out) return out, handle @@ -1148,7 +1157,13 @@ def gather_along_first_dim( async_op: bool = False, quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: - """All-gather tensors and concatenate along first dimension.""" + """ + All-gather tensors and concatenate along first dimension. + + NOTE: Caller should be aware that there are asynchronous cases + where quantizer is not None, but the output will not be quantized. + This affects Float8BlockQuantizer. + """ # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 431d1687b1..179647954a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -736,6 +736,8 @@ def backward( if ctx.input_quantizer is not None and not isinstance( ln_out_total, QuantizedTensor ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 0d81a76465..025b991339 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -938,6 +938,8 @@ def backward( if ctx.fc1_input_quantizer is not None and not isinstance( ln_out_total, QuantizedTensor ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8bfb63dad8..afb327ae9e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -622,6 +622,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.input_quantizer is not None and not isinstance( inputmat_total, QuantizedTensor ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data From 52c18a13829cad184f71da536a024d1a6d85d302 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 15:45:14 -0700 Subject: [PATCH 22/53] Update pytorch lint errors. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 10 +++------- transformer_engine/pytorch/module/layernorm_linear.py | 3 +-- transformer_engine/pytorch/module/layernorm_mlp.py | 6 ++---- transformer_engine/pytorch/module/linear.py | 3 +-- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 1ec28dd68e..e8671bc278 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -24,7 +24,7 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer -from .tensor.float8_blockwise_tensor import Float8BlockQuantizer, Float8BlockwiseQTensor +from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase @@ -961,22 +961,18 @@ def _all_gather_fp8_blockwise( """ # Input tensor attributes - in_shape: Iterable[int] device: torch.device dtype: torch.dtype if isinstance(inp, torch.Tensor): - in_shape = inp.size() device = inp.device dtype = inp.dtype elif isinstance(inp, Float8BlockwiseQTensorBase): if inp._rowwise_data is not None: - in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device elif inp._columnwise_data is not None: - in_shape = inp._columnwise_data.size() device = inp._columnwise_data.device else: - raise ValueError("Got Float8BlockwiseQTensor input tensor without any data") + raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. else: raise ValueError( @@ -988,7 +984,7 @@ def _all_gather_fp8_blockwise( # Check that quantizer is valid if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") - elif not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): + if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") # Output tensor dims diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 179647954a..06bde8ab12 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -61,7 +61,6 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param -from transformer_engine.common.recipe import Recipe from ..cpp_extensions import ( general_gemm, ) @@ -241,7 +240,7 @@ def forward( ln_out_total = None ub_obj_fprop = None if with_input_all_gather: - # TODO: Support FP8 gather of Float8BlockQuantizer. + # TODO(kwyss): Support FP8 allgather for FP8 block quantization. force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 025b991339..f3331e5105 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -52,7 +52,6 @@ in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, ) -from transformer_engine.common.recipe import Recipe from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing @@ -117,8 +116,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "srelu": (tex.srelu, tex.dsrelu, None), } - else: - raise NotImplementedError(f"Unhandled recipe type {recipe}") + raise NotImplementedError(f"Unhandled recipe type {recipe}") def _act_func(activation: str, recipe: Optional[Recipe] = None): @@ -262,7 +260,7 @@ def forward( ln_out_total = None ub_obj_lnout = None if sequence_parallel: - # TODO: Support FP8 allgather of Float8Block quantization. + # TODO(kwyss): Support FP8 allgather of Float8Block quantization. force_high_precision_gather = isinstance(fc1_input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index afb327ae9e..2cca226708 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -61,7 +61,6 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param -from transformer_engine.common.recipe import Recipe __all__ = ["Linear"] @@ -144,7 +143,7 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - # TODO: Support FP8 allgather for FP8 block quantization. + # TODO(kwyss): Support FP8 allgather for FP8 block quantization. force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) if force_high_precision_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) From 80057a67a7418e203dec2bbb16d11a1aa4bc93da Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 18:21:55 -0700 Subject: [PATCH 23/53] Dont set high precision dtype. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index df7c9b2690..695c5ffb8c 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -496,7 +496,6 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv - dst.dtype = src.dtype # Check that tensor dimensions match if ( From 77cfef4e128d353ef778e5ee941f9cecf422ef9a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sat, 5 Apr 2025 04:09:39 +0000 Subject: [PATCH 24/53] Add test for sanity and CG; fix CG for sequential? Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_cuda_graphs.py | 1 + tests/pytorch/test_sanity.py | 1 + transformer_engine/pytorch/fp8.py | 133 ++++++++++++++------------- transformer_engine/pytorch/ops/op.py | 2 +- 4 files changed, 73 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 5a1dc3f732..5f896aaafe 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -55,6 +55,7 @@ class ModelConfig: recipe.DelayedScaling(), recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] # Supported data types diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 69ac8f7996..03e8563975 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -106,6 +106,7 @@ def is_fp8_supported(self): None, # Test non-FP8 recipe.MXFP8BlockScaling(), # Test default recipe.Float8CurrentScaling(), # Test default + recipe.Float8BlockScaling(), # Test default recipe.DelayedScaling(), # Test default recipe.DelayedScaling( # Test most_recent algo amax_history_len=16, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 681bddd774..030d62cfaa 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,7 +6,6 @@ from __future__ import annotations import abc -import itertools import os from contextlib import contextmanager from collections import deque @@ -827,7 +826,7 @@ def create( ) @abc.abstractmethod - def make_quantizers(self) -> list: + def make_quantizers(self, fp8_output: bool = True) -> list: """Convert recipe state to quantizers. Quantizers are builder classes for quantized tensors. They are @@ -877,7 +876,7 @@ def __init__( device=device, ) - def make_quantizers(self) -> list: + def make_quantizers(self, fp8_output: bool = True) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.float8_tensor import Float8Quantizer @@ -917,7 +916,7 @@ def __init__( device = torch.device("cuda") self.device = device - def make_quantizers(self) -> list: + def make_quantizers(self, fp8_output: bool = True) -> list: from .tensor.float8_tensor import Float8CurrentScalingQuantizer return [ @@ -954,7 +953,7 @@ def __init__( if device is None: device = torch.device("cuda") - def make_quantizers(self) -> list: + def make_quantizers(self, fp8_output: bool = True) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.mxfp8_tensor import MXFP8Quantizer @@ -994,7 +993,7 @@ def __init__( device = torch.device("cuda") self.device = device - def make_quantizers(self) -> list: + def make_quantizers(self, fp8_output: bool = True) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.float8_blockwise_tensor import Float8BlockQuantizer @@ -1002,65 +1001,73 @@ def make_quantizers(self) -> list: # The index convention (coming from base.py set_meta_tensor) # is somewhat awkward, and doesn't play nicely with QuantizeOp, # which is not associated with a GEMM. - assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 3) - ] - ) - ) - assert self.mode == "backward", f"Unexpected mode {self.mode}" - assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), + # Handle ops that do not support FP8 gemm output. + quantizers_per_gemm = 3 if fp8_output else 2 + assert self.num_quantizers % quantizers_per_gemm == 0 + num_gemms = self.num_quantizers // quantizers_per_gemm + + quantizers = [] + for _ in range(num_gemms): + quantizers.append( + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ) + ) + quantizers.append( + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ) + ) + if fp8_output: + quantizers.append( Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, + fp8_dtype=self.qx_dtype, rowwise=True, columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 2) - ] + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ) + ) + return quantizers + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + quantizers_per_gemm = 2 if fp8_output else 1 + assert self.num_quantizers % quantizers_per_gemm == 0 + num_gemms = self.num_quantizers // quantizers_per_gemm + + quantizers = [] + for _ in range(num_gemms): + quantizers.append( + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ) ) - ) + if fp8_output: + quantizers.append( + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ) + ) + return quantizers diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 2e212e15f4..7b28fd5b66 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -234,7 +234,7 @@ def _reset_quantization_recipe_state( } # Construct builder class for quantized tensors - self._quantizers[mode] = recipe_state.make_quantizers() + self._quantizers[mode] = recipe_state.make_quantizers(fp8_output=False) def _update_quantization_recipe_state( self, From dbcff160a47a2eb83b6ef2aa45e76123664979d2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 7 Apr 2025 15:04:33 -0700 Subject: [PATCH 25/53] Keep make_quantizers API stable Update num_quantizers instead to pass cuda_graph tests. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/fp8.py | 133 +++++++++--------- .../pytorch/ops/basic/basic_linear.py | 10 +- transformer_engine/pytorch/ops/op.py | 5 +- 3 files changed, 74 insertions(+), 74 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 030d62cfaa..7517c6e4fd 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import itertools import os from contextlib import contextmanager from collections import deque @@ -826,7 +827,7 @@ def create( ) @abc.abstractmethod - def make_quantizers(self, fp8_output: bool = True) -> list: + def make_quantizers(self) -> list: """Convert recipe state to quantizers. Quantizers are builder classes for quantized tensors. They are @@ -876,7 +877,7 @@ def __init__( device=device, ) - def make_quantizers(self, fp8_output: bool = True) -> list: + def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.float8_tensor import Float8Quantizer @@ -916,7 +917,7 @@ def __init__( device = torch.device("cuda") self.device = device - def make_quantizers(self, fp8_output: bool = True) -> list: + def make_quantizers(self) -> list: from .tensor.float8_tensor import Float8CurrentScalingQuantizer return [ @@ -953,7 +954,7 @@ def __init__( if device is None: device = torch.device("cuda") - def make_quantizers(self, fp8_output: bool = True) -> list: + def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.mxfp8_tensor import MXFP8Quantizer @@ -993,7 +994,7 @@ def __init__( device = torch.device("cuda") self.device = device - def make_quantizers(self, fp8_output: bool = True) -> list: + def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.float8_blockwise_tensor import Float8BlockQuantizer @@ -1001,73 +1002,65 @@ def make_quantizers(self, fp8_output: bool = True) -> list: # The index convention (coming from base.py set_meta_tensor) # is somewhat awkward, and doesn't play nicely with QuantizeOp, # which is not associated with a GEMM. - - # Handle ops that do not support FP8 gemm output. - quantizers_per_gemm = 3 if fp8_output else 2 - assert self.num_quantizers % quantizers_per_gemm == 0 - num_gemms = self.num_quantizers // quantizers_per_gemm - - quantizers = [] - for _ in range(num_gemms): - quantizers.append( - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ) + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qw_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] ) - quantizers.append( - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ) - ) - if fp8_output: - quantizers.append( + ) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return list( + itertools.chain.from_iterable( + [ + [ Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, + fp8_dtype=self.qgrad_dtype, rowwise=True, columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ) - ) - return quantizers - - assert self.mode == "backward", f"Unexpected mode {self.mode}" - quantizers_per_gemm = 2 if fp8_output else 1 - assert self.num_quantizers % quantizers_per_gemm == 0 - num_gemms = self.num_quantizers // quantizers_per_gemm - - quantizers = [] - for _ in range(num_gemms): - quantizers.append( - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ) + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] ) - if fp8_output: - quantizers.append( - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ) - ) - return quantizers + ) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb93eb5e6b..cebc2fc9c5 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -254,10 +254,16 @@ def _canonicalize_tensor_parallelism( ) def num_quantizers(self, mode: str) -> int: + # Adhere to consistent conventions of the non-fused + # module code, where fwd quantized tensors are (x, w, output) + # and bwd quantized tensors are (ygrad, dgrad) + # + # Fused code need not use the output or dgrad quantizers. + # Since fp8_output flag is not available. if mode == "forward": - return 2 + return 3 if mode == "backward": - return 1 + return 2 return 0 def reset_parameters(self) -> None: diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 7b28fd5b66..88fffe80e0 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -234,7 +234,7 @@ def _reset_quantization_recipe_state( } # Construct builder class for quantized tensors - self._quantizers[mode] = recipe_state.make_quantizers(fp8_output=False) + self._quantizers[mode] = recipe_state.make_quantizers() def _update_quantization_recipe_state( self, @@ -260,7 +260,8 @@ def _update_quantization_recipe_state( recipe_state = self._fp8_metas[mode][fp8_meta_key] need_to_reset_recipe_state = ( recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) - ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState) + ) or (recipe.fp8blockwise() and not isinstance(recipe_state, Float8BlockScaling)) if need_to_reset_recipe_state: self._reset_quantization_recipe_state(recipe=recipe) return From 9e50b6d644738ffd46695fd2aa37aede88ab41d8 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 7 Apr 2025 15:52:24 -0700 Subject: [PATCH 26/53] Fix import name. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/ops/basic/basic_linear.py | 7 +++++++ transformer_engine/pytorch/ops/op.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cebc2fc9c5..b0f67b4cb3 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -23,6 +23,7 @@ from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext @@ -489,6 +490,12 @@ def _functional_forward( "Attempting to generate MXFP8 output tensor, " "but GEMM with MXFP8 output is not supported" ) + if isinstance(output_quantizer, Float8BlockQuantizer): + raise RuntimeError( + "Attempting to generate Float8BlockQuantized output tensor, " + "but GEMM with Float8BlockQuantized output is not supported" + ) + if output_quantizer is not None: output_quantizer.set_usage(rowwise=True, columnwise=False) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 88fffe80e0..5cb698350f 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -17,6 +17,7 @@ from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -259,9 +260,13 @@ def _update_quantization_recipe_state( continue recipe_state = self._fp8_metas[mode][fp8_meta_key] need_to_reset_recipe_state = ( - recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) - ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState) - ) or (recipe.fp8blockwise() and not isinstance(recipe_state, Float8BlockScaling)) + (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) + or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + or ( + recipe.fp8blockwise() + and not isinstance(recipe_state, Float8BlockScalingRecipeState) + ) + ) if need_to_reset_recipe_state: self._reset_quantization_recipe_state(recipe=recipe) return From 0e2359184204f9cff809d279894883d19b857986 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 7 Apr 2025 16:07:34 -0700 Subject: [PATCH 27/53] Rename recipe method. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_current_scaling_exact.py | 4 ++-- tests/pytorch/test_numerics.py | 10 +++++----- transformer_engine/common/recipe/__init__.py | 2 +- transformer_engine/pytorch/fp8.py | 2 +- transformer_engine/pytorch/module/base.py | 4 +++- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 8 ++++---- transformer_engine/pytorch/ops/op.py | 2 +- 8 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index d8c19143d7..b8e5b936d1 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -100,7 +100,7 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias # Expected tensor names based on the naming template if recipe.float8_current_scaling(): scaling_type = "ScalingType.PER_TENSOR" - elif recipe.fp8blockwise(): + elif recipe.float8_block_scaling(): scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" else: scaling_type = "Unknown" @@ -443,7 +443,7 @@ def _check_golden_tensor_dumps( # Expected tensor names based on the naming template if recipe.float8_current_scaling(): scaling_type = "ScalingType.PER_TENSOR" - elif recipe.fp8blockwise(): + elif recipe.float8_block_scaling(): scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" else: scaling_type = "Unknown" diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 54df59fe65..7a930b6cde 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -567,7 +567,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if recipe.fp8blockwise() and not fp8_block_scaling_available: + if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -681,7 +681,7 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if recipe.fp8blockwise() and not fp8_block_scaling_available: + if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -1536,7 +1536,7 @@ def test_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") - if recipe.fp8blockwise(): + if recipe.float8_block_scaling(): pytest.skip("Grouped linear for FP8 blockwise unsupported.") config = model_configs[model] @@ -1733,7 +1733,7 @@ def test_padding_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") - if recipe.fp8blockwise(): + if recipe.float8_block_scaling(): pytest.skip("Float8 block scaling unsupported for grouped linear.") config = model_configs[model] @@ -1945,7 +1945,7 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if recipe.fp8blockwise() and not fp8_block_scaling_available: + if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 89bfc69137..9bf054c55a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -81,7 +81,7 @@ def float8_per_tensor_scaling(self): """Whether the given recipe is per-tensor scaling.""" return isinstance(self, (DelayedScaling, Float8CurrentScaling)) - def fp8blockwise(self): + def float8_block_scaling(self): """Whether the given recipe is float8 blockwise scaling.""" return isinstance(self, Float8BlockScaling) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 7517c6e4fd..c02ff73391 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -815,7 +815,7 @@ def create( cls = MXFP8BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState - elif recipe.fp8blockwise(): + elif recipe.float8_block_scaling(): cls = Float8BlockScalingRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index f136df864b..47b19a69f6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -502,7 +502,9 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8CurrentScalingRecipeState ): return - if recipe.fp8blockwise() and isinstance(recipe_state, Float8BlockScalingRecipeState): + if recipe.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ad869bf0d4..1ea66a7f2c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -91,7 +91,7 @@ def forward( # TODO Support Float8 Current Scaling # pylint: disable=fixme if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") - if fp8 and FP8GlobalStateManager.get_fp8_recipe().fp8blockwise(): + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling") # Make sure input dimensions are compatible diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f3331e5105..50ec8ce490 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -105,7 +105,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): } # no activation fusion written yet # Per-tensor current scaling or fp8 blockwise scaling: [] - if recipe.float8_current_scaling() or recipe.fp8blockwise(): + if recipe.float8_current_scaling() or recipe.float8_block_scaling(): return { "gelu": (tex.gelu, tex.dgelu, None), "relu": (tex.relu, tex.drelu, None), @@ -384,7 +384,7 @@ def forward( act_out, _, fc1_out, _ = fc1_outputs else: fc1_out, *_ = fc1_outputs - if fp8 and FP8GlobalStateManager.get_fp8_recipe().fp8blockwise(): + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): # tex.quantize does not support GELU fusion for blockwise. act_out = activation_func(fc1_out, None) act_out = tex.quantize(act_out, fc2_input_quantizer) @@ -769,7 +769,7 @@ def backward( grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) grad_arg = True - if ctx.fp8 and ctx.fp8_recipe.fp8blockwise(): + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( act_out, @@ -789,7 +789,7 @@ def backward( out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if fc2_bias_grad is None: - if ctx.fp8 and ctx.fp8_recipe.fp8blockwise() and fc2_bias is not None: + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None: # BGRAD not fused with GEMM for float8 blockwise gemm. fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) fc2_bias_grad = fc2_bias_grad_ diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 5cb698350f..4c90ce28e3 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -263,7 +263,7 @@ def _update_quantization_recipe_state( (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) or ( - recipe.fp8blockwise() + recipe.float8_block_scaling() and not isinstance(recipe_state, Float8BlockScalingRecipeState) ) ) From db2aaa9e7765a3bcd1e652e3539d57e85cdbd5bc Mon Sep 17 00:00:00 2001 From: kwyss-nvidia Date: Mon, 7 Apr 2025 16:25:06 -0700 Subject: [PATCH 28/53] Subchannel Block quantized GEMM (#1545) * Add GEMM logic for blockwise quantized tensors. GEMM test cases included in pytorch integration. Signed-off-by: Keith Wyss * Update NVTE_BLOCK_SCALING for GEMM. Signed-off-by: Keith Wyss * Gate feature on CUDA 12.9 Signed-off-by: Keith Wyss * Gemm typo. Signed-off-by: Keith Wyss * Remove unecessary type converter change. Signed-off-by: Keith Wyss * Reflect epilogue availability and test supported epilogues. Signed-off-by: Keith Wyss * GEMM simplifications from recipe branch. Signed-off-by: Keith Wyss * Format py code. Signed-off-by: Keith Wyss * Update GEMM DGelu tests to match support depending on output dtype. Signed-off-by: Keith Wyss * Force pow2Scales in GEMM Signed-off-by: Keith Wyss * Add GEMM test to pytorch test suite. Signed-off-by: Keith Wyss * Add copyright to GEMM test. Signed-off-by: Keith Wyss * Update import for GEMM test. Signed-off-by: Keith Wyss * Add license. Signed-off-by: Keith Wyss * Update test gemm supported predicate. Signed-off-by: Keith Wyss * Use sgemm like interfaces and naming. Signed-off-by: Keith Wyss * Rewrite GEMM comment. Signed-off-by: Keith Wyss * MR Feedback. Signed-off-by: Keith Wyss * Refactor GEMM param canonicalization Configure A and B matrices separately. Have separate code path for each scaling mode. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Prune number of tests. Signed-off-by: Keith Wyss --------- Signed-off-by: Keith Wyss Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 1 + .../blockwise_fp8_gemm_reference.py | 242 +++++ .../blockwise_quantizer_reference.py | 1 + .../test_float8_blockwise_gemm_exact.py | 975 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 364 ++++--- .../common/normalization/layernorm/ln_api.cpp | 4 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 4 +- 7 files changed, 1445 insertions(+), 146 deletions(-) create mode 100644 tests/pytorch/references/blockwise_fp8_gemm_reference.py create mode 100644 tests/pytorch/test_float8_blockwise_gemm_exact.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eaededc4..1206012195 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail " python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" diff --git a/tests/pytorch/references/blockwise_fp8_gemm_reference.py b/tests/pytorch/references/blockwise_fp8_gemm_reference.py new file mode 100644 index 0000000000..5aef986e37 --- /dev/null +++ b/tests/pytorch/references/blockwise_fp8_gemm_reference.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128): + pid = tl.program_id(0) + idx = pid * BLOCK + tl.arange(0, BLOCK) + mask = idx < M * N + + row = idx // N + col = idx % N + + y_offset = row * y_str0 + col * y_str1 + x_offset = row * N + col + s_offset = row * N + col + + y = tl.load(y_ptr + y_offset, mask=mask) + x = tl.load(x_ptr + x_offset, mask=mask) + s = tl.load(s_ptr + s_offset, mask=mask) + + tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask) + + +def fused_fma(y, x, s, BLOCK=128): + """ + Fused multiply-add operation (y = y + x * s). + + PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation). + This function also supports cases where 'y' is non-contiguous in memory. + """ + + assert ( + y.shape == x.shape == s.shape and y.dim() == 2 + ), "All tensors must be 2D with the same shape" + assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous" + + M, N = y.shape + grid = ((M * N + BLOCK - 1) // BLOCK,) + + fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK) + + return y + + +class CuBLASRefBlockwiseGemm: + """ + A cuBLAS compatible reference implementation of subchannel GEMM. + """ + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + demunged_sx: torch.Tensor, + demunged_sw: torch.Tensor, + quant_tile_shape_x: Tuple[int, int], + quant_tile_shape_w: Tuple[int, int], + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + use_split_accumulator: bool = False, + ) -> torch.Tensor: + # demunge scale shapes for cuBLAS + is_a_1d_scaled = quant_tile_shape_x[0] == 1 + is_b_1d_scaled = quant_tile_shape_w[0] == 1 + M, K = qx.shape + N, K = qw.shape + + # mm_tile_shape = (tile_m, tile_n, tile_k) + mm_tile_shape = ( + quant_tile_shape_x[0], + quant_tile_shape_w[0], + quant_tile_shape_w[1], + ) + if bias is not None and bias.numel(): + # To match cuBLAS more closely when bias is applied, + # the reference accumulates into float32, and cast to + # bfloat16 is deferred until after the GEMM. + out_dtype_for_ref = torch.float32 + else: + out_dtype_for_ref = out_dtype + y = self.qgemm_blockwise_2d( + qx, + qw, + out_dtype_for_ref, + demunged_sx, + demunged_sw, + mm_tile_shape, + use_split_accumulator, + is_a_1d_scaled, + is_b_1d_scaled, + ) + if bias is not None and bias.numel(): + y += bias + y = y.to(dtype=out_dtype) + # cublas accumulation first convert to output dtype, then accumulate. + if accumulate: + assert out is not None + y = y + out + else: + assert out is None, "Output tensor should be None when accumulate is False." + + return y + + @classmethod + def qgemm_blockwise_2d( + cls, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + mm_tile_shape: Tuple[int, int, int], + use_split_accumulator: bool, + is_a_1d_scaled: bool, + is_b_1d_scaled: bool, + ) -> torch.Tensor: + """ + Difference between cuBLAS and CUTLASS GEMM implementations: + - cuBLAS accumulation equation: use different equation for each scaling mode. + - For accumulation C in epiloge, it first convert C to output dtype, then accumulate. + """ + + M, K = qx.shape + N, K_w = qw.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + tile_len = 128 + # Calculate grid sizes without padding + grid_m = (M + tile_len - 1) // tile_len + grid_n = (N + tile_len - 1) // tile_len + grid_k = (K + tile_len - 1) // tile_len + + block_m, block_n, block_k = mm_tile_shape + scale_m_per_tile = tile_len // block_m + scale_n_per_tile = tile_len // block_n + assert block_k == tile_len, "block_k must be equal to tile_len" + + # Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM: + # 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers. + # 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # Validate shapes of sx and sw + scale_m_per_tensor = (M + block_m - 1) // block_m + scale_n_per_tensor = (N + block_n - 1) // block_n + assert sx.shape == ( + scale_m_per_tensor, + grid_k, + ), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}" + assert sw.shape == ( + scale_n_per_tensor, + grid_k, + ), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}" + + for i in range(grid_m): + m_start = i * tile_len + m_end = min(m_start + tile_len, M) + m_size = m_end - m_start + + for j in range(grid_n): + n_start = j * tile_len + n_end = min(n_start + tile_len, N) + n_size = n_end - n_start + + y_block = y[m_start:m_end, n_start:n_end] + + for k in range(grid_k): + k_start = k * tile_len + k_end = min(k_start + tile_len, K) + k_size = k_end - k_start + + qx_block = ( + qx[m_start:m_end, k_start:k_end].clone().contiguous() + ) # Shape: [m_size, k_size] + qw_block = ( + qw[n_start:n_end, k_start:k_end].clone().contiguous() + ) # Shape: [n_size, k_size] + + # Extract scaling factors for the current blocks + sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze( + -1 + ) + sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0) + + # Perform qgemm with scaling factors fused in the GEMM + # Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM + one = torch.tensor(1.0, dtype=torch.float32, device=qx.device) + y_partial = torch._scaled_mm( + qx_block, + qw_block.t(), + scale_a=one, + scale_b=one, + out_dtype=torch.float32, + use_fast_accum=not use_split_accumulator, + ) + + # Accumulate the partial result + if is_a_1d_scaled and is_b_1d_scaled: + # 1Dx1D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + # Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM + # y_block.add_(y_partial, alpha=scale.item()) + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + elif not is_a_1d_scaled and is_b_1d_scaled: + # 2Dx1D + # CuBLAS accumulation equation: y += (y * scale_b) * scale_a + y_partial = y_partial * sw_block + fused_fma( + y_block, + y_partial, + sx_block.expand_as(y_partial).contiguous(), + ) + elif is_a_1d_scaled and not is_b_1d_scaled: + # 1Dx2D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + else: + scale = sx_block * sw_block + fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous()) + + y = y.to(out_dtype) + return y diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index b98966f514..f5c9dc0e96 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -49,6 +49,7 @@ def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1) return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + @classmethod def demunge_scale_shape_from_backend( cls, qtensor_shape: Tuple[int, int], diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py new file mode 100644 index 0000000000..9a1cfa2db8 --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -0,0 +1,975 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from transformer_engine.pytorch.utils import get_device_compute_capability +from references.blockwise_quantizer_reference import CuBLASScaleMunger +from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm + + +def fp8_blockwise_gemm_supported() -> bool: + return ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ) + + +def cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + atol: float = 0.0, + rtol: float = 0.0 +): + if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2: + pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2") + if not (is_x_1d_scaled or is_w_1d_scaled): + pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile") + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + if noise_type == "uniform": + x = torch.rand(x_shape, dtype=torch.float32, device=device) * x_magnitude * 2 - x_magnitude + w = torch.rand(w_shape, dtype=torch.float32, device=device) * w_magnitude * 2 - w_magnitude + elif noise_type == "normal": + x = torch.randn(x_shape, dtype=torch.float32, device=device) * x_magnitude + w = torch.randn(w_shape, dtype=torch.float32, device=device) * w_magnitude + else: + assert False + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude + else: + out = None + + assert not (use_bias and use_grad), "Bias grad not supported by GEMM" + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Reference GEMM + ref_gemm = CuBLASRefBlockwiseGemm() + scale_decoder = CuBLASScaleMunger() + qx_data = ( + qx._columnwise_data.view(dtype=x_dtype) + if x_columnwise + else qx._rowwise_data.view(dtype=x_dtype) + ) + qw_data = ( + qw._columnwise_data.view(dtype=w_dtype) + if w_columnwise + else qw._rowwise_data.view(dtype=w_dtype) + ) + ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv + ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv + y_ref = ref_gemm.qgemm( + qx=qx_data, + qw=qw_data, + out_dtype=out_dtype, + demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape + ), + demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape + ), + quant_tile_shape_x=x_quant_tile_shape, + quant_tile_shape_w=w_quant_tile_shape, + bias=bias, + out=out.clone() if accumulate else None, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM" + aux_tensor = torch.randn((M, N), dtype=out_dtype, device=device) if use_gelu else None + aux_tensor_ref = aux_tensor.clone() if use_gelu else None + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + aux_tensor, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y are not the same tensor + assert y_ref is not y, "y_ref and y should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y = torch.where(y.isnan(), torch.zeros_like(y), y) + + if use_gelu: + # Check + if use_grad: + # With use_grad, GEMM should use aux tensor to calculate + # gradient + gelu_ref = tex.dgelu(y_ref, aux_tensor_ref, None) + # TODO: How do we decide whether this is acceptably close? + # Could also try to put the activation inside the reference + # before the output cast to see different tolerances. + torch.testing.assert_close(y, gelu_ref, atol=1e-3, rtol=1e-2) + else: + # aux tensor is pre-gelu aux output. Verify against y_ref. + torch.testing.assert_close(aux_tensor, y_ref, atol=atol, rtol=rtol) + act = torch.nn.GELU() + gelu_ref = act(y_ref) + # gelu_ref = tex.gelu(y_ref, None) + torch.testing.assert_close(y, gelu_ref, atol=atol, rtol=rtol) + else: + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + +def cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED", + expected_err_cls=RuntimeError +): + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + x = torch.rand(x_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + w = torch.rand(w_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + grad = use_grad + gelu_in = None if not use_gelu else torch.randn((M, N), dtype=out_dtype, device=device) + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + with pytest.raises(expected_err_cls, match=expected_err_msg): + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_in, + grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (128, 128, 128), + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 128, 336), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (320, 256, 336), + (1024, 4096, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_shape_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + (320, 256, 336), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 64, 336), + # k > 128 + (256, 256, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_bias( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_bias=True, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["colxrow", "colxcol", "rowxcol"], +) +def test_cublas_gemm_fp8_blockwise_columnwise( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + is_x_columnwise, + is_w_columnwise, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 64, 336), + # k > 128 + (256, 256, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "use_grad", + [ + True, + ], + ids=["grad"], +) +def test_cublas_gemm_fp8_gelu( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad, +): + # NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed + # so the epilogue is disabled on the transformer engine side. + if not use_grad and not (is_x_1d_scaled and not is_w_1d_scaled): + pytest.skip( + "CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)." + ) + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_gelu=True, + use_grad=use_grad, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [False], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_split_accumulator_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_bgrad_not_supported( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # NOTE: BGRAD epilogue is not supported for fp8. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=True, + use_bias=True, + expected_err_msg="Epilogue requested outside of the available", + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no_bias"]) +@pytest.mark.parametrize("use_grad", [True, False], ids=["grad", "no_grad"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_gelu_unsupported_cases_error( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_bias, + use_grad, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + if use_grad and not use_bias and out_dtype == torch.bfloat16: + pytest.skip("DGELU epilogue is supported for bfloat16.") + elif use_grad and not use_bias: + expected_err = "an unsupported value or parameter was passed" + else: + expected_err = "Epilogue requested outside of the available" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=use_grad, + use_bias=use_bias, + use_gelu=True, + expected_err_msg=expected_err, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_illegal_dtype_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # e5m2 by e5m2 not supported. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (False, False), + ], + ids=["2Dx2D"], +) +def test_illegal_2D_by_2D_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # 2D block quantization by 2D block quantization is not supported. + expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg=expected_err_msg, + ) + + +@pytest.mark.parametrize( + "M, K, N, legalX1d, legalX2d", + [ + # M dim unconstrained when X is 2D. + (255, 128, 256, False, True), + # K must be multiple of 16 + (256, 120, 256, False, False), + # N must be a multiple of 8 + (256, 128, 252, False, False), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (False, True), + (True, True), + ], + ids=["1Dx2D", "2Dx1D", "1Dx1D"], +) +def test_unaligned_shapes( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + legalX1d, + legalX2d, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + legal = legalX1d if is_x_1d_scaled else legalX2d + if not legal: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg="dimension requirement", + ) + else: + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + "uniform", # noise type + 1.0, # x_magnitude + 1.0, # w_magnitude + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f19465c44b..6fe3539257 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -52,97 +52,173 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); } +/* Parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ struct GemmParam { - void *A; - void *B; - cublasOperation_t transA; - cublasOperation_t transB; - transformer_engine::DType Atype; - transformer_engine::DType Btype; - void *A_scale_inv; - void *B_scale_inv; - int lda; - int ldb; - - GemmParam(cublasOperation_t transA, cublasOperation_t transB) - : A(nullptr), - B(nullptr), - transA(transA), - transB(transB), - Atype(transformer_engine::DType::kNumTypes), - Btype(transformer_engine::DType::kNumTypes), - A_scale_inv(nullptr), - B_scale_inv(nullptr), - lda(0), - ldb(0) {} + void *A = nullptr; + void *B = nullptr; + cublasOperation_t transA = CUBLAS_OP_N; + cublasOperation_t transB = CUBLAS_OP_N; + transformer_engine::DType Atype = transformer_engine::DType::kNumTypes; + transformer_engine::DType Btype = transformer_engine::DType::kNumTypes; + void *A_scale_inv = nullptr; + void *B_scale_inv = nullptr; + int lda = 0; // A column strides + int ldb = 0; // B column strides }; +/* Populate parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - const int k, const int lda, const int ldb) { + int m, int n, int k) { using namespace transformer_engine; - // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. - // Must either force them both into a common block scaling mode or loosen this - // restriction. - NVTE_CHECK(A.scaling_mode == B.scaling_mode, - "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK( + A.scaling_mode == B.scaling_mode || + (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || + (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), + "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); - GemmParam ret(transA, transB); + GemmParam ret; + + // Device compute capability + const int arch = cuda::sm_arch(); - ret.lda = lda; - ret.ldb = ldb; + // Transpose mode with column-major ordering + bool transa_bool = transA == CUBLAS_OP_T; + bool transb_bool = transB == CUBLAS_OP_T; - // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases - // or need to be treated as `is_tensor_scaling`. + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.A = A.data.dptr; + ret.transA = transA; + ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - if (transA == CUBLAS_OP_T) { - ret.Atype = A.data.dtype; - } else { - ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; - if (is_fp8_dtype(ret.Atype)) { - int arch = cuda::sm_arch(cuda::current_device()); - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); - ret.A = A.columnwise_data.dptr; - ret.transA = CUBLAS_OP_T; - ret.A_scale_inv = A.columnwise_scale_inv.dptr; - ret.lda = k; - } + ret.lda = transa_bool ? k : m; + if (arch < 100 && !transa_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } + } else if (is_mxfp_scaling(A.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + } + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = transA; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = m; + } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + } + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.lda % 16) == 0, + "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. + // Smallest supported CType is 2 bytes in this scaling mode. + NVTE_CHECK((m % 8) == 0, + "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } else { + NVTE_ERROR("A has unsupported scaling mode"); + } + + // Configure B matrix + if (is_tensor_scaling(B.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.B = B.data.dptr; + ret.transB = transB; + ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; - if (transB == CUBLAS_OP_T) { - ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; - if (is_fp8_dtype(ret.Btype)) { - int arch = cuda::sm_arch(cuda::current_device()); - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); - ret.B = B.columnwise_data.dptr; - ret.transB = CUBLAS_OP_N; - ret.B_scale_inv = B.columnwise_scale_inv.dptr; - ret.ldb = k; - } + ret.ldb = transb_bool ? n : k; + if (arch < 100 && transb_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } + } + } else if (is_mxfp_scaling(B.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = transB; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { - ret.Btype = B.data.dtype; + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + + // Requirements from + // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.ldb % 16) == 0, + "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { + // Observed this requirement only present for B tensor is 1D quantized. + NVTE_CHECK((n % 8) == 0, + "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } } else { - // If not tensor scaling (which includes also high precision types), we need to - // use the proper version of data - // We leave the transA/B values as is, since Blackwell supports transposes - ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; - ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; - ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + NVTE_ERROR("B has unsupported scaling mode"); } + return ret; } @@ -153,18 +229,33 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, - int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, - void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, cudaStream_t stream) { + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + // Tensor dims in row-major order + const int A0 = inputA->flat_first_dim(); + const int A1 = inputA->flat_last_dim(); + const int B0 = inputB->flat_first_dim(); + const int B1 = inputB->flat_last_dim(); + + // GEMM dims in column-major order + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, + "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, + ")"); + const int ldd = m; + // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { return; } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); + const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -226,6 +317,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, param.transA == CUBLAS_OP_N ? k : m, param.lda)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, param.transB == CUBLAS_OP_N ? n : k, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); @@ -249,12 +341,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized - // GEMM types. - // Scaling factors. #if CUDA_VERSION >= 12080 - cublasLtMatmulMatrixScale_t scaling_mode; + cublasLtMatmulMatrixScale_t scaling_mode_a; + cublasLtMatmulMatrixScale_t scaling_mode_b; #endif if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { void *A_scale_inverse = param.A_scale_inv; @@ -266,8 +356,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); #if CUDA_VERSION >= 12080 - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; - } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -276,7 +367,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. if (cublasLtGetVersion() <= 120803) { @@ -285,7 +377,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } -#endif + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { +#if CUDA_VERSION >= 12090 + float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; +#else + NVTE_ERROR("FP8 block scaling requires CUDA 12.9+"); +#endif // CUDA_VERSION >= 12090 +#endif // CUDA_VERSION >= 12080 } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + to_string(inputB->scaling_mode) + "."); @@ -293,9 +410,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUDA_VERSION >= 12080 NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); #endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output @@ -305,8 +422,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUDA_VERSION >= 12080 - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + // NOTE: In all current cases where FP8 output is supported, the input is + // scaled identically to the output. + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_D_SCALE_MODE, + &scaling_mode_a, sizeof(scaling_mode_a))); #endif // For FP8 output, cuBLAS requires C_type to match bias_type and // be FP16/BF16 @@ -364,6 +484,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } + if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) || + (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) { + NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS || + epilogue == CUBLASLT_EPILOGUE_DGELU), + "Epilogue requested outside of the available and tested cuBLAS functionality for " + "float8 block scaled GEMM"); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); @@ -411,7 +539,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C @@ -469,35 +596,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const size_t A0 = inputA->flat_first_dim(); - const size_t A1 = inputA->flat_last_dim(); - const size_t B0 = inputB->flat_first_dim(); - const size_t B1 = inputB->flat_last_dim(); - - const int m = transa ? A0 : A1; - const int k = transa ? A1 : A0; - const int n = transb ? B1 : B0; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, 0, 0, false, nullptr, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -525,31 +626,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, + inputCounter, stream); } void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index dae39d82bf..f6b6ae22c2 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -27,7 +27,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -57,7 +57,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (cudnn_backend) { // TODO: add check for GPU ARCH diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 8519fe1b64..c56f9ef407 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -23,7 +23,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -47,7 +47,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; From 45519f104f30097bb458bee9b9b8c6b88350b323 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 7 Apr 2025 16:19:54 -0700 Subject: [PATCH 29/53] Skip grouped linear sanity test. Signed-off-by: Keith Wyss --- tests/pytorch/test_sanity.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 03e8563975..1f9b33b677 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -548,6 +548,8 @@ def test_sanity_grouped_linear( pytest.skip("Grouped linear does not support MXFP8") if fp8_recipe.float8_current_scaling(): pytest.skip("Grouped linear does not support FP8 current scaling") + if fp8_recipe.float8_block_scaling(): + pytest.skip("Grouped linear does not support FP8 block scaling") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") From a21e65b262a0edcad6c6ff485289e6a4f275ea8f Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 7 Apr 2025 16:17:51 -0700 Subject: [PATCH 30/53] Set usage before BF16 gather. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/module/layernorm_linear.py | 1 + transformer_engine/pytorch/module/layernorm_mlp.py | 1 + transformer_engine/pytorch/module/linear.py | 1 + 3 files changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 06bde8ab12..3b79182f4e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -737,6 +737,7 @@ def backward( ): # Async gather in BF16 does not asynchronously # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 50ec8ce490..6aa8109430 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -938,6 +938,7 @@ def backward( ): # Async gather in BF16 does not asynchronously # call quantizer after gather. + ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2cca226708..ac4725cb7c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -623,6 +623,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ): # Async gather in BF16 does not asynchronously # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data From ba5dc5ddebec040939d52ba72efb70202e6fa25d Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Tue, 8 Apr 2025 06:36:11 -0700 Subject: [PATCH 31/53] Enable reuse of dummy wgrad tensor (#1651) * Use dummy wgrads for lower memory consumption Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Vasudevan Rengasamy * Bug fix to avoid sharing gradients. Signed-off-by: Vasudevan Rengasamy * Disable automatic use of batch_p2p_comm for CP2 Signed-off-by: Vasudevan Rengasamy * Change weight to origin_weight for LN_LINEAR Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Vasudevan Rengasamy Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 4 ++-- transformer_engine/pytorch/module/base.py | 17 +++++++++++++++++ .../pytorch/module/layernorm_linear.py | 18 ++++++++---------- transformer_engine/pytorch/module/linear.py | 18 ++++++++---------- 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6440c628cd..0d442435bf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -616,7 +616,7 @@ def forward( rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -1564,7 +1564,7 @@ def backward(ctx, dout): rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cdb75aa1b6..31a464caad 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -43,6 +43,7 @@ _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _multi_stream_cublas_workspace = [] +_dummy_wgrads = {} _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 @@ -78,6 +79,22 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: return _multi_stream_cublas_workspace +def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: + """Returns a dummy tensor of given shape.""" + assert len(shape) == 2 + global _dummy_wgrads + if (shape[0], shape[1], dtype) not in _dummy_wgrads: + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + shape, + dtype=dtype, + device="cuda", + requires_grad=False, + ) + if zero: + _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) + return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + + def initialize_ub( shape: list, tp_size: int, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..f49bad48c3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -19,6 +19,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -796,18 +797,15 @@ def backward( if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..ca9dd29043 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -16,6 +16,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -688,18 +689,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None From 0e8d32465ab9d615c9a24fd70ce4a50078c8a33c Mon Sep 17 00:00:00 2001 From: zhongboz Date: Mon, 7 Apr 2025 21:39:47 -0700 Subject: [PATCH 32/53] refactor for nvte_quantize_v2 Signed-off-by: zhongboz --- tests/pytorch/test_float8tensor.py | 27 +++++++++++++++- .../common/activation/activation_template.h | 8 ++--- transformer_engine/common/common.h | 4 ++- .../common/include/transformer_engine/cast.h | 16 +++++++++- .../transformer_engine/transformer_engine.h | 8 +++++ transformer_engine/common/recipe/__init__.py | 19 +++++++---- .../common/transformer_engine.cpp | 6 ++++ transformer_engine/common/util/cast.cu | 32 +++++++++++++------ .../common/util/cast_kernels.cuh | 21 ++++++++---- transformer_engine/pytorch/csrc/common.h | 2 -- .../pytorch/csrc/extensions/activation.cpp | 6 +++- .../pytorch/csrc/extensions/cast.cpp | 13 ++++++-- .../pytorch/csrc/extensions/normalization.cpp | 18 ++++++++--- .../pytorch/csrc/extensions/quantizer.cpp | 8 ++--- .../pytorch/csrc/extensions/transpose.cpp | 1 + 15 files changed, 145 insertions(+), 44 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 42600e3099..cea5682b38 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import io -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import pytest import torch @@ -124,6 +124,7 @@ def _test_quantize_dequantize( scale: float = 3.5, dtype: torch.dtype = torch.float32, dims: DimsType = 23, + noop_flag: Optional[torch.Tensor] = None, ) -> None: """Check numerical error when casting to FP8 and back""" @@ -132,6 +133,17 @@ def _test_quantize_dequantize( # Cast to FP8 and back x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + if noop_flag is not None: + # if noop, then when we input a different tensor, output should still be x_fp8_orig + x_ref_noop_test = 2 * x_ref.cuda() + x_fp8_orig = x_fp8.clone() + x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_flag) + if noop_flag.item() == 1.0: + torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) + else: + torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) + return + x_fp8 = x_fp8.dequantize().cpu() # Check results @@ -158,6 +170,19 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("noop", [True, False]) + def test_quantize_dequantize_noop(self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + noop: bool + ) -> None: + noop_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + if noop: + noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda") + self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype, noop_flag=noop_tensor) + def test_basic_ops( self, dims: DimsType = 23, diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 708403f911..67de1380ee 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } template @@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } template diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b1fe436379..3851915387 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -233,10 +233,12 @@ struct Tensor { struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; + NVTETensor noop_tensor = nullptr; static constexpr size_t attr_sizes[] = { sizeof(bool), // force_pow_2_scales - sizeof(float) // amax_epsilon + sizeof(float), // amax_epsilon + sizeof(NVTETensor) // noop_tensor }; }; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 7fa7957fa4..7c39ba896e 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,7 +89,7 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. @@ -102,6 +102,20 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream); +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output quantized tensor. + * \param[out] noop Noop tensor. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ba47b9d38c..2219442517 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -286,6 +286,8 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigForcePow2Scales = 0, /*! Small value to add to amax for numerical stability */ kNVTEQuantizationConfigAmaxEpsilon = 1, + /*! Noop tensor, with noop tensor element value = 1, quantization kernel will early exit */ + kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNumAttributes }; @@ -724,6 +726,12 @@ class QuantizationConfigWrapper { &amax_epsilon, sizeof(float)); } + /*! \brief Set noop tensor pointer */ + void set_noop_tensor(NVTETensor noop_tensor) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, + &noop_tensor, sizeof(NVTETensor)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9bf054c55a..9eea7e49ec 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -5,6 +5,7 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations import warnings +import os from enum import Enum from typing import Literal, Optional, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass @@ -314,11 +315,11 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} used for quantization of input tensor x - fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} used for quantization of weight tensor w - fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} used for quantization of gradient tensor dY x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) qblock scaling for x. @@ -346,12 +347,18 @@ class Float8BlockScaling(Recipe): `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + + Notes: By default, fp8_quant_fwd_inp, fp8_quant_fwd_weight, fp8_quant_bwd_grad are set to power of 2 scales. + To Enable FP32 scales, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it. + export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 """ + use_e8_scales: bool = not (os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1") + fp8_format: Format = Format.E4M3 - fp8_quant_fwd_inp = QParams(power_2_scale=True, amax_epsilon=0.0) - fp8_quant_fwd_weight = QParams(power_2_scale=True, amax_epsilon=0.0) - fp8_quant_bwd_grad = QParams(power_2_scale=True, amax_epsilon=0.0) + fp8_quant_fwd_inp = QParams(power_2_scale=use_e8_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=use_e8_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=use_e8_scales, amax_epsilon=0.0) x_block_scaling_dim: int = 1 w_block_scaling_dim: int = 2 grad_block_scaling_dim: int = 1 diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 97df5892b6..706e0bd0b5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -429,6 +429,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(buf, &config_.amax_epsilon, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(buf, &config_.noop_tensor, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -458,6 +461,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(&config_.amax_epsilon, buf, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(&config_.noop_tensor, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 22a50025df..70e732433c 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, nullptr, output, - dbias, workspace, stream); + detail::quantize_helper(input, grad, output, + dbias, workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no NVTE_API_CALL(nvte_quantize_noop); using namespace transformer_engine; + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; @@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, noop, output, - dbias, workspace, stream); + detail::quantize_helper(input, grad, output, + dbias, workspace, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, @@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr const NVTETensor activation_input = nullptr; detail::quantize_helper( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, @@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, @@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, @@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, @@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, @@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index a83b5e7063..2dc80582eb 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1215,8 +1215,9 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, +void quantize_helper(const NVTETensor input, const NVTETensor grad, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config_, cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; @@ -1232,8 +1233,14 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe auto output_tensor = reinterpret_cast(output); auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); + + const QuantizationConfig * quant_config = reinterpret_cast(quant_config_); + + // extract noop tensor from quant_config if it's not null + const NVTETensor noop = quant_config ? quant_config->noop_tensor : nullptr; const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); + switch (output_tensor->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { if (output_tensor->has_columnwise_data()) { @@ -1263,11 +1270,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - constexpr bool force_pow_2_scales = true; + bool force_pow_2_scales = quant_config ? quant_config->force_pow_2_scales : false; + float epsilon = quant_config ? quant_config->amax_epsilon : 0.0f; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, + epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } @@ -1275,7 +1283,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - constexpr bool force_pow_2_scales = true; + bool force_pow_2_scales = quant_config ? quant_config->force_pow_2_scales : false; + float epsilon = quant_config ? quant_config->amax_epsilon : 0.0f; FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() ? FP8BlockwiseRowwiseOption::ROWWISE : FP8BlockwiseRowwiseOption::NONE; @@ -1285,7 +1294,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, rowwise_option, columnwise_option, force_pow_2_scales, stream); + epsilon, rowwise_option, columnwise_option, force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 338f1fcbb1..9ca989944e 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,8 +167,6 @@ class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. DType dtype; - - private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 1ef6f5258d..8554ae7de1 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -50,7 +50,11 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + // sanity check, since activation fusion is not supported for blockwise quantization yet + // need to raise an error here instead of silently going into act_func with wrong numerics + NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); } else { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2c3ccff154..50ec9ae416 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob if (te_output.numel() == 0) return out; + QuantizationConfigWrapper quant_config; + quant_config.set_noop_tensor(te_noop.data()); + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -61,14 +64,20 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; + // this config is used for cs scaling factor computation + // because compute scale is cannot be fused with quantize kernel + // so in nvte_quantize_v2 with current scaling, the quant config is not used again quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + }else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), + nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); return out; diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index cbdeee5b48..b714d9b93f 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -150,6 +150,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -166,14 +167,18 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, + else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + } + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); } @@ -293,6 +298,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -309,14 +315,18 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, + else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + } + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 9ac6292e53..fbf31a7f5b 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -257,12 +257,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); - NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast(), - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires pow2 scales"); - NVTE_CHECK(quantizer.attr("amax_epsilon").cast() == 0.0, - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires amax_epsilon==0"); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, "Unsupported block scaling dim."); } diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index a873586032..4533f02cf7 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -69,6 +69,7 @@ std::vector fused_multi_quantize(std::vector input_list, nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); } else { for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + // TODO: switch to nvte_quantize_v2 with advanced numerical options nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], at::cuda::getCurrentCUDAStream()); } From e07760127d824a6a281b242d032e1f52eb960612 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 09:06:36 -0700 Subject: [PATCH 33/53] Format code. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8tensor.py | 12 +++++----- .../common/activation/activation_template.h | 8 +++---- transformer_engine/common/common.h | 4 ++-- .../common/include/transformer_engine/cast.h | 4 ++-- .../transformer_engine/transformer_engine.h | 6 ++--- transformer_engine/common/recipe/__init__.py | 4 ++-- transformer_engine/common/util/cast.cu | 12 +++++----- .../common/util/cast_kernels.cuh | 22 +++++++++---------- .../pytorch/csrc/extensions/activation.cpp | 3 ++- .../pytorch/csrc/extensions/cast.cpp | 6 ++--- .../pytorch/csrc/extensions/normalization.cpp | 14 +++++------- .../pytorch/csrc/extensions/transpose.cpp | 2 +- 12 files changed, 46 insertions(+), 51 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index cea5682b38..dabe3e06f9 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -135,15 +135,15 @@ def _test_quantize_dequantize( x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) if noop_flag is not None: # if noop, then when we input a different tensor, output should still be x_fp8_orig - x_ref_noop_test = 2 * x_ref.cuda() + x_ref_noop_test = 2 * x_ref.cuda() x_fp8_orig = x_fp8.clone() x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_flag) if noop_flag.item() == 1.0: torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) else: torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) - return - + return + x_fp8 = x_fp8.dequantize().cpu() # Check results @@ -173,10 +173,8 @@ def test_quantize_dequantize_dims(self, dims: DimsType) -> None: @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("noop", [True, False]) - def test_quantize_dequantize_noop(self, - fp8_dtype: tex.DType, - dtype: torch.dtype, - noop: bool + def test_quantize_dequantize_noop( + self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool ) -> None: noop_tensor = torch.empty(1, dtype=torch.float32, device="cuda") if noop: diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 67de1380ee..67f173a4ab 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, output, dbias, - workspace, nullptr, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template @@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, output, dbias, - workspace, nullptr, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 3851915387..728f8ad147 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -236,8 +236,8 @@ struct QuantizationConfig { NVTETensor noop_tensor = nullptr; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float), // amax_epsilon + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon sizeof(NVTETensor) // noop_tensor }; }; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 7c39ba896e..8b45e87de5 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -113,8 +113,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no * \param[in] quant_config Quantization configuration. * \param[in] stream CUDA stream used for the operation. */ -void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, - cudaStream_t stream); +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream); /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 2219442517..05a15f700a 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -726,10 +726,10 @@ class QuantizationConfigWrapper { &amax_epsilon, sizeof(float)); } - /*! \brief Set noop tensor pointer */ + /*! \brief Set noop tensor pointer */ void set_noop_tensor(NVTETensor noop_tensor) { - nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, - &noop_tensor, sizeof(NVTETensor)); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, &noop_tensor, + sizeof(NVTETensor)); } private: diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9eea7e49ec..85df533290 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -348,8 +348,8 @@ class Float8BlockScaling(Recipe): When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. - Notes: By default, fp8_quant_fwd_inp, fp8_quant_fwd_weight, fp8_quant_bwd_grad are set to power of 2 scales. - To Enable FP32 scales, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it. + Notes: By default, fp8_quant_fwd_inp, fp8_quant_fwd_weight, fp8_quant_bwd_grad are set to power of 2 scales. + To Enable FP32 scales, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it. export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 """ diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 70e732433c..1f146c7a33 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, output, - dbias, workspace, nullptr, stream); + detail::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -51,8 +51,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); } -void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, - cudaStream_t stream) { +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_v2); using namespace transformer_engine; @@ -63,8 +63,8 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuant constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, output, - dbias, workspace, quant_config, stream); + detail::quantize_helper( + input, grad, output, dbias, workspace, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 2dc80582eb..d376d9af6d 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1215,10 +1215,9 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor grad, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config_, - cudaStream_t stream) { +void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config_, cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) { @@ -1234,13 +1233,13 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); - const QuantizationConfig * quant_config = reinterpret_cast(quant_config_); + const QuantizationConfig *quant_config = + reinterpret_cast(quant_config_); // extract noop tensor from quant_config if it's not null const NVTETensor noop = quant_config ? quant_config->noop_tensor : nullptr; const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); - switch (output_tensor->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { if (output_tensor->has_columnwise_data()) { @@ -1274,8 +1273,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, float epsilon = quant_config ? quant_config->amax_epsilon : 0.0f; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - epsilon, + output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } @@ -1291,10 +1289,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, FP8BlockwiseColumnwiseOption columnwise_option = output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE : FP8BlockwiseColumnwiseOption::NONE; - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - epsilon, rowwise_option, columnwise_option, force_pow_2_scales, stream); + quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 8554ae7de1..bf037fe931 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -50,7 +50,8 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { // sanity check, since activation fusion is not supported for blockwise quantization yet // need to raise an error here instead of silently going into act_func with wrong numerics diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 50ec9ae416..84e50dea22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -64,7 +64,7 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - // this config is used for cs scaling factor computation + // this config is used for cs scaling factor computation // because compute scale is cannot be fused with quantize kernel // so in nvte_quantize_v2 with current scaling, the quant config is not used again quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); @@ -72,13 +72,13 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - }else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { auto my_quantizer_bw = static_cast(my_quantizer.get()); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index b714d9b93f..dae6ce42e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -172,14 +172,13 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); - } - else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, - at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); } return {out, py::cast(mu), py::cast(rsigma)}; @@ -320,14 +319,13 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); - } - else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, - at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); } return {out, py::none(), py::cast(rsigma)}; diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 4533f02cf7..e12990f79c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -69,7 +69,7 @@ std::vector fused_multi_quantize(std::vector input_list, nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); } else { for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { - // TODO: switch to nvte_quantize_v2 with advanced numerical options + // TODO: switch to nvte_quantize_v2 with advanced numerical options nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], at::cuda::getCurrentCUDAStream()); } From 07a70b897f6c06e492380ef81d383f36c7b2d335 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 09:11:38 -0700 Subject: [PATCH 34/53] Cleanup nvte_quantize_v2 Signed-off-by: Keith Wyss --- tests/pytorch/test_float8tensor.py | 2 +- .../common/include/transformer_engine/cast.h | 3 +-- .../transformer_engine/transformer_engine.h | 6 +++++- transformer_engine/common/recipe/__init__.py | 9 +++++---- .../common/util/cast_kernels.cuh | 18 +++++++++--------- transformer_engine/pytorch/csrc/common.h | 6 ++++++ .../pytorch/csrc/extensions/cast.cpp | 3 +-- .../pytorch/csrc/extensions/normalization.cpp | 6 ++---- .../pytorch/csrc/extensions/quantizer.cpp | 5 +++++ 9 files changed, 35 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index dabe3e06f9..2db45e9e63 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -176,7 +176,7 @@ def test_quantize_dequantize_dims(self, dims: DimsType) -> None: def test_quantize_dequantize_noop( self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool ) -> None: - noop_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda") if noop: noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda") self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype, noop_flag=noop_tensor) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 8b45e87de5..693f2e00c5 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -103,13 +103,12 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no cudaStream_t stream); /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel - * based on the value of the 'noop' tensor. + * by configuring a noop in quant_config. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. * * \param[in] input Input tensor to be cast. * \param[in,out] output Output quantized tensor. - * \param[out] noop Noop tensor. * \param[in] quant_config Quantization configuration. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 05a15f700a..d25eb4d929 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -286,7 +286,11 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigForcePow2Scales = 0, /*! Small value to add to amax for numerical stability */ kNVTEQuantizationConfigAmaxEpsilon = 1, - /*! Noop tensor, with noop tensor element value = 1, quantization kernel will early exit */ + /*! Noop tensor (containing a scalar). + If the scalar element value = 1, quantization kernel will early exit. + This is a tensor in order that the flag can be on GPU and conditional + early exit is compatible with a static CUDA graph. + */ kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNumAttributes }; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 85df533290..3d1f8c2ad6 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -351,14 +351,15 @@ class Float8BlockScaling(Recipe): Notes: By default, fp8_quant_fwd_inp, fp8_quant_fwd_weight, fp8_quant_bwd_grad are set to power of 2 scales. To Enable FP32 scales, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it. export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 + Or initialize the Recipe with non-default QParams in code. """ - use_e8_scales: bool = not (os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1") + use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" fp8_format: Format = Format.E4M3 - fp8_quant_fwd_inp = QParams(power_2_scale=use_e8_scales, amax_epsilon=0.0) - fp8_quant_fwd_weight = QParams(power_2_scale=use_e8_scales, amax_epsilon=0.0) - fp8_quant_bwd_grad = QParams(power_2_scale=use_e8_scales, amax_epsilon=0.0) + fp8_quant_fwd_inp = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) x_block_scaling_dim: int = 1 w_block_scaling_dim: int = 2 grad_block_scaling_dim: int = 1 diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index d376d9af6d..a599d88530 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1217,7 +1217,7 @@ template void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config_, cudaStream_t stream) { + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) { @@ -1233,11 +1233,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); - const QuantizationConfig *quant_config = - reinterpret_cast(quant_config_); + const QuantizationConfig *quant_config_cpp = + reinterpret_cast(quant_config); - // extract noop tensor from quant_config if it's not null - const NVTETensor noop = quant_config ? quant_config->noop_tensor : nullptr; + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); switch (output_tensor->scaling_mode) { @@ -1269,8 +1269,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config ? quant_config->force_pow_2_scales : false; - float epsilon = quant_config ? quant_config->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, @@ -1281,8 +1281,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config ? quant_config->force_pow_2_scales : false; - float epsilon = quant_config ? quant_config->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() ? FP8BlockwiseRowwiseOption::ROWWISE : FP8BlockwiseRowwiseOption::NONE; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 9ca989944e..cd08d1243f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,11 +167,14 @@ class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. DType dtype; + + private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + int block_scaling_dim = 2; public: @@ -185,6 +188,9 @@ class Float8BlockQuantizer : public Quantizer { // Gets rowwise and columnwise_data from tensor and sets them on wrapper void set_quantization_params(TensorWrapper* tensor) const override; + // Set options for quantization on QuantizationConfigWrapper. + void set_quantization_config(QuantizationConfigWrapper* quant_config) const; + // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 84e50dea22..8323c7f7ee 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -74,8 +74,7 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { auto my_quantizer_bw = static_cast(my_quantizer.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + my_quantizer_bw->set_quantization_config(&quant_config); } nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index dae6ce42e2..900586154c 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -174,8 +174,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { auto my_quantizer_bw = static_cast(my_quantizer.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + my_quantizer_bw->set_quantization_config(&quant_config); } nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); @@ -321,8 +320,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { auto my_quantizer_bw = static_cast(my_quantizer.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + my_quantizer_bw->set_quantization_config(&quant_config); } nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index fbf31a7f5b..c54042d7f8 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -278,6 +278,11 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const columnwise_data.shape); } +void Float8BlockQuantizer::set_quantization_config(QuantizationConfigWrapper* quant_config) const { + quant_config->set_force_pow_2_scales(this->force_pow_2_scales); + quant_config->set_amax_epsilon(this->amax_epsilon); +} + std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data) const { using namespace pybind11::literals; From 64f2601e0e8fb1b748af4252fa32d2183272a0de Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 13:35:55 -0700 Subject: [PATCH 35/53] Test fp32 scales. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_scaling_exact.py | 108 +++++++++++++----- 1 file changed, 79 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index d96d482ce9..0baee4975d 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -88,35 +88,7 @@ def initialize_for_many_scales( return result -@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize( - "M, N", - [ - # full tile cases - (128, 128), - (256, 256), - (256, 1024), - (1024, 256), - # Padding required cases - (256, 272), - (303, 300), - (305, 256), - # Some larger tiles. - (2000, 2000), - (2048, 2000), - (2000, 1024), - (2048, 1024), - ], -) -@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] -) -@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) -@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) -def test_quantization_block_tiling_versus_reference( +def check_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, N: int, @@ -221,12 +193,90 @@ def test_quantization_block_tiling_versus_reference( [ # full tile cases (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (303, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) @pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (256, 256), + (2048, 1024), + # Padding required cases + (256, 272), + (303, 300), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference_fp32_scales( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"]) @pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) def test_quantization_block_tiling_extrema_versus_reference( From 3cb712c7c81591a620e0be6305966c8c4712e427 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 14:11:47 -0700 Subject: [PATCH 36/53] Disable CUDA graph. Signed-off-by: Keith Wyss --- tests/pytorch/test_cuda_graphs.py | 4 ++- tests/pytorch/test_sanity.py | 31 +++++++++++++++++++ .../pytorch/ops/basic/basic_linear.py | 10 ++---- transformer_engine/pytorch/ops/op.py | 2 ++ 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 5f896aaafe..517cb11761 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -55,7 +55,9 @@ class ModelConfig: recipe.DelayedScaling(), recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), - recipe.Float8BlockScaling(), + # TODO: Support Float8BlockScaling with CUDA graph. + # One known issue is make_quantizers/num_quantizers, but + # sequential also should have changes. ] # Supported data types diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 1f9b33b677..d8552d63a4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,6 +46,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -440,6 +443,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -471,6 +476,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -503,6 +510,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -544,6 +553,8 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8(): pytest.skip("Grouped linear does not support MXFP8") if fp8_recipe.float8_current_scaling(): @@ -593,6 +604,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -643,6 +656,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -710,6 +725,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -769,6 +786,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -826,6 +845,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -861,6 +882,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -899,6 +922,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -940,6 +965,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -982,8 +1009,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling(): + pytest.skip("cuda graph not supported for float8_block_scaling recipe") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index b0f67b4cb3..3d57bbeba7 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -255,16 +255,10 @@ def _canonicalize_tensor_parallelism( ) def num_quantizers(self, mode: str) -> int: - # Adhere to consistent conventions of the non-fused - # module code, where fwd quantized tensors are (x, w, output) - # and bwd quantized tensors are (ygrad, dgrad) - # - # Fused code need not use the output or dgrad quantizers. - # Since fp8_output flag is not available. if mode == "forward": - return 3 - if mode == "backward": return 2 + if mode == "backward": + return 1 return 0 def reset_parameters(self) -> None: diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 4c90ce28e3..60428e683c 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -219,6 +219,8 @@ def _reset_quantization_recipe_state( if num_quantizers == 0: continue + if recipe.float8_block_scaling(): + raise NotImplementedError("CUDA graph support for float8_block_scaling pending.") # Construct quantization recipe state recipe_state = RecipeState.create( recipe, From 9d4e11eaa508383e35b510dc338e58b09c30be73 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 8 Apr 2025 14:12:05 -0700 Subject: [PATCH 37/53] [PyTorch] Debug GEMM refactor (#1652) * Minor stylistic tweaks and typo fixes Review suggestions from @ptrendx Signed-off-by: Tim Moon * Fix incorrect col strides for MXFP8 matrices Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/gemm/cublaslt_gemm.cu | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 6fe3539257..0cd0762ee5 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -96,8 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const int arch = cuda::sm_arch(); // Transpose mode with column-major ordering - bool transa_bool = transA == CUBLAS_OP_T; - bool transb_bool = transB == CUBLAS_OP_T; + bool is_A_transposed = transA == CUBLAS_OP_T; + bool is_B_transposed = transB == CUBLAS_OP_T; // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -106,8 +106,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transA = transA; ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - ret.lda = transa_bool ? k : m; - if (arch < 100 && !transa_bool) { + ret.lda = is_A_transposed ? k : m; + if (arch < 100 && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -123,28 +123,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = transA; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.lda = m; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? k : m; } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage @@ -165,8 +165,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transB = transB; ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; - ret.ldb = transb_bool ? n : k; - if (arch < 100 && transb_bool) { + ret.ldb = is_B_transposed ? n : k; + if (arch < 100 && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -182,28 +182,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = transB; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; - ret.ldb = k; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = is_B_transposed ? n : k; } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = CUBLAS_OP_N; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; // Requirements from @@ -392,7 +392,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &B_scale_inverse, sizeof(B_scale_inverse))); NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), - "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"); scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; From 07a563b5d5e2838eda266a78cc75deeeb4455697 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 14:48:24 -0700 Subject: [PATCH 38/53] Simplify layernorm linear Signed-off-by: Keith Wyss --- .../pytorch/module/layernorm_linear.py | 47 +++---------------- 1 file changed, 6 insertions(+), 41 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 23fb845ac3..078ac71dde 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -140,11 +140,6 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") - # Avoid quantized norm kernel if norm output will be returned - with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered - ) - tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output @@ -178,43 +173,13 @@ def forward( columnwise_usage = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - # Configure quantizer for normalization output - with_quantized_norm = fp8 and not return_layernorm_output - # for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer - # so we need to set with_quantized_norm to False - if isinstance(input_quantizer, Float8CurrentScalingQuantizer): - with_quantized_norm = False - if isinstance(input_quantizer, Float8BlockQuantizer): - # Quantizer has not been fused with norm yet. - with_quantized_norm = False - if with_quantized_norm: - if with_input_all_gather: - if isinstance(input_quantizer, MXFP8Quantizer): - with_quantized_norm = False - - # Reduce duplicated transpose in `_fix_gathered_fp8_transpose` - if ( + # Avoid quantized norm kernel if norm output will be returned + with_quantized_norm = ( fp8 - and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() - and ub_bulk_dgrad - ): - input_quantizer.set_usage(rowwise=True, columnwise=False) - - ub_obj_fprop = None - ln_out = None - # For DelayScaling, output of normalization will be in fp8. - # For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8. - if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer): - ub_obj_fprop = get_ub(ub_name + "_fprop") - ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) - elif with_quantized_norm: - if with_input_all_gather: - input_quantizer.set_usage(rowwise=True, columnwise=False) - ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") - else: - ln_out = torch.empty_like( - inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" - ) + and not return_layernorm_output + and not return_layernorm_output_gathered + and not isinstance(input_quantizer, Float8BlockQuantizer) + ) # Apply normalization nvtx_range_push(f"{nvtx_label}.norm") From 9a3abe2439c5071eb6e1bb6b6282ec7b70271319 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 15:21:54 -0700 Subject: [PATCH 39/53] Cleanup layernorm linear. Signed-off-by: Keith Wyss --- .../pytorch/module/layernorm_linear.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 078ac71dde..2989dcaeb2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -159,6 +159,8 @@ def forward( " current scaling" ) + # FP8 allgather not implemented for Float8_Blockwise_Tensor + force_bf16_all_gather = with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) # Configure quantizer for norm output if fp8: if input_quantizer is None: @@ -174,11 +176,12 @@ def forward( input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) # Avoid quantized norm kernel if norm output will be returned + force_bf16_blockwise_gather = (fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)) with_quantized_norm = ( fp8 and not return_layernorm_output and not return_layernorm_output_gathered - and not isinstance(input_quantizer, Float8BlockQuantizer) + and not force_bf16_blockwise_gather ) # Apply normalization @@ -206,8 +209,6 @@ def forward( ln_out_total = None ub_obj_fprop = None if with_input_all_gather: - # TODO(kwyss): Support FP8 allgather for FP8 block quantization. - force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered # norm output will be returned @@ -219,7 +220,7 @@ def forward( ln_out_total = input_quantizer(ln_out_total) else: if fp8: - if not (with_quantized_norm or force_high_precision_gather): + if not with_quantized_norm and not force_bf16_blockwise_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: @@ -335,6 +336,10 @@ def forward( if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # For force_bf16_blockwise_gather, we should + # be saving the unquantized ln_out to ctx. + assert not force_bf16_blockwise_gather + # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) From 27d99223fe5b0c2abb40347b7e15ad0d255617f3 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 15:43:15 -0700 Subject: [PATCH 40/53] LayerNorm linear bwd gather logic. Signed-off-by: Keith Wyss --- .../pytorch/module/layernorm_linear.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2989dcaeb2..63a5cab86f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -159,8 +159,6 @@ def forward( " current scaling" ) - # FP8 allgather not implemented for Float8_Blockwise_Tensor - force_bf16_all_gather = with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) # Configure quantizer for norm output if fp8: if input_quantizer is None: @@ -176,12 +174,14 @@ def forward( input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) # Avoid quantized norm kernel if norm output will be returned - force_bf16_blockwise_gather = (fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)) + force_hp_blockwise_ln_out_gather = ( + fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) + ) with_quantized_norm = ( fp8 and not return_layernorm_output and not return_layernorm_output_gathered - and not force_bf16_blockwise_gather + and not force_hp_blockwise_ln_out_gather ) # Apply normalization @@ -220,7 +220,7 @@ def forward( ln_out_total = input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm and not force_bf16_blockwise_gather: + if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: @@ -326,6 +326,7 @@ def forward( ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: @@ -336,9 +337,9 @@ def forward( if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) - # For force_bf16_blockwise_gather, we should + # For force_hp_blockwise_ln_out_gather, we should # be saving the unquantized ln_out to ctx. - assert not force_bf16_blockwise_gather + assert not force_hp_blockwise_ln_out_gather # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(weightmat, QuantizedTensor): @@ -618,11 +619,14 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + # async_op is not compatible with high precision gather since + # gather_along_first_dim does not offer callback chaining. + gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -706,7 +710,7 @@ def backward( if ctx.input_quantizer is not None and not isinstance( ln_out_total, QuantizedTensor ): - # Async gather in BF16 does not asynchronously + # Async gather may have been done in BF16 # call quantizer after gather. ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) From b62d555766cb9667b6b6f49604596de290e6b589 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 16:17:45 -0700 Subject: [PATCH 41/53] Communication updates. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 22 +++++------ .../pytorch/module/layernorm_linear.py | 15 ++++---- .../pytorch/module/layernorm_mlp.py | 37 ++++++++++++------- transformer_engine/pytorch/module/linear.py | 28 ++++++++------ 4 files changed, 58 insertions(+), 44 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e8671bc278..de92f7f7d1 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -953,11 +953,15 @@ def _all_gather_fp8_blockwise( all gather is done asynchronously, but quantization is deferred until after the gather, this returns the full precision tensor. - NOTE: It would be preferable to always honor the quantizer and quantize - the result, and this may be possible via calling `get_future()` on the + NOTE: The implementation is not sophisticated enough to honor async_op=True + and also apply the quantizer if quantizer=None. In such a case, it falls + back to a synchronous gather and invokes the quantizer. + A more sophisticated approach may be possible via calling `get_future()` on the asynchronous handler, calling `make_empty()` on the quantizer, and chaining a callback with `then()` to perform `update_quantized`. This invites complications and also requires pre-allocating the quantized tensor. + Or other callbacks are possible if the type can be relaxed from torch.distributed.Work + to a duck typed done check. """ # Input tensor attributes @@ -1000,13 +1004,9 @@ def _all_gather_fp8_blockwise( device=device, memory_format=torch.contiguous_format, ) - handle = torch.distributed.all_gather_into_tensor( - out, inp, group=process_group, async_op=async_op - ) - # NOTE: if async, out will not be quantized. - if handle is None: - out = quantizer(out) - return out, handle + torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) + out = quantizer(out) + return out, None # Implementation of fp8 gather needs to account for: # * Getting columnwise data as a transpose of how it is stored for GEMMS. # * Gathering non GEMM swizzled scales. @@ -1155,10 +1155,6 @@ def gather_along_first_dim( ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """ All-gather tensors and concatenate along first dimension. - - NOTE: Caller should be aware that there are asynchronous cases - where quantizer is not None, but the output will not be quantized. - This affects Float8BlockQuantizer. """ # Return immediately if no communication is required diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 63a5cab86f..ff5279c83c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -174,6 +174,7 @@ def forward( input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) # Avoid quantized norm kernel if norm output will be returned + # or if a gather of ln_out must be in high precision. force_hp_blockwise_ln_out_gather = ( fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) ) @@ -707,13 +708,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.input_quantizer is not None and not isinstance( - ln_out_total, QuantizedTensor - ): - # Async gather may have been done in BF16 - # call quantizer after gather. - ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.input_quantizer(ln_out_total) + if ctx.input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather may have been done in BF16 + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6aa8109430..d9a4058d4d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -225,6 +225,11 @@ def forward( ub_overlap_rs = ub_overlap_rs and is_grad_enabled backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + # TODO(kwyss): Support FP8 allgather of Float8Block quantization. + force_hp_fc1_input_gather = ( + fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer) + ) + # Configure quantizer for norm output if fp8: if fc1_input_quantizer is None: @@ -260,20 +265,19 @@ def forward( ln_out_total = None ub_obj_lnout = None if sequence_parallel: - # TODO(kwyss): Support FP8 allgather of Float8Block quantization. - force_high_precision_gather = isinstance(fc1_input_quantizer, Float8BlockQuantizer) if return_layernorm_output_gathered: # Perform all-gather in high precision if gathered # norm output will be returned ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8: - ln_out = fc1_input_quantizer(ln_out) + if not force_hp_fc1_input_gather: + ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = fc1_input_quantizer(ln_out_total) else: if fp8: - if not (with_quantized_norm or force_high_precision_gather): + if not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -289,7 +293,10 @@ def forward( quantizer=(fc1_input_quantizer if fp8 else None), ) else: - if fp8 and not with_quantized_norm: + # NOTE: force_hp_fc1_input_gather is redundant with else, but + # here for clarity. We should not quantize ln_out if bwd needs + # to gather in hp. + if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) ln_out_total = ln_out @@ -475,6 +482,8 @@ def forward( if not return_layernorm_output: clear_tensor_data(ln_out) ln_out = None + elif force_hp_fc1_input_gather: + assert not isinstance(ln_out, QuantizedTensor) if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None @@ -503,6 +512,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -710,11 +720,12 @@ def backward( else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) + gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) else: ln_out_total = ln_out @@ -933,13 +944,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fc1_input_quantizer is not None and not isinstance( - ln_out_total, QuantizedTensor - ): - # Async gather in BF16 does not asynchronously - # call quantizer after gather. - ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.fc1_input_quantizer(ln_out_total) + if ctx.fc1_input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 85326c4e02..7f223b3973 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -131,6 +131,10 @@ def forward( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False + # TODO(kwyss): Support FP8 allgather for FP8 block quantization. + force_hp_input_gather = ( + fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) + ) if fp8: assert_dim_for_fp8_exec(inputmat, weight) if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( @@ -144,9 +148,7 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - # TODO(kwyss): Support FP8 allgather for FP8 block quantization. - force_high_precision_gather = isinstance(input_quantizer, Float8BlockQuantizer) - if force_high_precision_gather: + if force_hp_input_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) inputmat_total, _ = gather_along_first_dim( inputmat, tp_group, quantizer=input_quantizer @@ -286,6 +288,8 @@ def forward( # can be allgathered. if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if force_hp_input_gather: + assert not isinstance(inputmat, QuantizedTensor) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. @@ -334,6 +338,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -529,11 +534,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + gather_quantizer = None if ctx.force_hp_input_gather else quantizer inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -619,13 +625,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.input_quantizer is not None and not isinstance( - inputmat_total, QuantizedTensor - ): - # Async gather in BF16 does not asynchronously - # call quantizer after gather. - ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - inputmat_total = ctx.input_quantizer(inputmat_total) + if ctx.input_quantizer is not None and not isinstance( + inputmat_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data if isinstance(inputmat_total, QuantizedTensor): From 196cd6d984541cb650a87df9c6b9effa92635eca Mon Sep 17 00:00:00 2001 From: kwyss-nvidia Date: Tue, 8 Apr 2025 16:39:25 -0700 Subject: [PATCH 42/53] Update transformer_engine/pytorch/ops/op.py Apply MR comment change. Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: kwyss-nvidia --- transformer_engine/pytorch/ops/op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 60428e683c..ca7f89d29e 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -220,7 +220,8 @@ def _reset_quantization_recipe_state( continue if recipe.float8_block_scaling(): - raise NotImplementedError("CUDA graph support for float8_block_scaling pending.") + raise NotImplementedError("Fusible operations do not support FP8 block scaling recipe") + # Construct quantization recipe state recipe_state = RecipeState.create( recipe, From 67e790bbd5408044df8ad4f26ce4c9385c2712bf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 23:40:04 +0000 Subject: [PATCH 43/53] Lint fix. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/distributed.py | 2 +- transformer_engine/pytorch/ops/op.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index de92f7f7d1..2aaee60307 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -942,7 +942,7 @@ def _all_gather_fp8_blockwise( inp: torch.Tensor, process_group: dist_group_type, *, - async_op: bool = False, + async_op: bool = False, # pylint: disable=unused-argument quantizer: Optional[Quantizer] = None, out_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index ca7f89d29e..879a5794ef 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -220,7 +220,9 @@ def _reset_quantization_recipe_state( continue if recipe.float8_block_scaling(): - raise NotImplementedError("Fusible operations do not support FP8 block scaling recipe") + raise NotImplementedError( + "Fusible operations do not support FP8 block scaling recipe" + ) # Construct quantization recipe state recipe_state = RecipeState.create( From ea9e46bc14748160c35536ee6120cffe3361f608 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 17:31:17 -0700 Subject: [PATCH 44/53] MR feedback. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8tensor.py | 29 ++++++++++--------- .../common/include/transformer_engine/cast.h | 5 +--- .../transformer_engine/transformer_engine.h | 4 +-- transformer_engine/common/recipe/__init__.py | 15 ---------- transformer_engine/pytorch/csrc/common.h | 5 +--- .../pytorch/csrc/extensions/cast.cpp | 3 +- .../pytorch/csrc/extensions/normalization.cpp | 6 ++-- .../pytorch/csrc/extensions/quantizer.cpp | 5 ---- transformer_engine/pytorch/distributed.py | 19 ++++-------- .../pytorch/module/layernorm_linear.py | 3 +- .../pytorch/module/layernorm_mlp.py | 4 +-- transformer_engine/pytorch/module/linear.py | 6 ++-- 12 files changed, 38 insertions(+), 66 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 2db45e9e63..d36da704b0 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -124,7 +124,6 @@ def _test_quantize_dequantize( scale: float = 3.5, dtype: torch.dtype = torch.float32, dims: DimsType = 23, - noop_flag: Optional[torch.Tensor] = None, ) -> None: """Check numerical error when casting to FP8 and back""" @@ -133,17 +132,6 @@ def _test_quantize_dequantize( # Cast to FP8 and back x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) - if noop_flag is not None: - # if noop, then when we input a different tensor, output should still be x_fp8_orig - x_ref_noop_test = 2 * x_ref.cuda() - x_fp8_orig = x_fp8.clone() - x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_flag) - if noop_flag.item() == 1.0: - torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) - else: - torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) - return - x_fp8 = x_fp8.dequantize().cpu() # Check results @@ -179,7 +167,22 @@ def test_quantize_dequantize_noop( noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda") if noop: noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda") - self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype, noop_flag=noop_tensor) + dims = 23 + scale: float = 3.5 + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + # if noop, then when we input a different tensor, output should still be x_fp8_orig + x_ref_noop_test = 2 * x_ref.cuda() + x_fp8_orig = x_fp8.clone() + x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor) + if noop_tensor.item() == 1.0: + torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) + else: + torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) def test_basic_ops( self, diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 693f2e00c5..64136b2c43 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -102,10 +102,7 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream); -/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel - * by configuring a noop in quant_config. - * The type of quantized tensor in the output depends on the scaling mode of the output - * tensor. +/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options. * * \param[in] input Input tensor to be cast. * \param[in,out] output Output quantized tensor. diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index d25eb4d929..d3ee446f83 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -288,8 +288,8 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigAmaxEpsilon = 1, /*! Noop tensor (containing a scalar). If the scalar element value = 1, quantization kernel will early exit. - This is a tensor in order that the flag can be on GPU and conditional - early exit is compatible with a static CUDA graph. + This is a tensor because the flag must be on GPU in order to enable + conditional early even when captured in a static CUDA graph. */ kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNumAttributes diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 3d1f8c2ad6..8a9bf69c3a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -333,21 +333,6 @@ class Float8BlockScaling(Recipe): use for calculating dgrad in backward pass fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True use for calculating dgrad in backward pass - fp8_dpa: bool, default = `False` - Whether to enable FP8 dot product attention (DPA). When the model is placed in an - `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the - inputs from higher precision to FP8, performs attention in FP8, and casts tensors - back to higher precision as outputs. FP8 DPA currently is only supported in the - `FusedAttention` backend. - fp8_mha: bool, default = `False` - Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting - operations mentioned above at the DPA boundaries. Currently only standard MHA modules - i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When - `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as - `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. - When `fp8_mha = True, fp8_dpa = True`, it becomes - `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. - Notes: By default, fp8_quant_fwd_inp, fp8_quant_fwd_weight, fp8_quant_bwd_grad are set to power of 2 scales. To Enable FP32 scales, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it. export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index cd08d1243f..2d02a7d920 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -168,13 +168,13 @@ class Float8BlockQuantizer : public Quantizer { // Which float8 type is used for q data. DType dtype; - private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + private: int block_scaling_dim = 2; public: @@ -188,9 +188,6 @@ class Float8BlockQuantizer : public Quantizer { // Gets rowwise and columnwise_data from tensor and sets them on wrapper void set_quantization_params(TensorWrapper* tensor) const override; - // Set options for quantization on QuantizationConfigWrapper. - void set_quantization_config(QuantizationConfigWrapper* quant_config) const; - // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 8323c7f7ee..84e50dea22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -74,7 +74,8 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { auto my_quantizer_bw = static_cast(my_quantizer.get()); - my_quantizer_bw->set_quantization_config(&quant_config); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 900586154c..dae6ce42e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -174,7 +174,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { auto my_quantizer_bw = static_cast(my_quantizer.get()); - my_quantizer_bw->set_quantization_config(&quant_config); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); @@ -320,7 +321,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { auto my_quantizer_bw = static_cast(my_quantizer.get()); - my_quantizer_bw->set_quantization_config(&quant_config); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index c54042d7f8..fbf31a7f5b 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -278,11 +278,6 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const columnwise_data.shape); } -void Float8BlockQuantizer::set_quantization_config(QuantizationConfigWrapper* quant_config) const { - quant_config->set_force_pow_2_scales(this->force_pow_2_scales); - quant_config->set_amax_epsilon(this->amax_epsilon); -} - std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data) const { using namespace pybind11::literals; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 2aaee60307..7a1fde164b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -942,26 +942,17 @@ def _all_gather_fp8_blockwise( inp: torch.Tensor, process_group: dist_group_type, *, - async_op: bool = False, # pylint: disable=unused-argument + async_op: bool = False, # pylint: disable=unused-argument quantizer: Optional[Quantizer] = None, out_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """ All-gather FP8 tensor along first dimension for blockwise quantization. - Usually returns a Float8BlockwiseQTensorBase. In the case that the - all gather is done asynchronously, but quantization is deferred until - after the gather, this returns the full precision tensor. - - NOTE: The implementation is not sophisticated enough to honor async_op=True - and also apply the quantizer if quantizer=None. In such a case, it falls - back to a synchronous gather and invokes the quantizer. - A more sophisticated approach may be possible via calling `get_future()` on the - asynchronous handler, calling `make_empty()` on the quantizer, and chaining - a callback with `then()` to perform `update_quantized`. This invites - complications and also requires pre-allocating the quantized tensor. - Or other callbacks are possible if the type can be relaxed from torch.distributed.Work - to a duck typed done check. + Returns: quantizer(gather(inp)) + + NOTE: The implementation is not sophisticated enough to honor async_op=True. + In some cases it falls back to synchronous gather and invokes the quantizer. """ # Input tensor attributes diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ff5279c83c..df3ae05f31 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -168,7 +168,6 @@ def forward( columnwise_usage and with_input_all_gather and not isinstance(input_quantizer, MXFP8Quantizer) - and not isinstance(input_quantizer, Float8BlockQuantizer) ): columnwise_usage = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) @@ -177,7 +176,7 @@ def forward( # or if a gather of ln_out must be in high precision. force_hp_blockwise_ln_out_gather = ( fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) - ) + ) # Perform TP communication in high precision. with_quantized_norm = ( fp8 and not return_layernorm_output diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d9a4058d4d..e51fe43cc0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -225,10 +225,10 @@ def forward( ub_overlap_rs = ub_overlap_rs and is_grad_enabled backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad - # TODO(kwyss): Support FP8 allgather of Float8Block quantization. + # TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe force_hp_fc1_input_gather = ( fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer) - ) + ) # Perform TP communication in high precision. # Configure quantizer for norm output if fp8: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7f223b3973..2887b2e452 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -134,7 +134,7 @@ def forward( # TODO(kwyss): Support FP8 allgather for FP8 block quantization. force_hp_input_gather = ( fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) - ) + ) # Perform TP communication in high precision. if fp8: assert_dim_for_fp8_exec(inputmat, weight) if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( @@ -156,8 +156,10 @@ def forward( else: if not isinstance(inputmat, QuantizedTensor): columnwise_usage = backward_needs_input and isinstance( - input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer) + input_quantizer, MXFP8Quantizer ) + # force_hp_input_gather should enforce this + assert not isinstance(input_quantizer, Float8BlockQuantizer) input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) inputmat = input_quantizer(inputmat) own_quantized_input = True From 324792b15caa677b4bb4a7306dc887b361c93d44 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 18:49:44 -0700 Subject: [PATCH 45/53] Enable cuda graph tests. Signed-off-by: Keith Wyss --- tests/pytorch/test_cuda_graphs.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 517cb11761..7bfe506f26 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -27,6 +27,9 @@ # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -55,9 +58,7 @@ class ModelConfig: recipe.DelayedScaling(), recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), - # TODO: Support Float8BlockScaling with CUDA graph. - # One known issue is make_quantizers/num_quantizers, but - # sequential also should have changes. + recipe.Float8BlockScaling(), ] # Supported data types @@ -319,9 +320,13 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling() and module == "linear_op": + pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") # Run model with different CUDA graph settings. model_config = model_configs[model_config] kwargs = dict( From 54e72792234fb48b580e965049d1c0b207b5ddc2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 8 Apr 2025 21:59:10 -0700 Subject: [PATCH 46/53] Reduce chance of spurious failure and reword. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8blockwisetensor.py | 4 ++++ transformer_engine/common/recipe/__init__.py | 12 +++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 2558a1d190..c42e1891e8 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -111,6 +111,10 @@ def _test_quantize_dequantize( # Initialize random data x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + if x_ref.numel() == 1: + # Reduce the chance that with trigger the trivial pass + # logic when all elements of x are close to zero. + x_ref.view(1)[0] = 42.1415 x_ref_cuda = x_ref.to("cuda") # Cast to FP8 and back diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 8a9bf69c3a..80857e565c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -301,7 +301,8 @@ class Float8BlockScaling(Recipe): In this strategy, tensors are scaled in blockwise fashion. Values within each block share a common scaling factor. The block dimensionality - can be configured. The scaling factors are float32. + can be configured. The scaling factors are float32 containers. They + will by default be constrained to powers of 2. Since the scaling happens in a particular direction (either rowwise or columnwise), the quantized tensor and its transpose are not numerically @@ -310,6 +311,11 @@ class Float8BlockScaling(Recipe): during the quantization both versions are computed from the high precision input to avoid double quantization errors. + NOTE: To relax the default constraint that scales be powers of 2, set env variable + NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. + export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 + Or initialize the Recipe with non-default QParams in code for increased control. + Parameters ---------- fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 @@ -333,10 +339,6 @@ class Float8BlockScaling(Recipe): use for calculating dgrad in backward pass fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True use for calculating dgrad in backward pass - Notes: By default, fp8_quant_fwd_inp, fp8_quant_fwd_weight, fp8_quant_bwd_grad are set to power of 2 scales. - To Enable FP32 scales, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it. - export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 - Or initialize the Recipe with non-default QParams in code. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" From 962d9c53423e604f3c22c3ad634bc5a0d66e4f7c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 9 Apr 2025 16:25:23 -0400 Subject: [PATCH 47/53] [JAX] Scaling Enum Abstracting (#1655) * scaling enum abstract * rm NVTE_ from ScalingMode names * rework scaling mode enum in grouped gemm * fix norm sharding --------- Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 4 +- examples/jax/encoder/test_multigpu_encoder.py | 4 +- .../jax/encoder/test_single_gpu_encoder.py | 4 +- examples/jax/mnist/test_single_gpu_mnist.py | 4 +- qa/L0_jax_distributed_unittest/test.sh | 2 +- tests/jax/test_custom_call_compute.py | 40 ++--- tests/jax/test_distributed_layernorm.py | 2 +- tests/jax/test_distributed_layernorm_mlp.py | 2 +- tests/jax/test_layer.py | 4 +- .../jax/cpp_extensions/activation.py | 40 ++--- transformer_engine/jax/cpp_extensions/gemm.py | 34 ++-- transformer_engine/jax/cpp_extensions/misc.py | 2 +- .../jax/cpp_extensions/normalization.py | 162 ++++++++---------- .../jax/cpp_extensions/quantization.py | 24 +-- transformer_engine/jax/csrc/extensions.h | 22 ++- .../jax/csrc/extensions/activation.cpp | 65 +++---- .../jax/csrc/extensions/gemm.cpp | 37 ++-- transformer_engine/jax/csrc/extensions/misc.h | 23 +++ .../jax/csrc/extensions/normalization.cpp | 21 ++- .../jax/csrc/extensions/pybind.cpp | 8 +- .../jax/csrc/extensions/quantization.cpp | 58 +++++-- transformer_engine/jax/flax/module.py | 2 +- .../jax/quantize/dequantizer.py | 4 +- transformer_engine/jax/quantize/helper.py | 16 +- transformer_engine/jax/quantize/quantizer.py | 15 +- .../jax/quantize/scaling_modes.py | 25 ++- transformer_engine/jax/quantize/tensor.py | 3 +- 27 files changed, 335 insertions(+), 292 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7e6605c9fe..eabd1b2a3f 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -448,8 +448,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index ba62d964fa..839bc3175e 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -416,8 +416,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 1300be01bb..df78157cc5 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -327,8 +327,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 4022cb7493..435750a1db 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -306,8 +306,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 3253861484..3fbfb9cf5c 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4dc07a2eea..8917e92465 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -48,21 +48,21 @@ LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) + supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING) + supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) def is_shape_supported_by_mxfp8(input_shape): try: if isinstance(input_shape, type(pytest.param(0))): input_shape = input_shape.values[0] - ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) + ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True except: # get_scale_shapes will raise an exception if the shape is not supported @@ -170,7 +170,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, ) quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) @@ -198,7 +198,7 @@ def test_act_forward_with_delayed_scaling_fp8( te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=q_layout, ) @@ -223,7 +223,7 @@ def test_act_forward_with_block_scaling_fp8( self.activation_type = activation_type quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) output = tex.act_lu(x, activation_type, quantizer) @@ -345,7 +345,7 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_layout=q_layout, ) @@ -420,7 +420,7 @@ def test_norm_forward_with_delayed_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_layout=q_layout, ) @@ -437,7 +437,7 @@ def test_norm_forward_with_block_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_layout=QuantizeLayout.ROWWISE_COLWISE, ) @@ -493,7 +493,7 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -533,7 +533,7 @@ class TestFusedQuantize: def test_quantize_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis ): - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( + if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") @@ -618,7 +618,7 @@ def test_quantize_dact_dbias_no_quantization( in_dtype=in_dtype, input_shape=input_shape, out_dtype=in_dtype, - scaling_mode=ScalingMode.NVTE_NO_SCALING, + scaling_mode=ScalingMode.NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=QuantizeLayout.ROWWISE, @@ -639,7 +639,7 @@ def test_quantize_dact_dbias_delayed_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -670,7 +670,7 @@ def test_quantize_dact_dbias_mxfp8_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -785,7 +785,7 @@ def ref_func(x, w, bias, data_layout): scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) @@ -830,7 +830,7 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): Test layernorm_dense VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -886,7 +886,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, @@ -916,7 +916,7 @@ def test_layernorm_mlp_grad( Test layernorm_mlp VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -993,7 +993,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 6d4cde364f..476d455a6a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -29,7 +29,7 @@ } is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 4350d5e8f3..cf311ac404 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -36,7 +36,7 @@ is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index b89530c19f..a21583a98c 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -39,7 +39,7 @@ def enable_fused_attn(): is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) QUANTIZE_RECIPES = [] """ Find supported scaling modes""" @@ -313,7 +313,7 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: _, updated_quantize_meta = flax.core.pop( updated_state[0], QuantizeConfig.COLLECTION_NAME ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d7676781c3..c27f6f50f7 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -162,7 +162,7 @@ def lowering( assert scale_aval is None or scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x + ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x ) return out @@ -282,7 +282,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -293,9 +293,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -339,7 +339,7 @@ def partition( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -350,9 +350,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -391,7 +391,7 @@ def sharded_impl(x, scale): ) ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -463,7 +463,7 @@ def abstract( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: colwise_out_shape = out_shape @@ -545,7 +545,7 @@ def lowering( dz, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), @@ -673,7 +673,7 @@ def infer_sharding_from_operands( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -691,9 +691,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -743,7 +743,7 @@ def partition( ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -761,9 +761,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -810,7 +810,7 @@ def sharded_impl(dz, x, scale): else: global_dbias = local_dbias - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -928,7 +928,7 @@ def act_lu( out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1042,7 +1042,7 @@ def quantize_dact_dbias( # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, # unused scale_dtype=jnp.float32, # unused scale_shapes=((), ()), # unused @@ -1095,7 +1095,7 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1df2bcc97f..0327542c2f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -98,7 +98,7 @@ def lowering( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + scaling_mode=scaling_mode.value, ) @staticmethod @@ -123,7 +123,7 @@ def impl( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=scaling_mode, out_dtype=out_dtype, out_flat_size=out_flat_size, ) @@ -198,7 +198,7 @@ def _jax_gemm_delayed_scaling_fp8( ): """FP8 GEMM for XLA pattern match""" assert ( - rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -230,7 +230,7 @@ def _jax_gemm_mxfp8_1d( JAX GEMM for MXFP8 via scaled_matmul """ assert ( - rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING + rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING ), "rhs does not have MXFP8 1D scaling mode" from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper @@ -291,10 +291,10 @@ def _jax_gemm( def _jax_gemm_fp8_impl(lhs, rhs): - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) - if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") @@ -403,7 +403,7 @@ def grouped_gemm( rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" @@ -415,7 +415,7 @@ def grouped_gemm( dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NVTE_NO_SCALING + scaling_mode = ScalingMode.NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape out_dtype = lhs.dtype @@ -427,13 +427,13 @@ def grouped_gemm( lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) @@ -470,13 +470,13 @@ def grouped_gemm( dims.append((bm, bn, k)) lhs_contig_.append(lhs_3d.reshape(-1)) rhs_contig_.append(rhs_3d.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) if bias_list is not None: @@ -493,8 +493,8 @@ def grouped_gemm( # TE/common does not support NVTE_NO_SCALING yet # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + if scaling_mode == ScalingMode.NO_SCALING: + scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING # Perform batched GEMM on flattened inputs out_contig = GroupedGemmPrimitive.outer_primitive.bind( @@ -505,7 +505,7 @@ def grouped_gemm( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, out_flat_size=out_flat_size, ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index c79eda5568..d64104ac27 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, """ should_apply_war = ( quantizer is not None - and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x() ) if not should_apply_war: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 74882c92db..388d4f17ee 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -105,6 +105,26 @@ def abstract( if norm_type == NVTE_Norm_Type.LayerNorm: assert gamma_aval.size == beta_aval.size + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) + if norm_type == NVTE_Norm_Type.RMSNorm: + mu_aval = mu_aval.update(shape=(1,)) + + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) + + scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size @@ -112,33 +132,13 @@ def abstract( jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype jax_dtype_to_te_dtype(out_dtype), norm_type, - scaling_mode.value, + scaling_mode, zero_centered_gamma, epsilon, get_forward_sm_margin(), is_2x, ) - - out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) - mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_aval = mu_aval.update(shape=(1,)) - - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x_aval.shape, is_padded=not is_outer - ) - - scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) - colwise_out_aval = jax.core.ShapedArray( - shape=x_aval.shape if is_2x else (1,), dtype=out_dtype - ) - - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - - wkspace_aval = x_aval.update( + wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) @@ -274,9 +274,9 @@ def impl( scale_shapes=scale_shapes, is_outer=False, ) - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x.shape, is_padded=False - ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x.shape, is_padded=False) # slice out padding for mxfp8, noop for DelayedScaling scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape @@ -364,6 +364,8 @@ def infer_sharding_from_operands( del zero_centered_gamma, epsilon, out_dtype, result_infos del scale_dtype, scale_shapes, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) + out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -371,34 +373,27 @@ def infer_sharding_from_operands( "and hurt performance." ) - out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out" + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) - rsigma_sharding = NamedSharding( mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") + + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) - - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") output = ( out_sharding, colwise_out_sharding, @@ -427,8 +422,11 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) g_spec = get_padded_spec(arg_infos[2]) b_spec = get_padded_spec(arg_infos[3]) + out_spec = (*x_spec[:-1], None) + if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -445,43 +443,30 @@ def partition( f"{NormFwdPrimitive.name} does not support sharding of parameter beta " "Enforcing no sharding of parameters hidden dim! " ) - x_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x" - ) - g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma") - b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta") - out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out") - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" + ) rsigma_sharding = NamedSharding( - mesh, - PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]), - desc="NormFwdPrimitive.rsigma", + mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale" - ) - scale_inv_sharding = scale_sharding.duplicate_with_new_description( - "NormFwdPrimitive.scale_inv" + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") - arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, colwise_out_sharding, @@ -517,7 +502,7 @@ def sharded_impl(x, scale, gamma, beta): scale_shapes=scale_shapes, is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -824,7 +809,6 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) - if quantizer is None: output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, @@ -835,7 +819,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((1,), (1,)), @@ -845,7 +829,7 @@ def layernorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -864,7 +848,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -873,7 +857,7 @@ def layernorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -882,7 +866,7 @@ def layernorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) @@ -1017,7 +1001,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1027,7 +1011,7 @@ def rmsnorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -1046,7 +1030,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -1055,7 +1039,7 @@ def rmsnorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1064,7 +1048,7 @@ def rmsnorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 034e149c50..2911b5a420 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -93,7 +93,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -114,6 +114,10 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + scaling_mode, + QuantizeLayout( + q_layout + ), # For now until we have auto-decoding for QuantizeLayout enum ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -176,7 +180,7 @@ def lowering( ctx, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, is_dbias=is_dbias, @@ -302,7 +306,7 @@ def infer_sharding_from_operands( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -322,9 +326,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -374,7 +378,7 @@ def partition( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -394,9 +398,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -445,7 +449,7 @@ def sharded_impl(x, scale): is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -588,7 +592,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1950d6cbab..aaaf57fab7 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -31,6 +31,9 @@ #include "transformer_engine/activation.h" #include "utils.h" +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); + namespace transformer_engine { namespace jax { @@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); + +pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, bool is_2x); + // Normalization XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); @@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training); @@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype); - -XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); - -pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x); + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index e71597e4b3..fc7f231f34 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -17,7 +17,7 @@ namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum, + Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto scaling_mode = static_cast(scaling_mode_enum); auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis @@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("act_enum") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid auto dact_input_tensor = TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); - auto output_tensor = TensorWrapper(static_cast(scaling_mode)); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (is_fp8_dtype(out_dtype)) { @@ -172,9 +173,8 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } if (is_2x) { - auto &tmp_shape = scaling_mode == static_cast(NVTE_DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter @@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } } - if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, - bool is_dbias, int64_t act_enum) { + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, bool is_2x, bool is_dbias) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); - auto scaling_mode = static_cast(scaling_mode_enum); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -295,9 +297,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK( - !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), - "TE/common does not support delayed scaling for 2x with gated activations."); + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { switch (act_type) { @@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") + .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias") - .Attr("act_enum"), + .Attr("is_dbias"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e5ec160c91..d4b9bf720e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const int64_t &scaling_mode, + int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode, cudaStream_t stream) { size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); @@ -90,14 +90,17 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - auto lhs_i = TensorWrapper(static_cast(lhs_ptr), lhs_shape, lhs_dtype, nullptr, - nullptr, reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i = TensorWrapper(static_cast(rhs_ptr), rhs_shape, rhs_dtype, nullptr, - nullptr, reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); + rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); + + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, + std::vector{1}); + rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, + std::vector{1}); + } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", MXFP8_BLOCK_SIZE, k); size_t sinv_k = k / MXFP8_BLOCK_SIZE; @@ -107,20 +110,15 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh rhs_sinv_shape[1] = sinv_k; // Note: the scale_inv array should have been swizzled in Python before lowering - TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING); - TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, lhs_sinv_shape); rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); } else { - NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); + NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); lhs_ptr += m * k * lhs_dtype_bytes; @@ -169,7 +167,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) { + Result_Type workspace_flatten, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode) { // Inputs auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); @@ -207,7 +206,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Ret() // out_flatten .Ret() // workspace_flatten .Attr("num_gemms") - .Attr("scaling_mode"), + .Attr("scaling_mode"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c8526e20c0..f7577c24f3 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -40,5 +40,28 @@ enum class QuantizeLayout { ROWWISE_COLWISE, }; +enum class JAXX_Scaling_Mode : int64_t { + NO_SCALING = 0, + DELAYED_TENSOR_SCALING = 1, + MXFP8_1D_SCALING = 2, +}; + +static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { + switch (mode) { + case JAXX_Scaling_Mode::NO_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::MXFP8_1D_SCALING: + return NVTEScalingMode::NVTE_MXFP8_1D_SCALING; + break; + default: + NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); + break; + } +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 03855753cf..e23e42f528 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -14,7 +14,8 @@ namespace jax { pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training) { auto input_shape = std::vector{batch_size, hidden_size}; @@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - auto _scaling_mode = static_cast(scaling_mode); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated - if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { int temp = 1; output_tensor.set_columnwise_data(static_cast(&temp), out_dtype, input_shape); } @@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, @@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, int scaling_mode, bool is_2x) { + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); - auto _scaling_mode = static_cast(scaling_mode); auto _norm_type = static_cast(norm_type); auto _is_2x = static_cast(is_2x); @@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); if (is_fp8_dtype(out_dtype)) { @@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc scale_inv_buf->dimensions().back()}); } - if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); cudaMemsetAsync(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); @@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, @@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("zero_centered_gamma") .Attr("epsilon") .Attr("sm_margin") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index ebdfe461c7..5c165cccb6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("RMSNorm", NVTE_Norm_Type::RMSNorm) .export_values(); - pybind11::enum_(m, "NVTE_Scaling_Mode", pybind11::module_local()) - .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) - .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) - .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) + pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) + .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) + .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .export_values(); pybind11::enum_(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index b48ee8a9b9..481dbd7cdf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -13,7 +13,9 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ int temp = 0; auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); - auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); - output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); + + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + TensorWrapper dummy_workspace; nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), @@ -44,8 +73,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T Result_Type output_buf, Result_Type output_trans_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, - int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, - int64_t flatten_axis) { + JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -54,7 +83,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); - auto scaling_mode = static_cast(scaling_mode_enum); auto const quantize_layout = static_cast(quantize_layout_enum); auto *output = output_buf->untyped_data(); @@ -77,14 +105,14 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T std::vector workspace_shape{workspace_dims.begin(), workspace_dims.end()}; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); @@ -109,14 +137,16 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (quantize_layout == QuantizeLayout::COLWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -153,7 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("q_layout") .Attr("is_dbias") .Attr("flatten_axis"), diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a944848881..45ff8d7ed9 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -361,7 +361,7 @@ def generate_quantize_meta(quantizer_name: str): ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index b1e9ba03b4..d68eb3c6c2 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -84,8 +84,8 @@ def _dq_func_block_scaling(scaled_tensor): ) funcs = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, - ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling, + ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, + ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling, } @staticmethod diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 7d144aa69d..98f280b9a9 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -94,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ gpu_arch = get_device_compute_capability(gpu_id) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _check_block_scaling_fp8_support(gpu_arch) return (False, "Unsupported scaling_mode!") def is_fp8_available( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, ) -> Tuple[bool, str]: """Check if FP8 is available for the given scaling mode and GPU. @@ -179,9 +179,9 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ValueError: If the recipe type is not supported """ if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.NVTE_DELAYED_TENSOR_SCALING + return ScalingMode.DELAYED_TENSOR_SCALING if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.NVTE_MXFP8_1D_SCALING + return ScalingMode.MXFP8_1D_SCALING raise ValueError("Invalid fp8_recipe!") @@ -217,7 +217,7 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False IF_QUANTIZE_2X: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING + SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 @@ -253,11 +253,11 @@ def finalize(cls) -> None: cls.MARGIN = 0.0 cls.FP8_FORMAT = recipe.Format.HYBRID cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.FP8_2X_ACC_FPROP = False cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.IF_QUANTIZE_2X = False # DelayedScaling cls.AMAX_HISTORY_LEN = 1024 diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index bd7045453b..b57043a034 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) @@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): q_layout: Quantization axis (default: ROWWISE_COLWISE) """ - scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING + scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE def get_data_layout(self) -> str: @@ -530,8 +530,8 @@ class QuantizerFactory: """ quantizer_type_map = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, + ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, } @staticmethod @@ -556,8 +556,9 @@ def create( A single quantizer or tuple of quantizers """ # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted - # assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING - if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING): + assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" + # import pdb; pdb.set_trace() + if scaling_mode == ScalingMode.NO_SCALING: quantizers = [None] * n_quantizers else: quantizers = [] @@ -651,4 +652,4 @@ def create_set( return q_set[0] if len(q_set) == 1 else tuple(q_set) -noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING) +noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 95bbc9bb41..34f63a994c 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -19,6 +19,8 @@ from jax.tree_util import register_pytree_node_class import jax.numpy as jnp +from transformer_engine_jax import JAXX_Scaling_Mode + __all__ = ["ScalingMode"] @@ -216,25 +218,20 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) -# (Phuong: Map the NVTEScalingMode value to the ScalingMode - - @dataclass(frozen=True) @register_pytree_node_class class ScalingMode(Enum): """Enumeration of tensor scaling modes with their corresponding metadata implementations. This class defines the available scaling modes for tensor quantization: - - NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - - NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - - NVTE_INVALID_SCALING: Invalid scaling mode - - NVTE_NO_SCALING: No scaling applied + - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales + - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales + - NO_SCALING: No scaling applied """ - NVTE_DELAYED_TENSOR_SCALING = 0 - NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 100 - NVTE_NO_SCALING = 1000 + NO_SCALING = JAXX_Scaling_Mode.NO_SCALING + DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING + MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -329,8 +326,8 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR - ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c34a235d94..0ef30f4728 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -236,13 +236,12 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # TODO(Phuong): Handle padding !? scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv - # TODO(Phuong): constaint padded scale_inv? return ScaledTensor1x( data=data, scale_inv=scale_inv, From 20e95ba3d3f7af540678af74dc5960332221a0b3 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:05:13 -0700 Subject: [PATCH 48/53] [PyTorch] Explicitly specify quantized tensor usages needed for linear op backward (#1646) Explicitly specify quantized tensor usages needed for linear op backward Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 9 ++++-- .../pytorch/ops/basic/basic_linear.py | 32 +++++++++++-------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9c1a842cd8..ddc79af426 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1420,15 +1420,17 @@ def test_activation( test_device=device, test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, + test_is_fp8=quantized_compute, requires_grad=False, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1459,6 +1461,7 @@ def test_activation( swiglu=te_ops.SwiGLU, )[activation] forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantized_compute), make_op(), te_ops.Quantize(forward=quantized_compute, backward=False), ) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb93eb5e6b..b451acea9a 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -523,7 +523,7 @@ def _functional_forward( # Configure input tensor for backward pass if own_quantized_x_local: - x_local.update_usage(rowwise_usage=False) + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save @@ -679,7 +679,9 @@ def _functional_backward( quantizer=input_quantizer, ) else: - if not isinstance(x_local, QuantizedTensor): + if isinstance(x_local, QuantizedTensor): + x_local.update_usage(columnwise_usage=True) + else: x_local = input_quantizer(x_local) x = x_local else: @@ -706,15 +708,19 @@ def _functional_backward( raise ValueError("Weight tensor is required to compute input grad") w = weight w_is_quantized = isinstance(w, QuantizedTensor) - if with_quantized_compute and not w_is_quantized: - if weight_quantizer is None: - raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(columnwise=True) - w = weight_quantizer(w) - elif not with_quantized_compute and w_is_quantized: - w = w.dequantize() - if not with_quantized_compute and w.dtype != dtype: - w = w.to(dtype=dtype) + if with_quantized_compute: + if w_is_quantized: + w.update_usage(columnwise_usage=True) + else: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + else: + if w_is_quantized: + w = w.dequantize(dtype=dtype) + elif w.dtype != dtype: + w = w.to(dtype=dtype) # Synchronize tensor-parallel communication _wait_async(dy_async) @@ -867,8 +873,8 @@ def op_forward( # Configure quantizers # Note: We cache the quantized input for backward pass, # but discard the quantized weights. - input_quantizer.set_usage(columnwise=weight_requires_grad) - weight_quantizer.set_usage(columnwise=False) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) # Get autocast dtype if needed dtype = None From 0da604497446c5ef4cf28e35c7362f5ad913e434 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:05:24 -0700 Subject: [PATCH 49/53] [PyTorch] Debug checkpointing with te.Sequential (#1629) * Debug checkpointing with te.Sequential Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 116 ++++++++++++++++++ transformer_engine/pytorch/ops/op.py | 71 ++++++----- .../pytorch/tensor/mxfp8_tensor.py | 5 +- 3 files changed, 158 insertions(+), 34 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ddc79af426..59af228861 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import Iterable +import io import math from typing import Optional @@ -1885,3 +1886,118 @@ def test_backward_linear_add( torch.testing.assert_close(y2_test, y2_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + +class TestCheckpointing: + """Tests for checkpointing""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear( + self, + *, + pre_checkpoint_steps: int = 2, + post_checkpoint_steps: int = 2, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Check checkpointing with linear op""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + + # Construct model + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_save = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25) + + # Warmup training steps + for _ in range(pre_checkpoint_steps): + x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + dy = torch.randn(out_shape, dtype=dtype, device=device) + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(x) + y.backward(dy) + optim_save.step() + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save( + {"model": model_save.state_dict(), "optim": optim_save.state_dict()}, + byte_stream, + ) + checkpoint_bytes = byte_stream.getvalue() + del byte_stream + + # Synthetic data for evaluation + xs_save = [ + torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + for _ in range(post_checkpoint_steps) + ] + with torch.no_grad(): + xs_load = [x.clone().requires_grad_() for x in xs_save] + dys = [ + torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps) + ] + + # Training steps with original model + ys_save = [] + for i in range(post_checkpoint_steps): + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(xs_save[i]) + y.backward(dys[i]) + optim_save.step() + ys_save.append(y) + + # Load checkpoint + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_load = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25) + state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False) + model_load.load_state_dict(state_dict["model"]) + optim_load.load_state_dict(state_dict["optim"]) + + # Training steps with loaded model + ys_load = [] + for i in range(post_checkpoint_steps): + optim_load.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_load(xs_load[i]) + y.backward(dys[i]) + optim_load.step() + ys_load.append(y) + + # Check that original and loaded model match exactly + tols = {"rtol": 0, "atol": 0} + for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): + torch.testing.assert_close(param_load, param_save, **tols) + torch.testing.assert_close(param_load.grad, param_save.grad, **tols) + for y_load, y_save in zip(ys_load, ys_save): + torch.testing.assert_close(y_load, y_save, **tols) + for x_load, x_save in zip(xs_load, xs_save): + torch.testing.assert_close(x_load.grad, x_save.grad, **tols) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 2e212e15f4..ad32055479 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -19,6 +19,7 @@ DelayedScalingRecipeState, FP8GlobalStateManager, RecipeState, + fp8_autocast, ) from ..tensor import Quantizer @@ -508,7 +509,7 @@ def forward( def get_extra_state(self) -> torch.Tensor: """Serialize extra state - Contains metadata for FP8 casting. + Contains metadata for quantization recipe. """ @@ -540,23 +541,27 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: dst.copy_(src, non_blocking=True) return dst - # Store FP8 state + # Store quantizer state if needed state = {} for mode in ("forward", "backward"): - # Get state for a given FP8 tensor - if self.num_quantizers(mode) == 0: + # Skip if op has no quantizer state + if self._fp8_metas is None or self._fp8_metas[mode] is None: continue - fp8_meta = self.get_fp8_meta(mode) + + # Quantizer state + fp8_meta = self._fp8_metas[mode] state[mode] = {} + state[mode]["recipe"] = fp8_meta["recipe"] - # Store tensors - if "scaling_fwd" in fp8_meta: - state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) - if "scaling_bwd" in fp8_meta: - state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) + # Copy tensors to CPU and store + if state[mode]["recipe"].delayed(): + if mode == "forward": + state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) + state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) + if mode == "backward": + state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) + state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items extra = {} @@ -595,37 +600,37 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) dst.copy_(src, non_blocking=True) - # Load FP8 state + # Load quantizer state if needed for mode in ("forward", "backward"): - # Get state for a given FP8 tensor + # Skip if checkpoint has no quantizer state if mode not in state: continue - if self.num_quantizers(mode) == 0: - continue - fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue - # Load extra state + # Get op's quantizer state, initializing if needed + if self._fp8_metas is None or self._fp8_metas[mode] is None: + with fp8_autocast(fp8_recipe=state[mode]["recipe"]): + self._reset_quantization_recipe_state() + fp8_meta = self._fp8_metas[mode] + + # Load extra items + fp8_meta["recipe"] = state[mode]["recipe"] fp8_meta.update(state[mode]["extra_fp8_variables"]) - if "amax_history_fwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0) - elif "amax_history_bwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0) if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Load tensors - fp8_meta = self.get_fp8_meta(mode) - if "scaling_fwd" in fp8_meta: - fp8_meta_fwd = fp8_meta["scaling_fwd"] - copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) - if "scaling_bwd" in fp8_meta: - fp8_meta_bwd = fp8_meta["scaling_bwd"] - copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) + if state[mode]["recipe"].delayed(): + if mode == "forward": + copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale) + copy_tensor( + state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history + ) + if mode == "backward": + copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale) + copy_tensor( + state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history + ) # Finish CPU-GPU memory transfers torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 843c7936f2..2694319a0f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -347,6 +347,7 @@ def _make_in_reduce_ex( columnwise_scale_inv: torch.Tensor, fp8_dtype: TE_DType, dtype: torch.dtype, + shape: torch.shape, ) -> MXFP8Tensor: """Build MXFP8Tensor, for use in __reduce__ @@ -361,10 +362,11 @@ def _make_in_reduce_ex( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling""" return ( MXFP8Tensor._make_in_reduce_ex, ( @@ -374,6 +376,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._fp8_dtype, self.dtype, + self.shape, ), ) From 0bf7844b3e45b763b073431070cac8185e7dbe48 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 10 Apr 2025 00:48:37 +0000 Subject: [PATCH 50/53] Review suggestions from @timmoon10 Signed-off-by: Tim Moon --- tests/pytorch/test_float8_current_scaling_exact.py | 2 +- tests/pytorch/test_float8blockwisetensor.py | 7 +++---- transformer_engine/pytorch/csrc/common.h | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index b8e5b936d1..8911ecc159 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -82,7 +82,7 @@ def _get_sum_abs_error(a, b): @staticmethod def _get_mean_abs_relative_error(a, b): - error = torch.where(b == 0, 0.0, torch.abs((a - b) / b)) + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) return torch.mean(error) @staticmethod diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index c42e1891e8..ccb67f2994 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -110,11 +110,10 @@ def _test_quantize_dequantize( dims = _to_list(dims) # Initialize random data + # Note: Make sure values are not all close to zero, or else + # test may pass trivially. x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - if x_ref.numel() == 1: - # Reduce the chance that with trigger the trivial pass - # logic when all elements of x are close to zero. - x_ref.view(1)[0] = 42.1415 + x_ref.view(1)[0] = 0.75 x_ref_cuda = x_ref.to("cuda") # Cast to FP8 and back diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2d02a7d920..338f1fcbb1 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -168,13 +168,12 @@ class Float8BlockQuantizer : public Quantizer { // Which float8 type is used for q data. DType dtype; + private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; - - private: int block_scaling_dim = 2; public: From 7efac72ec7692979a86e63b2b1d60c02f752ad8a Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 9 Apr 2025 17:26:56 -0700 Subject: [PATCH 51/53] Update CPP tests. Signed-off-by: Keith Wyss --- .../cpp/operator/test_cast_float8blockwise.cu | 32 +++++++++++++------ tests/cpp/test_common.cu | 7 +--- tests/cpp/test_common.h | 14 ++------ 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index cc27f72769..10b52e065f 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -19,6 +19,12 @@ using namespace test; namespace { +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + constexpr size_t kBlockLen = 128; enum ProcessingMethod { @@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector ref_output = std::make_unique(rows * cols); @@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector(&input, fill_case); fillUniform(&grad); + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); Tensor workspace; switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, Tensor input("input", shape, itype); Tensor grad("grad", shape, itype); Tensor output_c("output_c", shape, otype, rowwise, colwise, - opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); Tensor output_dbias("output_dbias", {cols}, itype); std::unique_ptr ref_output = std::make_unique(rows * cols); @@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, fillUniform(&grad); Tensor workspace; + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { - {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, - {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, + {1, 16}, {65, 96}, {256, 256}, {993, 512}, + {256, 65536}, {4096, 1632}, {1024, 1}, + {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, }; std::vector input_scenarios = { @@ -429,6 +441,8 @@ std::vector Activation_types = { std::vector amax_epsilons = { 0.0f, + 1.0f, // Make large to be observable. + }; } // namespace @@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); @@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 071c2186e0..61d3075265 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -216,8 +216,7 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode, - const QuantizationOptions* q_opts) { + const NVTEScalingMode &scaling_mode) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -328,10 +327,6 @@ Tensor::Tensor(const std::string& name, tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } - if (q_opts != nullptr) { - NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation."); - NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation."); - } } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 08df3cf7d1..d5ecc6d0f5 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,29 +95,21 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; -struct QuantizationOptions { - bool force_pow_2_scales = false; - float amax_epsilon = 0.0; - size_t block_scaling_dim = 2u; -}; - class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} From c3ee3d82056c77d60542aef251c34f082744a872 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 10 Apr 2025 11:50:23 +0800 Subject: [PATCH 52/53] Update common.h Signed-off-by: Xin Yao --- transformer_engine/pytorch/csrc/common.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 338f1fcbb1..3b349e7f09 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. DType dtype; - - private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + + private: int block_scaling_dim = 2; public: From 59cb49c4cfdfc56fba52fcc617ccc8b15f1a9a0b Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 10 Apr 2025 13:35:06 +0800 Subject: [PATCH 53/53] Update test_float8blockwisetensor.py Signed-off-by: Xin Yao --- tests/pytorch/test_float8blockwisetensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index ccb67f2994..6d3e879970 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -113,7 +113,7 @@ def _test_quantize_dequantize( # Note: Make sure values are not all close to zero, or else # test may pass trivially. x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_ref.view(1)[0] = 0.75 + x_ref.view(-1)[0] = 0.75 x_ref_cuda = x_ref.to("cuda") # Cast to FP8 and back