diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index cc27f72769..10b52e065f 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -19,6 +19,12 @@ using namespace test; namespace { +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + constexpr size_t kBlockLen = 128; enum ProcessingMethod { @@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector ref_output = std::make_unique(rows * cols); @@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector(&input, fill_case); fillUniform(&grad); + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); Tensor workspace; switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, Tensor input("input", shape, itype); Tensor grad("grad", shape, itype); Tensor output_c("output_c", shape, otype, rowwise, colwise, - opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); Tensor output_dbias("output_dbias", {cols}, itype); std::unique_ptr ref_output = std::make_unique(rows * cols); @@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, fillUniform(&grad); Tensor workspace; + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { - {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, - {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, + {1, 16}, {65, 96}, {256, 256}, {993, 512}, + {256, 65536}, {4096, 1632}, {1024, 1}, + {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, }; std::vector input_scenarios = { @@ -429,6 +441,8 @@ std::vector Activation_types = { std::vector amax_epsilons = { 0.0f, + 1.0f, // Make large to be observable. + }; } // namespace @@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); @@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 071c2186e0..61d3075265 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -216,8 +216,7 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode, - const QuantizationOptions* q_opts) { + const NVTEScalingMode &scaling_mode) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -328,10 +327,6 @@ Tensor::Tensor(const std::string& name, tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } - if (q_opts != nullptr) { - NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation."); - NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation."); - } } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 08df3cf7d1..d5ecc6d0f5 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,29 +95,21 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; -struct QuantizationOptions { - bool force_pow_2_scales = false; - float amax_epsilon = 0.0; - size_t block_scaling_dim = 2u; -}; - class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ae5993eb1e..b423bce53d 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -19,6 +19,7 @@ MXFP8BlockScaling, DelayedScaling, Float8CurrentScaling, + Float8BlockScaling, Format, Recipe, ) @@ -49,6 +50,8 @@ def quantization_recipe() -> Recipe: return MXFP8BlockScaling() if QUANTIZATION == "fp8_cs": return Float8CurrentScaling() + if QUANTIZATION == "fp8_block_scaling": + return Float8BlockScaling() return te.fp8.get_default_fp8_recipe() @@ -85,7 +88,7 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE SEQ_LEN = 32 BATCH_SIZE = 32 diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index b4e2b680b3..632f50e90a 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -28,6 +28,9 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -48,7 +51,7 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -56,4 +59,6 @@ def test_distributed(quantization): pytest.skip(fp8_available) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) _run_test(quantization) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 5a1dc3f732..7bfe506f26 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -27,6 +27,9 @@ # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -55,6 +58,7 @@ class ModelConfig: recipe.DelayedScaling(), recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] # Supported data types @@ -316,9 +320,13 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling() and module == "linear_op": + pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") # Run model with different CUDA graph settings. model_config = model_configs[model_config] kwargs = dict( diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 9a1cfa2db8..ec23cfe8c5 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,21 +8,18 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from transformer_engine.pytorch.utils import get_device_compute_capability from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - return ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ) + supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() + return supported def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index e638fe8c5b..0baee4975d 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -4,11 +4,14 @@ from typing import Tuple import math +import os +import pathlib import pytest import torch import transformer_engine as te import transformer_engine_torch as tex -from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -18,10 +21,29 @@ BlockwiseQuantizerReference, QuantizeResult, ) +from test_float8_current_scaling_exact import ( + TestFP8RecipeLinearBase, + TestFP8RecipeLayerNormLinearBase, +) + +# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() + + +class GetRecipes: -# TODO replace with call to fp8.py when recipe added. -recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 -reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + @staticmethod + def none(): + return None + + @staticmethod + def fp8_blockwise(): + # return default configs + return Float8BlockScaling() def initialize_for_many_scales( @@ -66,35 +88,7 @@ def initialize_for_many_scales( return result -@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize( - "M, N", - [ - # full tile cases - (128, 128), - (256, 256), - (256, 1024), - (1024, 256), - # Padding required cases - (256, 272), - (303, 300), - (305, 256), - # Some larger tiles. - (2000, 2000), - (2048, 2000), - (2000, 1024), - (2048, 1024), - ], -) -@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] -) -@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) -@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) -def test_quantization_block_tiling_versus_reference( +def check_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, N: int, @@ -199,12 +193,90 @@ def test_quantization_block_tiling_versus_reference( [ # full tile cases (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (303, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) @pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (256, 256), + (2048, 1024), + # Padding required cases + (256, 272), + (303, 300), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference_fp32_scales( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"]) @pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) def test_quantization_block_tiling_extrema_versus_reference( @@ -292,3 +364,130 @@ def test_quantization_block_tiling_extrema_versus_reference( atol=0.0, rtol=0.0, ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1.6, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 9741b1258c..8911ecc159 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -82,7 +82,8 @@ def _get_sum_abs_error(a, b): @staticmethod def _get_mean_abs_relative_error(a, b): - return torch.mean(torch.abs((a - b) / b)) + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) @staticmethod def _load_golden_tensor_values(a, b): @@ -97,9 +98,12 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { @@ -437,9 +441,13 @@ def _check_golden_tensor_dumps( fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" + current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index d030426b74..6d3e879970 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -110,7 +110,10 @@ def _test_quantize_dequantize( dims = _to_list(dims) # Initialize random data + # Note: Make sure values are not all close to zero, or else + # test may pass trivially. x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref.view(-1)[0] = 0.75 x_ref_cuda = x_ref.to("cuda") # Cast to FP8 and back @@ -150,6 +153,24 @@ def test_quantize_dequantize_dtypes( ) self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1]) + def test_quantize_dequantize_columnwise_only( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=False, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True + ) + @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 42600e3099..d36da704b0 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import io -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import pytest import torch @@ -158,6 +158,32 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("noop", [True, False]) + def test_quantize_dequantize_noop( + self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool + ) -> None: + noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda") + if noop: + noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda") + dims = 23 + scale: float = 3.5 + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + # if noop, then when we input a different tensor, output should still be x_fp8_orig + x_ref_noop_test = 2 * x_ref.cuda() + x_fp8_orig = x_fp8.clone() + x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor) + if noop_tensor.item() == 1.0: + torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) + else: + torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) + def test_basic_ops( self, dims: DimsType = 23, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 35f65a75f4..7a930b6cde 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -50,6 +50,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) sm_80plus = get_device_compute_capability() >= (8, 0) @@ -104,6 +107,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] @@ -563,6 +567,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -675,6 +681,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -1528,6 +1536,8 @@ def test_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.float8_block_scaling(): + pytest.skip("Grouped linear for FP8 blockwise unsupported.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1723,6 +1733,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.float8_block_scaling(): + pytest.skip("Float8 block scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1933,6 +1945,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 69ac8f7996..d8552d63a4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,6 +46,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -106,6 +109,7 @@ def is_fp8_supported(self): None, # Test non-FP8 recipe.MXFP8BlockScaling(), # Test default recipe.Float8CurrentScaling(), # Test default + recipe.Float8BlockScaling(), # Test default recipe.DelayedScaling(), # Test default recipe.DelayedScaling( # Test most_recent algo amax_history_len=16, @@ -439,6 +443,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -470,6 +476,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -502,6 +510,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -543,10 +553,14 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8(): pytest.skip("Grouped linear does not support MXFP8") if fp8_recipe.float8_current_scaling(): pytest.skip("Grouped linear does not support FP8 current scaling") + if fp8_recipe.float8_block_scaling(): + pytest.skip("Grouped linear does not support FP8 block scaling") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -590,6 +604,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -640,6 +656,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -707,6 +725,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -766,6 +786,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -823,6 +845,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -858,6 +882,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -896,6 +922,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -937,6 +965,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -979,8 +1009,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling(): + pytest.skip("cuda graph not supported for float8_block_scaling recipe") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 708403f911..67f173a4ab 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template @@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b1fe436379..728f8ad147 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -233,10 +233,12 @@ struct Tensor { struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; + NVTETensor noop_tensor = nullptr; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float) // amax_epsilon + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor) // noop_tensor }; }; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 7fa7957fa4..64136b2c43 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,7 +89,7 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. @@ -102,6 +102,16 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream); +/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output quantized tensor. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream); + /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ba47b9d38c..d3ee446f83 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -286,6 +286,12 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigForcePow2Scales = 0, /*! Small value to add to amax for numerical stability */ kNVTEQuantizationConfigAmaxEpsilon = 1, + /*! Noop tensor (containing a scalar). + If the scalar element value = 1, quantization kernel will early exit. + This is a tensor because the flag must be on GPU in order to enable + conditional early even when captured in a static CUDA graph. + */ + kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNumAttributes }; @@ -724,6 +730,12 @@ class QuantizationConfigWrapper { &amax_epsilon, sizeof(float)); } + /*! \brief Set noop tensor pointer */ + void set_noop_tensor(NVTETensor noop_tensor) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, &noop_tensor, + sizeof(NVTETensor)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b676bf6ab0..80857e565c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -5,6 +5,7 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations import warnings +import os from enum import Enum from typing import Literal, Optional, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass @@ -81,6 +82,10 @@ def float8_per_tensor_scaling(self): """Whether the given recipe is per-tensor scaling.""" return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + def float8_block_scaling(self): + """Whether the given recipe is float8 blockwise scaling.""" + return isinstance(self, Float8BlockScaling) + @dataclass() class DelayedScaling(Recipe): @@ -287,3 +292,99 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + + +@dataclass() +class Float8BlockScaling(Recipe): + """ + Use block-wise scaling for FP8 tensors. + + In this strategy, tensors are scaled in blockwise fashion. Values within + each block share a common scaling factor. The block dimensionality + can be configured. The scaling factors are float32 containers. They + will by default be constrained to powers of 2. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), the quantized tensor and its transpose are not numerically + equivalent. Due to this, when Transformer Engine needs both the FP8 tensor + and its transpose (e.g. to calculate both forward and backward pass), + during the quantization both versions are computed from the high precision + input to avoid double quantization errors. + + NOTE: To relax the default constraint that scales be powers of 2, set env variable + NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. + export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 + Or initialize the Recipe with non-default QParams in code for increased control. + + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of gradient tensor dY + x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for x. + w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for w. + grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for grad. + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + """ + + use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" + + fp8_format: Format = Format.E4M3 + fp8_quant_fwd_inp = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + x_block_scaling_dim: int = 1 + w_block_scaling_dim: int = 2 + grad_block_scaling_dim: int = 1 + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" + assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" + assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" + assert not ( + self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." + assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." + assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"x_block_scaling_dim={self.x_block_scaling_dim}, " + f"w_block_scaling_dim={self.w_block_scaling_dim}, " + f"grad_block_scaling_dim={self.grad_block_scaling_dim}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 97df5892b6..706e0bd0b5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -429,6 +429,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(buf, &config_.amax_epsilon, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(buf, &config_.noop_tensor, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -458,6 +461,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(&config_.amax_epsilon, buf, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(&config_.noop_tensor, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 298d087337..3148b4f720 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -29,11 +29,35 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor const bool return_transpose, const bool pow_2_scale, cudaStream_t stream); +// enum class for rowwise usage +enum class FP8BlockwiseRowwiseOption { + // No rowwise data + NONE, + // Rowwise data, scales in GEMM format + ROWWISE + // TODO: FP8 all gather requires some changes. + // 1. Compact scales are better for gathering than the GEMM format. +}; + +// enum class for columnwise usage +// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling +enum class FP8BlockwiseColumnwiseOption { + // No columnwise data + NONE, + // Columnwise data transposed from original shape. + // Scales in GEMM format corresponding to GEMM ingesting transposed column data. + COLUMNWISE_TRANSPOSE + // TODO: FP8 all gather requires some changes. + // 1. The transpose gets in the way of the all gather. + // 2. Compact scales are better for gathering than the GEMM format. +}; + void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, - const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scale, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 732d97999c..91f73dea1e 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -16,11 +16,15 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" #include "common/utils.cuh" namespace transformer_engine { namespace { +using transformer_engine::detail::FP8BlockwiseColumnwiseOption; +using transformer_engine::detail::FP8BlockwiseRowwiseOption; + // clang-format off /* @@ -138,15 +142,17 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); template -__global__ void __launch_bounds__(kThreadsPerBlock) - block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, - OType* const output_t, CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, const size_t row_length, - const size_t num_rows, const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, const float epsilon, - bool return_transpose, bool pow_2_scaling) { +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, + FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scaling) { + bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; + bool return_columnwise_transpose = + columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; + using SMemVec = Vec; using OVec = Vec; union IVec { @@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) __syncthreads(); // Step 2: Cast and store to output_c - { + if (return_rowwise) { constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; @@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Step 3: Transpose, cast and store to output_t - if (return_transpose) { + if (return_columnwise_transpose) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); @@ -389,10 +395,15 @@ namespace transformer_engine::detail { void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, - const bool return_transpose, const bool pow2_scale, - cudaStream_t stream) { + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); - NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + // assert that rowwise_option and columnwise_option are not both NONE + NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE || + columnwise_option != FP8BlockwiseColumnwiseOption::NONE, + "rowwise_option and columnwise_option cannot both be NONE"); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -408,21 +419,24 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor } // Options for scale layout of cuBLAS GEMM kernel. - - NVTE_CHECK(input.shape.size() == output.shape.size(), - "Input and output must have the same shape."); - size_t scale_stride_x = 0; size_t scale_stride_y = 0; - NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); - size_t scale_k = scale_inv.shape[1]; - scale_stride_x = scale_k; - scale_stride_y = 1; - size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; - if (return_transpose) { + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, + "Unexpected rowwise enum value"); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = scale_k; + scale_stride_y = 1; + } + + if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { + NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE, + "Unexpected columnwise enum value"); NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); if (output_t.shape.size() > 0) { @@ -469,10 +483,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(output_t.dptr), reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, - pow2_scale);) // kAligned - ) // OutputType - ) // InputType + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, + columnwise_option, pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 22a50025df..1f146c7a33 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, nullptr, output, - dbias, workspace, stream); + detail::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no NVTE_API_CALL(nvte_quantize_noop); using namespace transformer_engine; + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; @@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, noop, output, - dbias, workspace, stream); + detail::quantize_helper( + input, grad, output, dbias, workspace, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, @@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr const NVTETensor activation_input = nullptr; detail::quantize_helper( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, @@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, @@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, @@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, @@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, @@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index c6a8b0f23c..a599d88530 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1215,9 +1215,9 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) { @@ -1232,6 +1232,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe auto output_tensor = reinterpret_cast(output); auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); + + const QuantizationConfig *quant_config_cpp = + reinterpret_cast(quant_config); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); switch (output_tensor->scaling_mode) { @@ -1263,11 +1269,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - constexpr bool force_pow_2_scales = true; + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, + output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } @@ -1275,12 +1281,18 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - constexpr bool force_pow_2_scales = true; - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() + ? FP8BlockwiseRowwiseOption::ROWWISE + : FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = + output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE + : FP8BlockwiseColumnwiseOption::NONE; + quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 948a13a03e..79d6391e79 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,6 +14,7 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = [ "general_gemm", @@ -112,6 +113,10 @@ def general_gemm( # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + # There is not use_split_accumulator == False + # implementation for Float8BlockwiseQTensorBase GEMM + use_split_accumulator = True args = ( A, transa, # transa diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 338f1fcbb1..3b349e7f09 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. DType dtype; - - private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + + private: int block_scaling_dim = 2; public: diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 1ef6f5258d..bf037fe931 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + // sanity check, since activation fusion is not supported for blockwise quantization yet + // need to raise an error here instead of silently going into act_func with wrong numerics + NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); } else { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2c3ccff154..84e50dea22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob if (te_output.numel() == 0) return out; + QuantizationConfigWrapper quant_config; + quant_config.set_noop_tensor(te_noop.data()); + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; + // this config is used for cs scaling factor computation + // because compute scale is cannot be fused with quantize kernel + // so in nvte_quantize_v2 with current scaling, the quant config is not used again quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index cbdeee5b48..dae6ce42e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -150,6 +150,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -166,15 +167,18 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::cast(mu), py::cast(rsigma)}; @@ -293,6 +297,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -309,15 +314,18 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::none(), py::cast(rsigma)}; diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 9ac6292e53..fbf31a7f5b 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -257,12 +257,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); - NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast(), - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires pow2 scales"); - NVTE_CHECK(quantizer.attr("amax_epsilon").cast() == 0.0, - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires amax_epsilon==0"); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, "Unsupported block scaling dim."); } diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index a873586032..e12990f79c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -69,6 +69,7 @@ std::vector fused_multi_quantize(std::vector input_list, nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); } else { for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + // TODO: switch to nvte_quantize_v2 with advanced numerical options nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e245b788b4..7a1fde164b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -24,10 +24,11 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -937,6 +938,74 @@ def _all_gather_fp8( return out, handle +def _all_gather_fp8_blockwise( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, # pylint: disable=unused-argument + quantizer: Optional[Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: + """ + All-gather FP8 tensor along first dimension for blockwise quantization. + + Returns: quantizer(gather(inp)) + + NOTE: The implementation is not sophisticated enough to honor async_op=True. + In some cases it falls back to synchronous gather and invokes the quantizer. + """ + + # Input tensor attributes + device: torch.device + dtype: torch.dtype + if isinstance(inp, torch.Tensor): + device = inp.device + dtype = inp.dtype + elif isinstance(inp, Float8BlockwiseQTensorBase): + if inp._rowwise_data is not None: + device = inp._rowwise_data.device + elif inp._columnwise_data is not None: + device = inp._columnwise_data.device + else: + raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") + dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " + f"found {inp.__class__.__name__})" + ) + world_size = get_distributed_world_size(process_group) + + # Check that quantizer is valid + if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): + raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") + if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): + raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") + + # Output tensor dims + if out_shape is None: + out_shape = list(inp.size()) + out_shape[0] *= world_size + + # Doing BF16 gather for now as baseline because it's simpler + if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) + out = quantizer(out) + return out, None + # Implementation of fp8 gather needs to account for: + # * Getting columnwise data as a transpose of how it is stored for GEMMS. + # * Gathering non GEMM swizzled scales. + # * Refer to scaffold code when implementing at: + # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 + raise NotImplementedError("fp8 blockwise allgather not yet implemented") + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1075,7 +1144,9 @@ def gather_along_first_dim( async_op: bool = False, quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: - """All-gather tensors and concatenate along first dimension.""" + """ + All-gather tensors and concatenate along first dimension. + """ # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) @@ -1100,6 +1171,16 @@ def gather_along_first_dim( out_shape=out_shape, ) + # FP8 block scaling case, block length = 128 + if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + return _all_gather_fp8_blockwise( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # MXFP8 case if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 38f829c079..c02ff73391 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import itertools import os from contextlib import contextmanager from collections import deque @@ -19,6 +20,7 @@ Format, MXFP8BlockScaling, Float8CurrentScaling, + Float8BlockScaling, ) from .constants import dist_group_type @@ -49,6 +51,17 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available""" + if ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ): + return True, "" + return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + + def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -109,6 +122,8 @@ class FP8GlobalStateManager: skip_fp8_weight_update_tensor = None mxfp8_available = None reason_for_no_mxfp8 = "" + fp8_block_scaling_available = None + reason_for_no_fp8_block_scaling = None @classmethod def reset(cls) -> None: @@ -134,6 +149,8 @@ def reset(cls) -> None: cls.skip_fp8_weight_update_tensor = None cls.mxfp8_available = None cls.reason_for_no_mxfp8 = "" + cls.fp8_block_scaling_available = None + cls.reason_for_no_fp8_block_scaling = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -161,6 +178,15 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @classmethod + def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: + """Return if Float8 block scaling support is available.""" + if cls.fp8_block_scaling_available is None: + cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( + check_fp8_block_scaling_support() + ) + return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -434,6 +460,9 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8 + if isinstance(fp8_recipe, Float8BlockScaling): + fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() + assert fp8_block_available, reason_for_no_fp8_block @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -786,8 +815,10 @@ def create( cls = MXFP8BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState + elif recipe.float8_block_scaling(): + cls = Float8BlockScalingRecipeState else: - raise ValueError("{recipe.__class__.__name__} is not supported") + raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( recipe, mode=mode, @@ -928,3 +959,108 @@ def make_quantizers(self) -> list: from .tensor.mxfp8_tensor import MXFP8Quantizer return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + + +class Float8BlockScalingRecipeState(RecipeState): + """Configuration for Float8BlockScaling quantization. + + Float8BlockScaling quantization does not require state, + but different quantizers use different modes. + """ + + recipe: Float8BlockScaling + mode: str + qx_dtype: tex.DType + qw_dtype: tex.DType + qgrad_dtype: tex.DType + + def __init__( + self, + recipe: Float8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.qx_dtype = get_fp8_te_dtype(recipe, True) + self.qw_dtype = get_fp8_te_dtype(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_blockwise_tensor import Float8BlockQuantizer + + if self.mode == "forward": + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward, and doesn't play nicely with QuantizeOp, + # which is not associated with a GEMM. + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qw_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] + ) + ) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] + ) + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 31a464caad..65f47a0817 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -23,6 +23,7 @@ MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -34,8 +35,10 @@ ) from ..constants import dist_group_type from ..tensor import QuantizedTensor, Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["initialize_ub", "destroy_ub"] @@ -516,6 +519,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8CurrentScalingRecipeState ): return + if recipe.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -858,7 +865,13 @@ def grad_output_preprocess( if ctx.ub_overlap_ag: # Quantize the gradient if needed if not isinstance( - grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + grad_output, + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), ): grad_output = quantizer(grad_output) @@ -876,11 +889,21 @@ def grad_output_preprocess( # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: - if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance(quantizer, Float8BlockQuantizer): + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9cd52b1e5..1ea66a7f2c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -91,6 +91,8 @@ def forward( # TODO Support Float8 Current Scaling # pylint: disable=fixme if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f49bad48c3..df3ae05f31 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -57,9 +57,11 @@ restore_from_saved, ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + from ..cpp_extensions import ( general_gemm, ) @@ -138,11 +140,6 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") - # Avoid quantized norm kernel if norm output will be returned - with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered - ) - tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output @@ -175,6 +172,18 @@ def forward( columnwise_usage = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + # Avoid quantized norm kernel if norm output will be returned + # or if a gather of ln_out must be in high precision. + force_hp_blockwise_ln_out_gather = ( + fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + with_quantized_norm = ( + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not force_hp_blockwise_ln_out_gather + ) + # Apply normalization nvtx_range_push(f"{nvtx_label}.norm") ln_out, mu, rsigma = apply_normalization( @@ -211,7 +220,7 @@ def forward( ln_out_total = input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: @@ -317,6 +326,7 @@ def forward( ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: @@ -327,6 +337,10 @@ def forward( if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # For force_hp_blockwise_ln_out_gather, we should + # be saving the unquantized ln_out to ctx. + assert not force_hp_blockwise_ln_out_gather + # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) @@ -605,11 +619,14 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + # async_op is not compatible with high precision gather since + # gather_along_first_dim does not offer callback chaining. + gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -690,6 +707,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather may have been done in BF16 + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7dae573688..e51fe43cc0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -52,7 +52,6 @@ in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, ) - from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing @@ -62,6 +61,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..tensor.quantized_tensor import ( @@ -104,17 +104,19 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } # no activation fusion written yet - # Per-tensor current scaling: [] - return { - "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), - "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, None), - "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, None), - } + # Per-tensor current scaling or fp8 blockwise scaling: [] + if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, None), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, None), + } + raise NotImplementedError(f"Unhandled recipe type {recipe}") def _act_func(activation: str, recipe: Optional[Recipe] = None): @@ -122,7 +124,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None): # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] - # Per-tensor current scaling: [] + # Per-tensor current scaling or fp8 blockwise scaling: [] funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -214,12 +216,20 @@ def forward( with_quantized_norm = ( fp8 and not return_layernorm_output and not return_layernorm_output_gathered ) + if isinstance(fc1_input_quantizer, Float8BlockQuantizer): + # Kernels not available for norm fusion. + with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_rs = ub_overlap_rs and is_grad_enabled backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + # TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe + force_hp_fc1_input_gather = ( + fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + # Configure quantizer for norm output if fp8: if fc1_input_quantizer is None: @@ -261,12 +271,13 @@ def forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8: - ln_out = fc1_input_quantizer(ln_out) + if not force_hp_fc1_input_gather: + ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = fc1_input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -282,7 +293,10 @@ def forward( quantizer=(fc1_input_quantizer if fp8 else None), ) else: - if fp8 and not with_quantized_norm: + # NOTE: force_hp_fc1_input_gather is redundant with else, but + # here for clarity. We should not quantize ln_out if bwd needs + # to gather in hp. + if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) ln_out_total = ln_out @@ -336,6 +350,7 @@ def forward( # - bias_gelu_fusion - only for full precision. # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer if activation != "gelu": + # blockwise scaled gemms don't support gemm_gelu_fusion in fwd. gemm_gelu_fusion = bias_gelu_fusion = False else: if fp8: @@ -376,7 +391,12 @@ def forward( act_out, _, fc1_out, _ = fc1_outputs else: fc1_out, *_ = fc1_outputs - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + # tex.quantize does not support GELU fusion for blockwise. + act_out = activation_func(fc1_out, None) + act_out = tex.quantize(act_out, fc2_input_quantizer) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -462,6 +482,8 @@ def forward( if not return_layernorm_output: clear_tensor_data(ln_out) ln_out = None + elif force_hp_fc1_input_gather: + assert not isinstance(ln_out, QuantizedTensor) if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None @@ -490,6 +512,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -505,6 +528,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -696,11 +720,12 @@ def backward( else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) + gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) else: ln_out_total = ln_out @@ -712,12 +737,13 @@ def backward( ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - # There are 5 possible fusion paths + # There are 6 possible fusion paths # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) ) @@ -753,6 +779,9 @@ def backward( if isinstance(grad_output, QuantizedTensor): grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + grad_arg = True + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + grad_arg = False fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( act_out, grad_output, @@ -764,14 +793,18 @@ def backward( ), quantization_params=None, # wgrad in high precision layout="NT", - grad=True, - bias=fc2_bias if fc2_bias_grad is None else None, + grad=grad_arg, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if fc2_bias_grad is None: + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None: + # BGRAD not fused with GEMM for float8 blockwise gemm. + fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ clear_tensor_data(act_out) # bias computation @@ -808,7 +841,14 @@ def backward( ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + # TODO float8 blockwise current scaling has no bgrad fusion for now + if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer): + fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) + dact = ctx.grad_fc1_output_quantizer(dact) + else: + fc1_bias_grad, dact = tex.bgrad_quantize( + dact, ctx.grad_fc1_output_quantizer + ) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 @@ -904,6 +944,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.fc1_input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): @@ -1556,7 +1603,8 @@ def _get_quantizers(self, fp8_output): fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( - rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + rowwise=True, + columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ca9dd29043..2887b2e452 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -60,9 +60,10 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + __all__ = ["Linear"] @@ -130,6 +131,10 @@ def forward( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False + # TODO(kwyss): Support FP8 allgather for FP8 block quantization. + force_hp_input_gather = ( + fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. if fp8: assert_dim_for_fp8_exec(inputmat, weight) if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( @@ -143,19 +148,27 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - if not isinstance(inputmat, QuantizedTensor): - columnwise_usage = backward_needs_input and isinstance( - input_quantizer, MXFP8Quantizer + if force_hp_input_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, tp_group, quantizer=input_quantizer + ) + else: + if not isinstance(inputmat, QuantizedTensor): + columnwise_usage = backward_needs_input and isinstance( + input_quantizer, MXFP8Quantizer + ) + # force_hp_input_gather should enforce this + assert not isinstance(input_quantizer, Float8BlockQuantizer) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + inputmat = input_quantizer(inputmat) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, ) - input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - inputmat = input_quantizer(inputmat) - own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=False) - inputmat_total, _ = gather_along_first_dim( - inputmat, - tp_group, - quantizer=input_quantizer, - ) else: if ( FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() @@ -277,6 +290,8 @@ def forward( # can be allgathered. if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if force_hp_input_gather: + assert not isinstance(inputmat, QuantizedTensor) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. @@ -323,8 +338,9 @@ def forward( ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -520,11 +536,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + gather_quantizer = None if ctx.force_hp_input_gather else quantizer inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -610,6 +627,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None + if ctx.input_quantizer is not None and not isinstance( + inputmat_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data if isinstance(inputmat_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index b451acea9a..f4d4254537 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -23,6 +23,7 @@ from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext @@ -483,6 +484,12 @@ def _functional_forward( "Attempting to generate MXFP8 output tensor, " "but GEMM with MXFP8 output is not supported" ) + if isinstance(output_quantizer, Float8BlockQuantizer): + raise RuntimeError( + "Attempting to generate Float8BlockQuantized output tensor, " + "but GEMM with Float8BlockQuantized output is not supported" + ) + if output_quantizer is not None: output_quantizer.set_usage(rowwise=True, columnwise=False) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index ad32055479..802f4c25e3 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -17,6 +17,7 @@ from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, fp8_autocast, @@ -219,6 +220,11 @@ def _reset_quantization_recipe_state( if num_quantizers == 0: continue + if recipe.float8_block_scaling(): + raise NotImplementedError( + "Fusible operations do not support FP8 block scaling recipe" + ) + # Construct quantization recipe state recipe_state = RecipeState.create( recipe, @@ -260,8 +266,13 @@ def _update_quantization_recipe_state( continue recipe_state = self._fp8_metas[mode][fp8_meta_key] need_to_reset_recipe_state = ( - recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) - ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) + or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + or ( + recipe.float8_block_scaling() + and not isinstance(recipe_state, Float8BlockScalingRecipeState) + ) + ) if need_to_reset_recipe_state: self._reset_quantization_recipe_state(recipe=recipe) return diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 9135237854..7dc380606d 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase: def __new__( cls, *args, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, @@ -71,10 +71,16 @@ def get_metadata(self) -> Dict[str, Any]: def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: - """Prepare the tensor base for saving for backward""" + """ + Prepare the tensor base for saving for backward + + This does not clear the tensors currently, because with PP config + that clears the weight cache between micro-batches. If the rowwise + data is not required for backward, this is a possible memory + pessimization, but is consistent with the other quantized tensor + classes. + """ tensors = [self._rowwise_data, self._columnwise_data] - self._rowwise_data = None - self._columnwise_data = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 138d1fd29e..695c5ffb8c 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -44,7 +44,6 @@ def __init__( block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - assert rowwise self.dtype = fp8_dtype self.block_len = 128 self.force_pow_2_scales = force_pow_2_scales @@ -168,6 +167,11 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: colwise_shape.append(shape[i]) return tuple(colwise_shape) + # TODO(kwyss): With FP8 gather support, we need to implement a + # shape/layout/swizzle check to know whether FP8 gather works + # cleanly by stacking data without aliasing tiles and whether + # the scales also stack on the proper dimensions. + def make_empty( self, shape: Iterable[int], @@ -181,13 +185,16 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, - dtype=torch.float32, - device=device, - ) + data = None + scale_inv = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -489,7 +496,6 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv - dst.dtype = src.dtype # Check that tensor dimensions match if (