diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7e6605c9fe..eabd1b2a3f 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -448,8 +448,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index ba62d964fa..839bc3175e 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -416,8 +416,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 1300be01bb..df78157cc5 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -327,8 +327,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 4022cb7493..435750a1db 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -306,8 +306,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 3253861484..3fbfb9cf5c 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eaededc4..1206012195 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail " python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index cc27f72769..10b52e065f 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -19,6 +19,12 @@ using namespace test; namespace { +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + constexpr size_t kBlockLen = 128; enum ProcessingMethod { @@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector ref_output = std::make_unique(rows * cols); @@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector(&input, fill_case); fillUniform(&grad); + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); Tensor workspace; switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, Tensor input("input", shape, itype); Tensor grad("grad", shape, itype); Tensor output_c("output_c", shape, otype, rowwise, colwise, - opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); Tensor output_dbias("output_dbias", {cols}, itype); std::unique_ptr ref_output = std::make_unique(rows * cols); @@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, fillUniform(&grad); Tensor workspace; + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { - {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, - {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, + {1, 16}, {65, 96}, {256, 256}, {993, 512}, + {256, 65536}, {4096, 1632}, {1024, 1}, + {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, }; std::vector input_scenarios = { @@ -429,6 +441,8 @@ std::vector Activation_types = { std::vector amax_epsilons = { 0.0f, + 1.0f, // Make large to be observable. + }; } // namespace @@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); @@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 071c2186e0..61d3075265 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -216,8 +216,7 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode, - const QuantizationOptions* q_opts) { + const NVTEScalingMode &scaling_mode) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -328,10 +327,6 @@ Tensor::Tensor(const std::string& name, tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } - if (q_opts != nullptr) { - NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation."); - NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation."); - } } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 08df3cf7d1..d5ecc6d0f5 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,29 +95,21 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; -struct QuantizationOptions { - bool force_pow_2_scales = false; - float amax_epsilon = 0.0; - size_t block_scaling_dim = 2u; -}; - class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4dc07a2eea..8917e92465 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -48,21 +48,21 @@ LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) + supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING) + supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) def is_shape_supported_by_mxfp8(input_shape): try: if isinstance(input_shape, type(pytest.param(0))): input_shape = input_shape.values[0] - ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) + ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True except: # get_scale_shapes will raise an exception if the shape is not supported @@ -170,7 +170,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, ) quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) @@ -198,7 +198,7 @@ def test_act_forward_with_delayed_scaling_fp8( te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=q_layout, ) @@ -223,7 +223,7 @@ def test_act_forward_with_block_scaling_fp8( self.activation_type = activation_type quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) output = tex.act_lu(x, activation_type, quantizer) @@ -345,7 +345,7 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_layout=q_layout, ) @@ -420,7 +420,7 @@ def test_norm_forward_with_delayed_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_layout=q_layout, ) @@ -437,7 +437,7 @@ def test_norm_forward_with_block_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_layout=QuantizeLayout.ROWWISE_COLWISE, ) @@ -493,7 +493,7 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -533,7 +533,7 @@ class TestFusedQuantize: def test_quantize_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis ): - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( + if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") @@ -618,7 +618,7 @@ def test_quantize_dact_dbias_no_quantization( in_dtype=in_dtype, input_shape=input_shape, out_dtype=in_dtype, - scaling_mode=ScalingMode.NVTE_NO_SCALING, + scaling_mode=ScalingMode.NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=QuantizeLayout.ROWWISE, @@ -639,7 +639,7 @@ def test_quantize_dact_dbias_delayed_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -670,7 +670,7 @@ def test_quantize_dact_dbias_mxfp8_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -785,7 +785,7 @@ def ref_func(x, w, bias, data_layout): scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) @@ -830,7 +830,7 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): Test layernorm_dense VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -886,7 +886,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, @@ -916,7 +916,7 @@ def test_layernorm_mlp_grad( Test layernorm_mlp VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -993,7 +993,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 6d4cde364f..476d455a6a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -29,7 +29,7 @@ } is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 4350d5e8f3..cf311ac404 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -36,7 +36,7 @@ is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index b89530c19f..a21583a98c 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -39,7 +39,7 @@ def enable_fused_attn(): is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) QUANTIZE_RECIPES = [] """ Find supported scaling modes""" @@ -313,7 +313,7 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: _, updated_quantize_meta = flax.core.pop( updated_state[0], QuantizeConfig.COLLECTION_NAME ) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ae5993eb1e..b423bce53d 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -19,6 +19,7 @@ MXFP8BlockScaling, DelayedScaling, Float8CurrentScaling, + Float8BlockScaling, Format, Recipe, ) @@ -49,6 +50,8 @@ def quantization_recipe() -> Recipe: return MXFP8BlockScaling() if QUANTIZATION == "fp8_cs": return Float8CurrentScaling() + if QUANTIZATION == "fp8_block_scaling": + return Float8BlockScaling() return te.fp8.get_default_fp8_recipe() @@ -85,7 +88,7 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE SEQ_LEN = 32 BATCH_SIZE = 32 diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index b4e2b680b3..632f50e90a 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -28,6 +28,9 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -48,7 +51,7 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -56,4 +59,6 @@ def test_distributed(quantization): pytest.skip(fp8_available) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) _run_test(quantization) diff --git a/tests/pytorch/references/blockwise_fp8_gemm_reference.py b/tests/pytorch/references/blockwise_fp8_gemm_reference.py new file mode 100644 index 0000000000..5aef986e37 --- /dev/null +++ b/tests/pytorch/references/blockwise_fp8_gemm_reference.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128): + pid = tl.program_id(0) + idx = pid * BLOCK + tl.arange(0, BLOCK) + mask = idx < M * N + + row = idx // N + col = idx % N + + y_offset = row * y_str0 + col * y_str1 + x_offset = row * N + col + s_offset = row * N + col + + y = tl.load(y_ptr + y_offset, mask=mask) + x = tl.load(x_ptr + x_offset, mask=mask) + s = tl.load(s_ptr + s_offset, mask=mask) + + tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask) + + +def fused_fma(y, x, s, BLOCK=128): + """ + Fused multiply-add operation (y = y + x * s). + + PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation). + This function also supports cases where 'y' is non-contiguous in memory. + """ + + assert ( + y.shape == x.shape == s.shape and y.dim() == 2 + ), "All tensors must be 2D with the same shape" + assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous" + + M, N = y.shape + grid = ((M * N + BLOCK - 1) // BLOCK,) + + fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK) + + return y + + +class CuBLASRefBlockwiseGemm: + """ + A cuBLAS compatible reference implementation of subchannel GEMM. + """ + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + demunged_sx: torch.Tensor, + demunged_sw: torch.Tensor, + quant_tile_shape_x: Tuple[int, int], + quant_tile_shape_w: Tuple[int, int], + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + use_split_accumulator: bool = False, + ) -> torch.Tensor: + # demunge scale shapes for cuBLAS + is_a_1d_scaled = quant_tile_shape_x[0] == 1 + is_b_1d_scaled = quant_tile_shape_w[0] == 1 + M, K = qx.shape + N, K = qw.shape + + # mm_tile_shape = (tile_m, tile_n, tile_k) + mm_tile_shape = ( + quant_tile_shape_x[0], + quant_tile_shape_w[0], + quant_tile_shape_w[1], + ) + if bias is not None and bias.numel(): + # To match cuBLAS more closely when bias is applied, + # the reference accumulates into float32, and cast to + # bfloat16 is deferred until after the GEMM. + out_dtype_for_ref = torch.float32 + else: + out_dtype_for_ref = out_dtype + y = self.qgemm_blockwise_2d( + qx, + qw, + out_dtype_for_ref, + demunged_sx, + demunged_sw, + mm_tile_shape, + use_split_accumulator, + is_a_1d_scaled, + is_b_1d_scaled, + ) + if bias is not None and bias.numel(): + y += bias + y = y.to(dtype=out_dtype) + # cublas accumulation first convert to output dtype, then accumulate. + if accumulate: + assert out is not None + y = y + out + else: + assert out is None, "Output tensor should be None when accumulate is False." + + return y + + @classmethod + def qgemm_blockwise_2d( + cls, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + mm_tile_shape: Tuple[int, int, int], + use_split_accumulator: bool, + is_a_1d_scaled: bool, + is_b_1d_scaled: bool, + ) -> torch.Tensor: + """ + Difference between cuBLAS and CUTLASS GEMM implementations: + - cuBLAS accumulation equation: use different equation for each scaling mode. + - For accumulation C in epiloge, it first convert C to output dtype, then accumulate. + """ + + M, K = qx.shape + N, K_w = qw.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + tile_len = 128 + # Calculate grid sizes without padding + grid_m = (M + tile_len - 1) // tile_len + grid_n = (N + tile_len - 1) // tile_len + grid_k = (K + tile_len - 1) // tile_len + + block_m, block_n, block_k = mm_tile_shape + scale_m_per_tile = tile_len // block_m + scale_n_per_tile = tile_len // block_n + assert block_k == tile_len, "block_k must be equal to tile_len" + + # Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM: + # 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers. + # 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # Validate shapes of sx and sw + scale_m_per_tensor = (M + block_m - 1) // block_m + scale_n_per_tensor = (N + block_n - 1) // block_n + assert sx.shape == ( + scale_m_per_tensor, + grid_k, + ), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}" + assert sw.shape == ( + scale_n_per_tensor, + grid_k, + ), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}" + + for i in range(grid_m): + m_start = i * tile_len + m_end = min(m_start + tile_len, M) + m_size = m_end - m_start + + for j in range(grid_n): + n_start = j * tile_len + n_end = min(n_start + tile_len, N) + n_size = n_end - n_start + + y_block = y[m_start:m_end, n_start:n_end] + + for k in range(grid_k): + k_start = k * tile_len + k_end = min(k_start + tile_len, K) + k_size = k_end - k_start + + qx_block = ( + qx[m_start:m_end, k_start:k_end].clone().contiguous() + ) # Shape: [m_size, k_size] + qw_block = ( + qw[n_start:n_end, k_start:k_end].clone().contiguous() + ) # Shape: [n_size, k_size] + + # Extract scaling factors for the current blocks + sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze( + -1 + ) + sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0) + + # Perform qgemm with scaling factors fused in the GEMM + # Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM + one = torch.tensor(1.0, dtype=torch.float32, device=qx.device) + y_partial = torch._scaled_mm( + qx_block, + qw_block.t(), + scale_a=one, + scale_b=one, + out_dtype=torch.float32, + use_fast_accum=not use_split_accumulator, + ) + + # Accumulate the partial result + if is_a_1d_scaled and is_b_1d_scaled: + # 1Dx1D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + # Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM + # y_block.add_(y_partial, alpha=scale.item()) + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + elif not is_a_1d_scaled and is_b_1d_scaled: + # 2Dx1D + # CuBLAS accumulation equation: y += (y * scale_b) * scale_a + y_partial = y_partial * sw_block + fused_fma( + y_block, + y_partial, + sx_block.expand_as(y_partial).contiguous(), + ) + elif is_a_1d_scaled and not is_b_1d_scaled: + # 1Dx2D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + else: + scale = sx_block * sw_block + fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous()) + + y = y.to(out_dtype) + return y diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index b98966f514..f5c9dc0e96 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -49,6 +49,7 @@ def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1) return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + @classmethod def demunge_scale_shape_from_backend( cls, qtensor_shape: Tuple[int, int], diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 5a1dc3f732..7bfe506f26 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -27,6 +27,9 @@ # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -55,6 +58,7 @@ class ModelConfig: recipe.DelayedScaling(), recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] # Supported data types @@ -316,9 +320,13 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling() and module == "linear_op": + pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") # Run model with different CUDA graph settings. model_config = model_configs[model_config] kwargs = dict( diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py new file mode 100644 index 0000000000..ec23cfe8c5 --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -0,0 +1,972 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from references.blockwise_quantizer_reference import CuBLASScaleMunger +from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm + + +def fp8_blockwise_gemm_supported() -> bool: + supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() + return supported + + +def cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + atol: float = 0.0, + rtol: float = 0.0 +): + if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2: + pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2") + if not (is_x_1d_scaled or is_w_1d_scaled): + pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile") + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + if noise_type == "uniform": + x = torch.rand(x_shape, dtype=torch.float32, device=device) * x_magnitude * 2 - x_magnitude + w = torch.rand(w_shape, dtype=torch.float32, device=device) * w_magnitude * 2 - w_magnitude + elif noise_type == "normal": + x = torch.randn(x_shape, dtype=torch.float32, device=device) * x_magnitude + w = torch.randn(w_shape, dtype=torch.float32, device=device) * w_magnitude + else: + assert False + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude + else: + out = None + + assert not (use_bias and use_grad), "Bias grad not supported by GEMM" + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Reference GEMM + ref_gemm = CuBLASRefBlockwiseGemm() + scale_decoder = CuBLASScaleMunger() + qx_data = ( + qx._columnwise_data.view(dtype=x_dtype) + if x_columnwise + else qx._rowwise_data.view(dtype=x_dtype) + ) + qw_data = ( + qw._columnwise_data.view(dtype=w_dtype) + if w_columnwise + else qw._rowwise_data.view(dtype=w_dtype) + ) + ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv + ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv + y_ref = ref_gemm.qgemm( + qx=qx_data, + qw=qw_data, + out_dtype=out_dtype, + demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape + ), + demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape + ), + quant_tile_shape_x=x_quant_tile_shape, + quant_tile_shape_w=w_quant_tile_shape, + bias=bias, + out=out.clone() if accumulate else None, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM" + aux_tensor = torch.randn((M, N), dtype=out_dtype, device=device) if use_gelu else None + aux_tensor_ref = aux_tensor.clone() if use_gelu else None + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + aux_tensor, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y are not the same tensor + assert y_ref is not y, "y_ref and y should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y = torch.where(y.isnan(), torch.zeros_like(y), y) + + if use_gelu: + # Check + if use_grad: + # With use_grad, GEMM should use aux tensor to calculate + # gradient + gelu_ref = tex.dgelu(y_ref, aux_tensor_ref, None) + # TODO: How do we decide whether this is acceptably close? + # Could also try to put the activation inside the reference + # before the output cast to see different tolerances. + torch.testing.assert_close(y, gelu_ref, atol=1e-3, rtol=1e-2) + else: + # aux tensor is pre-gelu aux output. Verify against y_ref. + torch.testing.assert_close(aux_tensor, y_ref, atol=atol, rtol=rtol) + act = torch.nn.GELU() + gelu_ref = act(y_ref) + # gelu_ref = tex.gelu(y_ref, None) + torch.testing.assert_close(y, gelu_ref, atol=atol, rtol=rtol) + else: + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + +def cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED", + expected_err_cls=RuntimeError +): + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + x = torch.rand(x_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + w = torch.rand(w_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + grad = use_grad + gelu_in = None if not use_gelu else torch.randn((M, N), dtype=out_dtype, device=device) + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + with pytest.raises(expected_err_cls, match=expected_err_msg): + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_in, + grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (128, 128, 128), + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 128, 336), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (320, 256, 336), + (1024, 4096, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_shape_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + (320, 256, 336), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 64, 336), + # k > 128 + (256, 256, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_bias( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_bias=True, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["colxrow", "colxcol", "rowxcol"], +) +def test_cublas_gemm_fp8_blockwise_columnwise( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + is_x_columnwise, + is_w_columnwise, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 64, 336), + # k > 128 + (256, 256, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "use_grad", + [ + True, + ], + ids=["grad"], +) +def test_cublas_gemm_fp8_gelu( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad, +): + # NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed + # so the epilogue is disabled on the transformer engine side. + if not use_grad and not (is_x_1d_scaled and not is_w_1d_scaled): + pytest.skip( + "CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)." + ) + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_gelu=True, + use_grad=use_grad, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [False], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_split_accumulator_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_bgrad_not_supported( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # NOTE: BGRAD epilogue is not supported for fp8. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=True, + use_bias=True, + expected_err_msg="Epilogue requested outside of the available", + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no_bias"]) +@pytest.mark.parametrize("use_grad", [True, False], ids=["grad", "no_grad"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_gelu_unsupported_cases_error( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_bias, + use_grad, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + if use_grad and not use_bias and out_dtype == torch.bfloat16: + pytest.skip("DGELU epilogue is supported for bfloat16.") + elif use_grad and not use_bias: + expected_err = "an unsupported value or parameter was passed" + else: + expected_err = "Epilogue requested outside of the available" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=use_grad, + use_bias=use_bias, + use_gelu=True, + expected_err_msg=expected_err, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_illegal_dtype_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # e5m2 by e5m2 not supported. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (False, False), + ], + ids=["2Dx2D"], +) +def test_illegal_2D_by_2D_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # 2D block quantization by 2D block quantization is not supported. + expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg=expected_err_msg, + ) + + +@pytest.mark.parametrize( + "M, K, N, legalX1d, legalX2d", + [ + # M dim unconstrained when X is 2D. + (255, 128, 256, False, True), + # K must be multiple of 16 + (256, 120, 256, False, False), + # N must be a multiple of 8 + (256, 128, 252, False, False), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (False, True), + (True, True), + ], + ids=["1Dx2D", "2Dx1D", "1Dx1D"], +) +def test_unaligned_shapes( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + legalX1d, + legalX2d, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + legal = legalX1d if is_x_1d_scaled else legalX2d + if not legal: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg="dimension requirement", + ) + else: + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + "uniform", # noise type + 1.0, # x_magnitude + 1.0, # w_magnitude + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index e638fe8c5b..0baee4975d 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -4,11 +4,14 @@ from typing import Tuple import math +import os +import pathlib import pytest import torch import transformer_engine as te import transformer_engine_torch as tex -from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -18,10 +21,29 @@ BlockwiseQuantizerReference, QuantizeResult, ) +from test_float8_current_scaling_exact import ( + TestFP8RecipeLinearBase, + TestFP8RecipeLayerNormLinearBase, +) + +# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() + + +class GetRecipes: -# TODO replace with call to fp8.py when recipe added. -recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 -reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + @staticmethod + def none(): + return None + + @staticmethod + def fp8_blockwise(): + # return default configs + return Float8BlockScaling() def initialize_for_many_scales( @@ -66,35 +88,7 @@ def initialize_for_many_scales( return result -@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize( - "M, N", - [ - # full tile cases - (128, 128), - (256, 256), - (256, 1024), - (1024, 256), - # Padding required cases - (256, 272), - (303, 300), - (305, 256), - # Some larger tiles. - (2000, 2000), - (2048, 2000), - (2000, 1024), - (2048, 1024), - ], -) -@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] -) -@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) -@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) -def test_quantization_block_tiling_versus_reference( +def check_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, N: int, @@ -199,12 +193,90 @@ def test_quantization_block_tiling_versus_reference( [ # full tile cases (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (303, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) @pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (256, 256), + (2048, 1024), + # Padding required cases + (256, 272), + (303, 300), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference_fp32_scales( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"]) @pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) def test_quantization_block_tiling_extrema_versus_reference( @@ -292,3 +364,130 @@ def test_quantization_block_tiling_extrema_versus_reference( atol=0.0, rtol=0.0, ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1.6, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 9741b1258c..8911ecc159 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -82,7 +82,8 @@ def _get_sum_abs_error(a, b): @staticmethod def _get_mean_abs_relative_error(a, b): - return torch.mean(torch.abs((a - b) / b)) + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) @staticmethod def _load_golden_tensor_values(a, b): @@ -97,9 +98,12 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { @@ -437,9 +441,13 @@ def _check_golden_tensor_dumps( fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" + current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index d030426b74..6d3e879970 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -110,7 +110,10 @@ def _test_quantize_dequantize( dims = _to_list(dims) # Initialize random data + # Note: Make sure values are not all close to zero, or else + # test may pass trivially. x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref.view(-1)[0] = 0.75 x_ref_cuda = x_ref.to("cuda") # Cast to FP8 and back @@ -150,6 +153,24 @@ def test_quantize_dequantize_dtypes( ) self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1]) + def test_quantize_dequantize_columnwise_only( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=False, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True + ) + @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 42600e3099..d36da704b0 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import io -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import pytest import torch @@ -158,6 +158,32 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("noop", [True, False]) + def test_quantize_dequantize_noop( + self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool + ) -> None: + noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda") + if noop: + noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda") + dims = 23 + scale: float = 3.5 + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + # if noop, then when we input a different tensor, output should still be x_fp8_orig + x_ref_noop_test = 2 * x_ref.cuda() + x_fp8_orig = x_fp8.clone() + x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor) + if noop_tensor.item() == 1.0: + torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) + else: + torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) + def test_basic_ops( self, dims: DimsType = 23, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9c1a842cd8..59af228861 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import Iterable +import io import math from typing import Optional @@ -1420,15 +1421,17 @@ def test_activation( test_device=device, test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, + test_is_fp8=quantized_compute, requires_grad=False, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1459,6 +1462,7 @@ def test_activation( swiglu=te_ops.SwiGLU, )[activation] forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantized_compute), make_op(), te_ops.Quantize(forward=quantized_compute, backward=False), ) @@ -1882,3 +1886,118 @@ def test_backward_linear_add( torch.testing.assert_close(y2_test, y2_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + +class TestCheckpointing: + """Tests for checkpointing""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear( + self, + *, + pre_checkpoint_steps: int = 2, + post_checkpoint_steps: int = 2, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Check checkpointing with linear op""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + + # Construct model + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_save = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25) + + # Warmup training steps + for _ in range(pre_checkpoint_steps): + x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + dy = torch.randn(out_shape, dtype=dtype, device=device) + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(x) + y.backward(dy) + optim_save.step() + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save( + {"model": model_save.state_dict(), "optim": optim_save.state_dict()}, + byte_stream, + ) + checkpoint_bytes = byte_stream.getvalue() + del byte_stream + + # Synthetic data for evaluation + xs_save = [ + torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + for _ in range(post_checkpoint_steps) + ] + with torch.no_grad(): + xs_load = [x.clone().requires_grad_() for x in xs_save] + dys = [ + torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps) + ] + + # Training steps with original model + ys_save = [] + for i in range(post_checkpoint_steps): + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(xs_save[i]) + y.backward(dys[i]) + optim_save.step() + ys_save.append(y) + + # Load checkpoint + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_load = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25) + state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False) + model_load.load_state_dict(state_dict["model"]) + optim_load.load_state_dict(state_dict["optim"]) + + # Training steps with loaded model + ys_load = [] + for i in range(post_checkpoint_steps): + optim_load.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_load(xs_load[i]) + y.backward(dys[i]) + optim_load.step() + ys_load.append(y) + + # Check that original and loaded model match exactly + tols = {"rtol": 0, "atol": 0} + for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): + torch.testing.assert_close(param_load, param_save, **tols) + torch.testing.assert_close(param_load.grad, param_save.grad, **tols) + for y_load, y_save in zip(ys_load, ys_save): + torch.testing.assert_close(y_load, y_save, **tols) + for x_load, x_save in zip(xs_load, xs_save): + torch.testing.assert_close(x_load.grad, x_save.grad, **tols) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 35f65a75f4..7a930b6cde 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -50,6 +50,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) sm_80plus = get_device_compute_capability() >= (8, 0) @@ -104,6 +107,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] @@ -563,6 +567,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -675,6 +681,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -1528,6 +1536,8 @@ def test_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.float8_block_scaling(): + pytest.skip("Grouped linear for FP8 blockwise unsupported.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1723,6 +1733,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.float8_block_scaling(): + pytest.skip("Float8 block scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1933,6 +1945,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 69ac8f7996..d8552d63a4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,6 +46,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -106,6 +109,7 @@ def is_fp8_supported(self): None, # Test non-FP8 recipe.MXFP8BlockScaling(), # Test default recipe.Float8CurrentScaling(), # Test default + recipe.Float8BlockScaling(), # Test default recipe.DelayedScaling(), # Test default recipe.DelayedScaling( # Test most_recent algo amax_history_len=16, @@ -439,6 +443,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -470,6 +476,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -502,6 +510,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -543,10 +553,14 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8(): pytest.skip("Grouped linear does not support MXFP8") if fp8_recipe.float8_current_scaling(): pytest.skip("Grouped linear does not support FP8 current scaling") + if fp8_recipe.float8_block_scaling(): + pytest.skip("Grouped linear does not support FP8 block scaling") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -590,6 +604,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -640,6 +656,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -707,6 +725,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -766,6 +786,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -823,6 +845,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -858,6 +882,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -896,6 +922,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -937,6 +965,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -979,8 +1009,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling(): + pytest.skip("cuda graph not supported for float8_block_scaling recipe") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 708403f911..67f173a4ab 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template @@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b1fe436379..728f8ad147 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -233,10 +233,12 @@ struct Tensor { struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; + NVTETensor noop_tensor = nullptr; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float) // amax_epsilon + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor) // noop_tensor }; }; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f19465c44b..0cd0762ee5 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -52,97 +52,173 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); } +/* Parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ struct GemmParam { - void *A; - void *B; - cublasOperation_t transA; - cublasOperation_t transB; - transformer_engine::DType Atype; - transformer_engine::DType Btype; - void *A_scale_inv; - void *B_scale_inv; - int lda; - int ldb; - - GemmParam(cublasOperation_t transA, cublasOperation_t transB) - : A(nullptr), - B(nullptr), - transA(transA), - transB(transB), - Atype(transformer_engine::DType::kNumTypes), - Btype(transformer_engine::DType::kNumTypes), - A_scale_inv(nullptr), - B_scale_inv(nullptr), - lda(0), - ldb(0) {} + void *A = nullptr; + void *B = nullptr; + cublasOperation_t transA = CUBLAS_OP_N; + cublasOperation_t transB = CUBLAS_OP_N; + transformer_engine::DType Atype = transformer_engine::DType::kNumTypes; + transformer_engine::DType Btype = transformer_engine::DType::kNumTypes; + void *A_scale_inv = nullptr; + void *B_scale_inv = nullptr; + int lda = 0; // A column strides + int ldb = 0; // B column strides }; +/* Populate parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - const int k, const int lda, const int ldb) { + int m, int n, int k) { using namespace transformer_engine; - // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. - // Must either force them both into a common block scaling mode or loosen this - // restriction. - NVTE_CHECK(A.scaling_mode == B.scaling_mode, - "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK( + A.scaling_mode == B.scaling_mode || + (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || + (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), + "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); - GemmParam ret(transA, transB); + GemmParam ret; + + // Device compute capability + const int arch = cuda::sm_arch(); - ret.lda = lda; - ret.ldb = ldb; + // Transpose mode with column-major ordering + bool is_A_transposed = transA == CUBLAS_OP_T; + bool is_B_transposed = transB == CUBLAS_OP_T; - // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases - // or need to be treated as `is_tensor_scaling`. + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.A = A.data.dptr; + ret.transA = transA; + ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - if (transA == CUBLAS_OP_T) { - ret.Atype = A.data.dtype; - } else { - ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; - if (is_fp8_dtype(ret.Atype)) { - int arch = cuda::sm_arch(cuda::current_device()); - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); - ret.A = A.columnwise_data.dptr; - ret.transA = CUBLAS_OP_T; - ret.A_scale_inv = A.columnwise_scale_inv.dptr; - ret.lda = k; - } + ret.lda = is_A_transposed ? k : m; + if (arch < 100 && !is_A_transposed) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } + } else if (is_mxfp_scaling(A.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (is_A_transposed) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); + } + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = transA; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? k : m; + } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (is_A_transposed) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); + } + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.lda % 16) == 0, + "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. + // Smallest supported CType is 2 bytes in this scaling mode. + NVTE_CHECK((m % 8) == 0, + "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } else { + NVTE_ERROR("A has unsupported scaling mode"); + } + + // Configure B matrix + if (is_tensor_scaling(B.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.B = B.data.dptr; + ret.transB = transB; + ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; - if (transB == CUBLAS_OP_T) { - ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; - if (is_fp8_dtype(ret.Btype)) { - int arch = cuda::sm_arch(cuda::current_device()); - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); - ret.B = B.columnwise_data.dptr; - ret.transB = CUBLAS_OP_N; - ret.B_scale_inv = B.columnwise_scale_inv.dptr; - ret.ldb = k; - } + ret.ldb = is_B_transposed ? n : k; + if (arch < 100 && is_B_transposed) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } + } + } else if (is_mxfp_scaling(B.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (is_B_transposed) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = transB; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = is_B_transposed ? n : k; + } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (is_B_transposed) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { - ret.Btype = B.data.dtype; + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + + // Requirements from + // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.ldb % 16) == 0, + "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { + // Observed this requirement only present for B tensor is 1D quantized. + NVTE_CHECK((n % 8) == 0, + "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } } else { - // If not tensor scaling (which includes also high precision types), we need to - // use the proper version of data - // We leave the transA/B values as is, since Blackwell supports transposes - ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; - ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; - ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + NVTE_ERROR("B has unsupported scaling mode"); } + return ret; } @@ -153,18 +229,33 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, - int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, - void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, cudaStream_t stream) { + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + // Tensor dims in row-major order + const int A0 = inputA->flat_first_dim(); + const int A1 = inputA->flat_last_dim(); + const int B0 = inputB->flat_first_dim(); + const int B1 = inputB->flat_last_dim(); + + // GEMM dims in column-major order + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, + "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, + ")"); + const int ldd = m; + // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { return; } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); + const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -226,6 +317,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, param.transA == CUBLAS_OP_N ? k : m, param.lda)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, param.transB == CUBLAS_OP_N ? n : k, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); @@ -249,12 +341,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized - // GEMM types. - // Scaling factors. #if CUDA_VERSION >= 12080 - cublasLtMatmulMatrixScale_t scaling_mode; + cublasLtMatmulMatrixScale_t scaling_mode_a; + cublasLtMatmulMatrixScale_t scaling_mode_b; #endif if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { void *A_scale_inverse = param.A_scale_inv; @@ -266,8 +356,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); #if CUDA_VERSION >= 12080 - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; - } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -276,7 +367,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. if (cublasLtGetVersion() <= 120803) { @@ -285,7 +377,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } -#endif + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { +#if CUDA_VERSION >= 12090 + float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"); + scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; +#else + NVTE_ERROR("FP8 block scaling requires CUDA 12.9+"); +#endif // CUDA_VERSION >= 12090 +#endif // CUDA_VERSION >= 12080 } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + to_string(inputB->scaling_mode) + "."); @@ -293,9 +410,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUDA_VERSION >= 12080 NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); #endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output @@ -305,8 +422,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUDA_VERSION >= 12080 - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + // NOTE: In all current cases where FP8 output is supported, the input is + // scaled identically to the output. + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_D_SCALE_MODE, + &scaling_mode_a, sizeof(scaling_mode_a))); #endif // For FP8 output, cuBLAS requires C_type to match bias_type and // be FP16/BF16 @@ -364,6 +484,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } + if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) || + (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) { + NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS || + epilogue == CUBLASLT_EPILOGUE_DGELU), + "Epilogue requested outside of the available and tested cuBLAS functionality for " + "float8 block scaled GEMM"); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); @@ -411,7 +539,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C @@ -469,35 +596,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const size_t A0 = inputA->flat_first_dim(); - const size_t A1 = inputA->flat_last_dim(); - const size_t B0 = inputB->flat_first_dim(); - const size_t B1 = inputB->flat_last_dim(); - - const int m = transa ? A0 : A1; - const int k = transa ? A1 : A0; - const int n = transb ? B1 : B0; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, 0, 0, false, nullptr, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -525,31 +626,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, + inputCounter, stream); } void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 7fa7957fa4..64136b2c43 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,7 +89,7 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. @@ -102,6 +102,16 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream); +/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output quantized tensor. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream); + /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ba47b9d38c..d3ee446f83 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -286,6 +286,12 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigForcePow2Scales = 0, /*! Small value to add to amax for numerical stability */ kNVTEQuantizationConfigAmaxEpsilon = 1, + /*! Noop tensor (containing a scalar). + If the scalar element value = 1, quantization kernel will early exit. + This is a tensor because the flag must be on GPU in order to enable + conditional early even when captured in a static CUDA graph. + */ + kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNumAttributes }; @@ -724,6 +730,12 @@ class QuantizationConfigWrapper { &amax_epsilon, sizeof(float)); } + /*! \brief Set noop tensor pointer */ + void set_noop_tensor(NVTETensor noop_tensor) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, &noop_tensor, + sizeof(NVTETensor)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index dae39d82bf..f6b6ae22c2 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -27,7 +27,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -57,7 +57,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (cudnn_backend) { // TODO: add check for GPU ARCH diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 8519fe1b64..c56f9ef407 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -23,7 +23,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -47,7 +47,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b676bf6ab0..80857e565c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -5,6 +5,7 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations import warnings +import os from enum import Enum from typing import Literal, Optional, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass @@ -81,6 +82,10 @@ def float8_per_tensor_scaling(self): """Whether the given recipe is per-tensor scaling.""" return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + def float8_block_scaling(self): + """Whether the given recipe is float8 blockwise scaling.""" + return isinstance(self, Float8BlockScaling) + @dataclass() class DelayedScaling(Recipe): @@ -287,3 +292,99 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + + +@dataclass() +class Float8BlockScaling(Recipe): + """ + Use block-wise scaling for FP8 tensors. + + In this strategy, tensors are scaled in blockwise fashion. Values within + each block share a common scaling factor. The block dimensionality + can be configured. The scaling factors are float32 containers. They + will by default be constrained to powers of 2. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), the quantized tensor and its transpose are not numerically + equivalent. Due to this, when Transformer Engine needs both the FP8 tensor + and its transpose (e.g. to calculate both forward and backward pass), + during the quantization both versions are computed from the high precision + input to avoid double quantization errors. + + NOTE: To relax the default constraint that scales be powers of 2, set env variable + NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. + export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 + Or initialize the Recipe with non-default QParams in code for increased control. + + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of gradient tensor dY + x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for x. + w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for w. + grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for grad. + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + """ + + use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" + + fp8_format: Format = Format.E4M3 + fp8_quant_fwd_inp = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + x_block_scaling_dim: int = 1 + w_block_scaling_dim: int = 2 + grad_block_scaling_dim: int = 1 + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" + assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" + assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" + assert not ( + self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." + assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." + assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"x_block_scaling_dim={self.x_block_scaling_dim}, " + f"w_block_scaling_dim={self.w_block_scaling_dim}, " + f"grad_block_scaling_dim={self.grad_block_scaling_dim}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 97df5892b6..706e0bd0b5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -429,6 +429,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(buf, &config_.amax_epsilon, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(buf, &config_.noop_tensor, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -458,6 +461,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(&config_.amax_epsilon, buf, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(&config_.noop_tensor, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 298d087337..3148b4f720 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -29,11 +29,35 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor const bool return_transpose, const bool pow_2_scale, cudaStream_t stream); +// enum class for rowwise usage +enum class FP8BlockwiseRowwiseOption { + // No rowwise data + NONE, + // Rowwise data, scales in GEMM format + ROWWISE + // TODO: FP8 all gather requires some changes. + // 1. Compact scales are better for gathering than the GEMM format. +}; + +// enum class for columnwise usage +// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling +enum class FP8BlockwiseColumnwiseOption { + // No columnwise data + NONE, + // Columnwise data transposed from original shape. + // Scales in GEMM format corresponding to GEMM ingesting transposed column data. + COLUMNWISE_TRANSPOSE + // TODO: FP8 all gather requires some changes. + // 1. The transpose gets in the way of the all gather. + // 2. Compact scales are better for gathering than the GEMM format. +}; + void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, - const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scale, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 732d97999c..91f73dea1e 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -16,11 +16,15 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" #include "common/utils.cuh" namespace transformer_engine { namespace { +using transformer_engine::detail::FP8BlockwiseColumnwiseOption; +using transformer_engine::detail::FP8BlockwiseRowwiseOption; + // clang-format off /* @@ -138,15 +142,17 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); template -__global__ void __launch_bounds__(kThreadsPerBlock) - block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, - OType* const output_t, CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, const size_t row_length, - const size_t num_rows, const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, const float epsilon, - bool return_transpose, bool pow_2_scaling) { +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, + FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scaling) { + bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; + bool return_columnwise_transpose = + columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; + using SMemVec = Vec; using OVec = Vec; union IVec { @@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) __syncthreads(); // Step 2: Cast and store to output_c - { + if (return_rowwise) { constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; @@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Step 3: Transpose, cast and store to output_t - if (return_transpose) { + if (return_columnwise_transpose) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); @@ -389,10 +395,15 @@ namespace transformer_engine::detail { void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, - const bool return_transpose, const bool pow2_scale, - cudaStream_t stream) { + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); - NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + // assert that rowwise_option and columnwise_option are not both NONE + NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE || + columnwise_option != FP8BlockwiseColumnwiseOption::NONE, + "rowwise_option and columnwise_option cannot both be NONE"); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -408,21 +419,24 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor } // Options for scale layout of cuBLAS GEMM kernel. - - NVTE_CHECK(input.shape.size() == output.shape.size(), - "Input and output must have the same shape."); - size_t scale_stride_x = 0; size_t scale_stride_y = 0; - NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); - size_t scale_k = scale_inv.shape[1]; - scale_stride_x = scale_k; - scale_stride_y = 1; - size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; - if (return_transpose) { + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, + "Unexpected rowwise enum value"); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = scale_k; + scale_stride_y = 1; + } + + if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { + NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE, + "Unexpected columnwise enum value"); NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); if (output_t.shape.size() > 0) { @@ -469,10 +483,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(output_t.dptr), reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, - pow2_scale);) // kAligned - ) // OutputType - ) // InputType + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, + columnwise_option, pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 22a50025df..1f146c7a33 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, nullptr, output, - dbias, workspace, stream); + detail::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no NVTE_API_CALL(nvte_quantize_noop); using namespace transformer_engine; + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; @@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, noop, output, - dbias, workspace, stream); + detail::quantize_helper( + input, grad, output, dbias, workspace, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, @@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr const NVTETensor activation_input = nullptr; detail::quantize_helper( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, @@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, @@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, @@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, @@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, @@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index c6a8b0f23c..a599d88530 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1215,9 +1215,9 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) { @@ -1232,6 +1232,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe auto output_tensor = reinterpret_cast(output); auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); + + const QuantizationConfig *quant_config_cpp = + reinterpret_cast(quant_config); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); switch (output_tensor->scaling_mode) { @@ -1263,11 +1269,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - constexpr bool force_pow_2_scales = true; + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, + output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } @@ -1275,12 +1281,18 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - constexpr bool force_pow_2_scales = true; - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() + ? FP8BlockwiseRowwiseOption::ROWWISE + : FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = + output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE + : FP8BlockwiseColumnwiseOption::NONE; + quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d7676781c3..c27f6f50f7 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -162,7 +162,7 @@ def lowering( assert scale_aval is None or scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x + ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x ) return out @@ -282,7 +282,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -293,9 +293,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -339,7 +339,7 @@ def partition( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -350,9 +350,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -391,7 +391,7 @@ def sharded_impl(x, scale): ) ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -463,7 +463,7 @@ def abstract( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: colwise_out_shape = out_shape @@ -545,7 +545,7 @@ def lowering( dz, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), @@ -673,7 +673,7 @@ def infer_sharding_from_operands( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -691,9 +691,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -743,7 +743,7 @@ def partition( ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -761,9 +761,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -810,7 +810,7 @@ def sharded_impl(dz, x, scale): else: global_dbias = local_dbias - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -928,7 +928,7 @@ def act_lu( out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1042,7 +1042,7 @@ def quantize_dact_dbias( # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, # unused scale_dtype=jnp.float32, # unused scale_shapes=((), ()), # unused @@ -1095,7 +1095,7 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1df2bcc97f..0327542c2f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -98,7 +98,7 @@ def lowering( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + scaling_mode=scaling_mode.value, ) @staticmethod @@ -123,7 +123,7 @@ def impl( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=scaling_mode, out_dtype=out_dtype, out_flat_size=out_flat_size, ) @@ -198,7 +198,7 @@ def _jax_gemm_delayed_scaling_fp8( ): """FP8 GEMM for XLA pattern match""" assert ( - rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -230,7 +230,7 @@ def _jax_gemm_mxfp8_1d( JAX GEMM for MXFP8 via scaled_matmul """ assert ( - rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING + rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING ), "rhs does not have MXFP8 1D scaling mode" from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper @@ -291,10 +291,10 @@ def _jax_gemm( def _jax_gemm_fp8_impl(lhs, rhs): - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) - if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") @@ -403,7 +403,7 @@ def grouped_gemm( rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" @@ -415,7 +415,7 @@ def grouped_gemm( dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NVTE_NO_SCALING + scaling_mode = ScalingMode.NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape out_dtype = lhs.dtype @@ -427,13 +427,13 @@ def grouped_gemm( lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) @@ -470,13 +470,13 @@ def grouped_gemm( dims.append((bm, bn, k)) lhs_contig_.append(lhs_3d.reshape(-1)) rhs_contig_.append(rhs_3d.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) if bias_list is not None: @@ -493,8 +493,8 @@ def grouped_gemm( # TE/common does not support NVTE_NO_SCALING yet # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + if scaling_mode == ScalingMode.NO_SCALING: + scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING # Perform batched GEMM on flattened inputs out_contig = GroupedGemmPrimitive.outer_primitive.bind( @@ -505,7 +505,7 @@ def grouped_gemm( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, out_flat_size=out_flat_size, ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index c79eda5568..d64104ac27 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, """ should_apply_war = ( quantizer is not None - and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x() ) if not should_apply_war: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 74882c92db..388d4f17ee 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -105,6 +105,26 @@ def abstract( if norm_type == NVTE_Norm_Type.LayerNorm: assert gamma_aval.size == beta_aval.size + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) + if norm_type == NVTE_Norm_Type.RMSNorm: + mu_aval = mu_aval.update(shape=(1,)) + + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) + + scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size @@ -112,33 +132,13 @@ def abstract( jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype jax_dtype_to_te_dtype(out_dtype), norm_type, - scaling_mode.value, + scaling_mode, zero_centered_gamma, epsilon, get_forward_sm_margin(), is_2x, ) - - out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) - mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_aval = mu_aval.update(shape=(1,)) - - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x_aval.shape, is_padded=not is_outer - ) - - scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) - colwise_out_aval = jax.core.ShapedArray( - shape=x_aval.shape if is_2x else (1,), dtype=out_dtype - ) - - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - - wkspace_aval = x_aval.update( + wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) @@ -274,9 +274,9 @@ def impl( scale_shapes=scale_shapes, is_outer=False, ) - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x.shape, is_padded=False - ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x.shape, is_padded=False) # slice out padding for mxfp8, noop for DelayedScaling scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape @@ -364,6 +364,8 @@ def infer_sharding_from_operands( del zero_centered_gamma, epsilon, out_dtype, result_infos del scale_dtype, scale_shapes, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) + out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -371,34 +373,27 @@ def infer_sharding_from_operands( "and hurt performance." ) - out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out" + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) - rsigma_sharding = NamedSharding( mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") + + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) - - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") output = ( out_sharding, colwise_out_sharding, @@ -427,8 +422,11 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) g_spec = get_padded_spec(arg_infos[2]) b_spec = get_padded_spec(arg_infos[3]) + out_spec = (*x_spec[:-1], None) + if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -445,43 +443,30 @@ def partition( f"{NormFwdPrimitive.name} does not support sharding of parameter beta " "Enforcing no sharding of parameters hidden dim! " ) - x_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x" - ) - g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma") - b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta") - out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out") - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" + ) rsigma_sharding = NamedSharding( - mesh, - PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]), - desc="NormFwdPrimitive.rsigma", + mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale" - ) - scale_inv_sharding = scale_sharding.duplicate_with_new_description( - "NormFwdPrimitive.scale_inv" + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") - arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, colwise_out_sharding, @@ -517,7 +502,7 @@ def sharded_impl(x, scale, gamma, beta): scale_shapes=scale_shapes, is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -824,7 +809,6 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) - if quantizer is None: output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, @@ -835,7 +819,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((1,), (1,)), @@ -845,7 +829,7 @@ def layernorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -864,7 +848,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -873,7 +857,7 @@ def layernorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -882,7 +866,7 @@ def layernorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) @@ -1017,7 +1001,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1027,7 +1011,7 @@ def rmsnorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -1046,7 +1030,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -1055,7 +1039,7 @@ def rmsnorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1064,7 +1048,7 @@ def rmsnorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 034e149c50..2911b5a420 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -93,7 +93,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -114,6 +114,10 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + scaling_mode, + QuantizeLayout( + q_layout + ), # For now until we have auto-decoding for QuantizeLayout enum ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -176,7 +180,7 @@ def lowering( ctx, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, is_dbias=is_dbias, @@ -302,7 +306,7 @@ def infer_sharding_from_operands( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -322,9 +326,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -374,7 +378,7 @@ def partition( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -394,9 +398,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -445,7 +449,7 @@ def sharded_impl(x, scale): is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -588,7 +592,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1950d6cbab..aaaf57fab7 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -31,6 +31,9 @@ #include "transformer_engine/activation.h" #include "utils.h" +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); + namespace transformer_engine { namespace jax { @@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); + +pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, bool is_2x); + // Normalization XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); @@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training); @@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype); - -XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); - -pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x); + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index e71597e4b3..fc7f231f34 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -17,7 +17,7 @@ namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum, + Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto scaling_mode = static_cast(scaling_mode_enum); auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis @@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("act_enum") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid auto dact_input_tensor = TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); - auto output_tensor = TensorWrapper(static_cast(scaling_mode)); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (is_fp8_dtype(out_dtype)) { @@ -172,9 +173,8 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } if (is_2x) { - auto &tmp_shape = scaling_mode == static_cast(NVTE_DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter @@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } } - if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, - bool is_dbias, int64_t act_enum) { + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, bool is_2x, bool is_dbias) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); - auto scaling_mode = static_cast(scaling_mode_enum); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -295,9 +297,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK( - !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), - "TE/common does not support delayed scaling for 2x with gated activations."); + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { switch (act_type) { @@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") + .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias") - .Attr("act_enum"), + .Attr("is_dbias"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e5ec160c91..d4b9bf720e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const int64_t &scaling_mode, + int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode, cudaStream_t stream) { size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); @@ -90,14 +90,17 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - auto lhs_i = TensorWrapper(static_cast(lhs_ptr), lhs_shape, lhs_dtype, nullptr, - nullptr, reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i = TensorWrapper(static_cast(rhs_ptr), rhs_shape, rhs_dtype, nullptr, - nullptr, reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); + rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); + + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, + std::vector{1}); + rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, + std::vector{1}); + } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", MXFP8_BLOCK_SIZE, k); size_t sinv_k = k / MXFP8_BLOCK_SIZE; @@ -107,20 +110,15 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh rhs_sinv_shape[1] = sinv_k; // Note: the scale_inv array should have been swizzled in Python before lowering - TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING); - TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, lhs_sinv_shape); rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); } else { - NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); + NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); lhs_ptr += m * k * lhs_dtype_bytes; @@ -169,7 +167,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) { + Result_Type workspace_flatten, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode) { // Inputs auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); @@ -207,7 +206,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Ret() // out_flatten .Ret() // workspace_flatten .Attr("num_gemms") - .Attr("scaling_mode"), + .Attr("scaling_mode"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c8526e20c0..f7577c24f3 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -40,5 +40,28 @@ enum class QuantizeLayout { ROWWISE_COLWISE, }; +enum class JAXX_Scaling_Mode : int64_t { + NO_SCALING = 0, + DELAYED_TENSOR_SCALING = 1, + MXFP8_1D_SCALING = 2, +}; + +static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { + switch (mode) { + case JAXX_Scaling_Mode::NO_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::MXFP8_1D_SCALING: + return NVTEScalingMode::NVTE_MXFP8_1D_SCALING; + break; + default: + NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); + break; + } +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 03855753cf..e23e42f528 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -14,7 +14,8 @@ namespace jax { pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training) { auto input_shape = std::vector{batch_size, hidden_size}; @@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - auto _scaling_mode = static_cast(scaling_mode); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated - if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { int temp = 1; output_tensor.set_columnwise_data(static_cast(&temp), out_dtype, input_shape); } @@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, @@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, int scaling_mode, bool is_2x) { + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); - auto _scaling_mode = static_cast(scaling_mode); auto _norm_type = static_cast(norm_type); auto _is_2x = static_cast(is_2x); @@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); if (is_fp8_dtype(out_dtype)) { @@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc scale_inv_buf->dimensions().back()}); } - if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); cudaMemsetAsync(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); @@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, @@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("zero_centered_gamma") .Attr("epsilon") .Attr("sm_margin") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index ebdfe461c7..5c165cccb6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("RMSNorm", NVTE_Norm_Type::RMSNorm) .export_values(); - pybind11::enum_(m, "NVTE_Scaling_Mode", pybind11::module_local()) - .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) - .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) - .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) + pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) + .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) + .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .export_values(); pybind11::enum_(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index b48ee8a9b9..481dbd7cdf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -13,7 +13,9 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ int temp = 0; auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); - auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); - output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); + + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + TensorWrapper dummy_workspace; nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), @@ -44,8 +73,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T Result_Type output_buf, Result_Type output_trans_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, - int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, - int64_t flatten_axis) { + JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -54,7 +83,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); - auto scaling_mode = static_cast(scaling_mode_enum); auto const quantize_layout = static_cast(quantize_layout_enum); auto *output = output_buf->untyped_data(); @@ -77,14 +105,14 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T std::vector workspace_shape{workspace_dims.begin(), workspace_dims.end()}; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); @@ -109,14 +137,16 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (quantize_layout == QuantizeLayout::COLWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -153,7 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("q_layout") .Attr("is_dbias") .Attr("flatten_axis"), diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a944848881..45ff8d7ed9 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -361,7 +361,7 @@ def generate_quantize_meta(quantizer_name: str): ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index b1e9ba03b4..d68eb3c6c2 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -84,8 +84,8 @@ def _dq_func_block_scaling(scaled_tensor): ) funcs = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, - ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling, + ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, + ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling, } @staticmethod diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 7d144aa69d..98f280b9a9 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -94,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ gpu_arch = get_device_compute_capability(gpu_id) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _check_block_scaling_fp8_support(gpu_arch) return (False, "Unsupported scaling_mode!") def is_fp8_available( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, ) -> Tuple[bool, str]: """Check if FP8 is available for the given scaling mode and GPU. @@ -179,9 +179,9 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ValueError: If the recipe type is not supported """ if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.NVTE_DELAYED_TENSOR_SCALING + return ScalingMode.DELAYED_TENSOR_SCALING if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.NVTE_MXFP8_1D_SCALING + return ScalingMode.MXFP8_1D_SCALING raise ValueError("Invalid fp8_recipe!") @@ -217,7 +217,7 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False IF_QUANTIZE_2X: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING + SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 @@ -253,11 +253,11 @@ def finalize(cls) -> None: cls.MARGIN = 0.0 cls.FP8_FORMAT = recipe.Format.HYBRID cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.FP8_2X_ACC_FPROP = False cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.IF_QUANTIZE_2X = False # DelayedScaling cls.AMAX_HISTORY_LEN = 1024 diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index bd7045453b..b57043a034 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) @@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): q_layout: Quantization axis (default: ROWWISE_COLWISE) """ - scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING + scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE def get_data_layout(self) -> str: @@ -530,8 +530,8 @@ class QuantizerFactory: """ quantizer_type_map = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, + ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, } @staticmethod @@ -556,8 +556,9 @@ def create( A single quantizer or tuple of quantizers """ # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted - # assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING - if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING): + assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" + # import pdb; pdb.set_trace() + if scaling_mode == ScalingMode.NO_SCALING: quantizers = [None] * n_quantizers else: quantizers = [] @@ -651,4 +652,4 @@ def create_set( return q_set[0] if len(q_set) == 1 else tuple(q_set) -noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING) +noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 95bbc9bb41..34f63a994c 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -19,6 +19,8 @@ from jax.tree_util import register_pytree_node_class import jax.numpy as jnp +from transformer_engine_jax import JAXX_Scaling_Mode + __all__ = ["ScalingMode"] @@ -216,25 +218,20 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) -# (Phuong: Map the NVTEScalingMode value to the ScalingMode - - @dataclass(frozen=True) @register_pytree_node_class class ScalingMode(Enum): """Enumeration of tensor scaling modes with their corresponding metadata implementations. This class defines the available scaling modes for tensor quantization: - - NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - - NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - - NVTE_INVALID_SCALING: Invalid scaling mode - - NVTE_NO_SCALING: No scaling applied + - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales + - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales + - NO_SCALING: No scaling applied """ - NVTE_DELAYED_TENSOR_SCALING = 0 - NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 100 - NVTE_NO_SCALING = 1000 + NO_SCALING = JAXX_Scaling_Mode.NO_SCALING + DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING + MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -329,8 +326,8 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR - ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c34a235d94..0ef30f4728 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -236,13 +236,12 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # TODO(Phuong): Handle padding !? scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv - # TODO(Phuong): constaint padded scale_inv? return ScaledTensor1x( data=data, scale_inv=scale_inv, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6440c628cd..0d442435bf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -616,7 +616,7 @@ def forward( rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -1564,7 +1564,7 @@ def backward(ctx, dout): rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 948a13a03e..79d6391e79 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,6 +14,7 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = [ "general_gemm", @@ -112,6 +113,10 @@ def general_gemm( # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + # There is not use_split_accumulator == False + # implementation for Float8BlockwiseQTensorBase GEMM + use_split_accumulator = True args = ( A, transa, # transa diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 338f1fcbb1..3b349e7f09 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. DType dtype; - - private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + + private: int block_scaling_dim = 2; public: diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 1ef6f5258d..bf037fe931 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + // sanity check, since activation fusion is not supported for blockwise quantization yet + // need to raise an error here instead of silently going into act_func with wrong numerics + NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); } else { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2c3ccff154..84e50dea22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob if (te_output.numel() == 0) return out; + QuantizationConfigWrapper quant_config; + quant_config.set_noop_tensor(te_noop.data()); + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; + // this config is used for cs scaling factor computation + // because compute scale is cannot be fused with quantize kernel + // so in nvte_quantize_v2 with current scaling, the quant config is not used again quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index cbdeee5b48..dae6ce42e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -150,6 +150,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -166,15 +167,18 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::cast(mu), py::cast(rsigma)}; @@ -293,6 +297,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -309,15 +314,18 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::none(), py::cast(rsigma)}; diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 9ac6292e53..fbf31a7f5b 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -257,12 +257,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); - NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast(), - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires pow2 scales"); - NVTE_CHECK(quantizer.attr("amax_epsilon").cast() == 0.0, - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires amax_epsilon==0"); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, "Unsupported block scaling dim."); } diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index a873586032..e12990f79c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -69,6 +69,7 @@ std::vector fused_multi_quantize(std::vector input_list, nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); } else { for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + // TODO: switch to nvte_quantize_v2 with advanced numerical options nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e245b788b4..7a1fde164b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -24,10 +24,11 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -937,6 +938,74 @@ def _all_gather_fp8( return out, handle +def _all_gather_fp8_blockwise( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, # pylint: disable=unused-argument + quantizer: Optional[Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: + """ + All-gather FP8 tensor along first dimension for blockwise quantization. + + Returns: quantizer(gather(inp)) + + NOTE: The implementation is not sophisticated enough to honor async_op=True. + In some cases it falls back to synchronous gather and invokes the quantizer. + """ + + # Input tensor attributes + device: torch.device + dtype: torch.dtype + if isinstance(inp, torch.Tensor): + device = inp.device + dtype = inp.dtype + elif isinstance(inp, Float8BlockwiseQTensorBase): + if inp._rowwise_data is not None: + device = inp._rowwise_data.device + elif inp._columnwise_data is not None: + device = inp._columnwise_data.device + else: + raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") + dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " + f"found {inp.__class__.__name__})" + ) + world_size = get_distributed_world_size(process_group) + + # Check that quantizer is valid + if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): + raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") + if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): + raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") + + # Output tensor dims + if out_shape is None: + out_shape = list(inp.size()) + out_shape[0] *= world_size + + # Doing BF16 gather for now as baseline because it's simpler + if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) + out = quantizer(out) + return out, None + # Implementation of fp8 gather needs to account for: + # * Getting columnwise data as a transpose of how it is stored for GEMMS. + # * Gathering non GEMM swizzled scales. + # * Refer to scaffold code when implementing at: + # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 + raise NotImplementedError("fp8 blockwise allgather not yet implemented") + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1075,7 +1144,9 @@ def gather_along_first_dim( async_op: bool = False, quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: - """All-gather tensors and concatenate along first dimension.""" + """ + All-gather tensors and concatenate along first dimension. + """ # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) @@ -1100,6 +1171,16 @@ def gather_along_first_dim( out_shape=out_shape, ) + # FP8 block scaling case, block length = 128 + if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + return _all_gather_fp8_blockwise( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # MXFP8 case if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 38f829c079..c02ff73391 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import itertools import os from contextlib import contextmanager from collections import deque @@ -19,6 +20,7 @@ Format, MXFP8BlockScaling, Float8CurrentScaling, + Float8BlockScaling, ) from .constants import dist_group_type @@ -49,6 +51,17 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available""" + if ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ): + return True, "" + return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + + def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -109,6 +122,8 @@ class FP8GlobalStateManager: skip_fp8_weight_update_tensor = None mxfp8_available = None reason_for_no_mxfp8 = "" + fp8_block_scaling_available = None + reason_for_no_fp8_block_scaling = None @classmethod def reset(cls) -> None: @@ -134,6 +149,8 @@ def reset(cls) -> None: cls.skip_fp8_weight_update_tensor = None cls.mxfp8_available = None cls.reason_for_no_mxfp8 = "" + cls.fp8_block_scaling_available = None + cls.reason_for_no_fp8_block_scaling = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -161,6 +178,15 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @classmethod + def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: + """Return if Float8 block scaling support is available.""" + if cls.fp8_block_scaling_available is None: + cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( + check_fp8_block_scaling_support() + ) + return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -434,6 +460,9 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8 + if isinstance(fp8_recipe, Float8BlockScaling): + fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() + assert fp8_block_available, reason_for_no_fp8_block @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -786,8 +815,10 @@ def create( cls = MXFP8BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState + elif recipe.float8_block_scaling(): + cls = Float8BlockScalingRecipeState else: - raise ValueError("{recipe.__class__.__name__} is not supported") + raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( recipe, mode=mode, @@ -928,3 +959,108 @@ def make_quantizers(self) -> list: from .tensor.mxfp8_tensor import MXFP8Quantizer return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + + +class Float8BlockScalingRecipeState(RecipeState): + """Configuration for Float8BlockScaling quantization. + + Float8BlockScaling quantization does not require state, + but different quantizers use different modes. + """ + + recipe: Float8BlockScaling + mode: str + qx_dtype: tex.DType + qw_dtype: tex.DType + qgrad_dtype: tex.DType + + def __init__( + self, + recipe: Float8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.qx_dtype = get_fp8_te_dtype(recipe, True) + self.qw_dtype = get_fp8_te_dtype(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_blockwise_tensor import Float8BlockQuantizer + + if self.mode == "forward": + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward, and doesn't play nicely with QuantizeOp, + # which is not associated with a GEMM. + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qw_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] + ) + ) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] + ) + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cdb75aa1b6..65f47a0817 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -23,6 +23,7 @@ MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -34,8 +35,10 @@ ) from ..constants import dist_group_type from ..tensor import QuantizedTensor, Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["initialize_ub", "destroy_ub"] @@ -43,6 +46,7 @@ _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _multi_stream_cublas_workspace = [] +_dummy_wgrads = {} _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 @@ -78,6 +82,22 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: return _multi_stream_cublas_workspace +def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: + """Returns a dummy tensor of given shape.""" + assert len(shape) == 2 + global _dummy_wgrads + if (shape[0], shape[1], dtype) not in _dummy_wgrads: + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + shape, + dtype=dtype, + device="cuda", + requires_grad=False, + ) + if zero: + _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) + return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + + def initialize_ub( shape: list, tp_size: int, @@ -499,6 +519,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8CurrentScalingRecipeState ): return + if recipe.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -841,7 +865,13 @@ def grad_output_preprocess( if ctx.ub_overlap_ag: # Quantize the gradient if needed if not isinstance( - grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + grad_output, + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), ): grad_output = quantizer(grad_output) @@ -859,11 +889,21 @@ def grad_output_preprocess( # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: - if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance(quantizer, Float8BlockQuantizer): + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9cd52b1e5..1ea66a7f2c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -91,6 +91,8 @@ def forward( # TODO Support Float8 Current Scaling # pylint: disable=fixme if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..df3ae05f31 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -19,6 +19,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -56,9 +57,11 @@ restore_from_saved, ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + from ..cpp_extensions import ( general_gemm, ) @@ -137,11 +140,6 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") - # Avoid quantized norm kernel if norm output will be returned - with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered - ) - tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output @@ -174,6 +172,18 @@ def forward( columnwise_usage = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + # Avoid quantized norm kernel if norm output will be returned + # or if a gather of ln_out must be in high precision. + force_hp_blockwise_ln_out_gather = ( + fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + with_quantized_norm = ( + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not force_hp_blockwise_ln_out_gather + ) + # Apply normalization nvtx_range_push(f"{nvtx_label}.norm") ln_out, mu, rsigma = apply_normalization( @@ -210,7 +220,7 @@ def forward( ln_out_total = input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: @@ -316,6 +326,7 @@ def forward( ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: @@ -326,6 +337,10 @@ def forward( if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # For force_hp_blockwise_ln_out_gather, we should + # be saving the unquantized ln_out to ctx. + assert not force_hp_blockwise_ln_out_gather + # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) @@ -604,11 +619,14 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + # async_op is not compatible with high precision gather since + # gather_along_first_dim does not offer callback chaining. + gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -689,6 +707,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather may have been done in BF16 + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): @@ -796,18 +821,15 @@ def backward( if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7dae573688..e51fe43cc0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -52,7 +52,6 @@ in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, ) - from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing @@ -62,6 +61,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..tensor.quantized_tensor import ( @@ -104,17 +104,19 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } # no activation fusion written yet - # Per-tensor current scaling: [] - return { - "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), - "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, None), - "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, None), - } + # Per-tensor current scaling or fp8 blockwise scaling: [] + if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, None), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, None), + } + raise NotImplementedError(f"Unhandled recipe type {recipe}") def _act_func(activation: str, recipe: Optional[Recipe] = None): @@ -122,7 +124,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None): # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] - # Per-tensor current scaling: [] + # Per-tensor current scaling or fp8 blockwise scaling: [] funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -214,12 +216,20 @@ def forward( with_quantized_norm = ( fp8 and not return_layernorm_output and not return_layernorm_output_gathered ) + if isinstance(fc1_input_quantizer, Float8BlockQuantizer): + # Kernels not available for norm fusion. + with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_rs = ub_overlap_rs and is_grad_enabled backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + # TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe + force_hp_fc1_input_gather = ( + fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + # Configure quantizer for norm output if fp8: if fc1_input_quantizer is None: @@ -261,12 +271,13 @@ def forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8: - ln_out = fc1_input_quantizer(ln_out) + if not force_hp_fc1_input_gather: + ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = fc1_input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -282,7 +293,10 @@ def forward( quantizer=(fc1_input_quantizer if fp8 else None), ) else: - if fp8 and not with_quantized_norm: + # NOTE: force_hp_fc1_input_gather is redundant with else, but + # here for clarity. We should not quantize ln_out if bwd needs + # to gather in hp. + if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) ln_out_total = ln_out @@ -336,6 +350,7 @@ def forward( # - bias_gelu_fusion - only for full precision. # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer if activation != "gelu": + # blockwise scaled gemms don't support gemm_gelu_fusion in fwd. gemm_gelu_fusion = bias_gelu_fusion = False else: if fp8: @@ -376,7 +391,12 @@ def forward( act_out, _, fc1_out, _ = fc1_outputs else: fc1_out, *_ = fc1_outputs - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + # tex.quantize does not support GELU fusion for blockwise. + act_out = activation_func(fc1_out, None) + act_out = tex.quantize(act_out, fc2_input_quantizer) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -462,6 +482,8 @@ def forward( if not return_layernorm_output: clear_tensor_data(ln_out) ln_out = None + elif force_hp_fc1_input_gather: + assert not isinstance(ln_out, QuantizedTensor) if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None @@ -490,6 +512,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -505,6 +528,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -696,11 +720,12 @@ def backward( else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) + gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) else: ln_out_total = ln_out @@ -712,12 +737,13 @@ def backward( ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - # There are 5 possible fusion paths + # There are 6 possible fusion paths # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) ) @@ -753,6 +779,9 @@ def backward( if isinstance(grad_output, QuantizedTensor): grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + grad_arg = True + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + grad_arg = False fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( act_out, grad_output, @@ -764,14 +793,18 @@ def backward( ), quantization_params=None, # wgrad in high precision layout="NT", - grad=True, - bias=fc2_bias if fc2_bias_grad is None else None, + grad=grad_arg, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if fc2_bias_grad is None: + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None: + # BGRAD not fused with GEMM for float8 blockwise gemm. + fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ clear_tensor_data(act_out) # bias computation @@ -808,7 +841,14 @@ def backward( ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + # TODO float8 blockwise current scaling has no bgrad fusion for now + if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer): + fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) + dact = ctx.grad_fc1_output_quantizer(dact) + else: + fc1_bias_grad, dact = tex.bgrad_quantize( + dact, ctx.grad_fc1_output_quantizer + ) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 @@ -904,6 +944,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.fc1_input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): @@ -1556,7 +1603,8 @@ def _get_quantizers(self, fp8_output): fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( - rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + rowwise=True, + columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..2887b2e452 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -16,6 +16,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -59,9 +60,10 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + __all__ = ["Linear"] @@ -129,6 +131,10 @@ def forward( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False + # TODO(kwyss): Support FP8 allgather for FP8 block quantization. + force_hp_input_gather = ( + fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. if fp8: assert_dim_for_fp8_exec(inputmat, weight) if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( @@ -142,19 +148,27 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - if not isinstance(inputmat, QuantizedTensor): - columnwise_usage = backward_needs_input and isinstance( - input_quantizer, MXFP8Quantizer + if force_hp_input_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, tp_group, quantizer=input_quantizer + ) + else: + if not isinstance(inputmat, QuantizedTensor): + columnwise_usage = backward_needs_input and isinstance( + input_quantizer, MXFP8Quantizer + ) + # force_hp_input_gather should enforce this + assert not isinstance(input_quantizer, Float8BlockQuantizer) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + inputmat = input_quantizer(inputmat) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, ) - input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - inputmat = input_quantizer(inputmat) - own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=False) - inputmat_total, _ = gather_along_first_dim( - inputmat, - tp_group, - quantizer=input_quantizer, - ) else: if ( FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() @@ -276,6 +290,8 @@ def forward( # can be allgathered. if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if force_hp_input_gather: + assert not isinstance(inputmat, QuantizedTensor) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. @@ -322,8 +338,9 @@ def forward( ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -519,11 +536,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + gather_quantizer = None if ctx.force_hp_input_gather else quantizer inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -609,6 +627,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None + if ctx.input_quantizer is not None and not isinstance( + inputmat_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data if isinstance(inputmat_total, QuantizedTensor): @@ -688,18 +713,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb93eb5e6b..f4d4254537 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -23,6 +23,7 @@ from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext @@ -483,6 +484,12 @@ def _functional_forward( "Attempting to generate MXFP8 output tensor, " "but GEMM with MXFP8 output is not supported" ) + if isinstance(output_quantizer, Float8BlockQuantizer): + raise RuntimeError( + "Attempting to generate Float8BlockQuantized output tensor, " + "but GEMM with Float8BlockQuantized output is not supported" + ) + if output_quantizer is not None: output_quantizer.set_usage(rowwise=True, columnwise=False) @@ -523,7 +530,7 @@ def _functional_forward( # Configure input tensor for backward pass if own_quantized_x_local: - x_local.update_usage(rowwise_usage=False) + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save @@ -679,7 +686,9 @@ def _functional_backward( quantizer=input_quantizer, ) else: - if not isinstance(x_local, QuantizedTensor): + if isinstance(x_local, QuantizedTensor): + x_local.update_usage(columnwise_usage=True) + else: x_local = input_quantizer(x_local) x = x_local else: @@ -706,15 +715,19 @@ def _functional_backward( raise ValueError("Weight tensor is required to compute input grad") w = weight w_is_quantized = isinstance(w, QuantizedTensor) - if with_quantized_compute and not w_is_quantized: - if weight_quantizer is None: - raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(columnwise=True) - w = weight_quantizer(w) - elif not with_quantized_compute and w_is_quantized: - w = w.dequantize() - if not with_quantized_compute and w.dtype != dtype: - w = w.to(dtype=dtype) + if with_quantized_compute: + if w_is_quantized: + w.update_usage(columnwise_usage=True) + else: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + else: + if w_is_quantized: + w = w.dequantize(dtype=dtype) + elif w.dtype != dtype: + w = w.to(dtype=dtype) # Synchronize tensor-parallel communication _wait_async(dy_async) @@ -867,8 +880,8 @@ def op_forward( # Configure quantizers # Note: We cache the quantized input for backward pass, # but discard the quantized weights. - input_quantizer.set_usage(columnwise=weight_requires_grad) - weight_quantizer.set_usage(columnwise=False) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) # Get autocast dtype if needed dtype = None diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 2e212e15f4..802f4c25e3 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -17,8 +17,10 @@ from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, + fp8_autocast, ) from ..tensor import Quantizer @@ -218,6 +220,11 @@ def _reset_quantization_recipe_state( if num_quantizers == 0: continue + if recipe.float8_block_scaling(): + raise NotImplementedError( + "Fusible operations do not support FP8 block scaling recipe" + ) + # Construct quantization recipe state recipe_state = RecipeState.create( recipe, @@ -259,8 +266,13 @@ def _update_quantization_recipe_state( continue recipe_state = self._fp8_metas[mode][fp8_meta_key] need_to_reset_recipe_state = ( - recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) - ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) + or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + or ( + recipe.float8_block_scaling() + and not isinstance(recipe_state, Float8BlockScalingRecipeState) + ) + ) if need_to_reset_recipe_state: self._reset_quantization_recipe_state(recipe=recipe) return @@ -508,7 +520,7 @@ def forward( def get_extra_state(self) -> torch.Tensor: """Serialize extra state - Contains metadata for FP8 casting. + Contains metadata for quantization recipe. """ @@ -540,23 +552,27 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: dst.copy_(src, non_blocking=True) return dst - # Store FP8 state + # Store quantizer state if needed state = {} for mode in ("forward", "backward"): - # Get state for a given FP8 tensor - if self.num_quantizers(mode) == 0: + # Skip if op has no quantizer state + if self._fp8_metas is None or self._fp8_metas[mode] is None: continue - fp8_meta = self.get_fp8_meta(mode) + + # Quantizer state + fp8_meta = self._fp8_metas[mode] state[mode] = {} + state[mode]["recipe"] = fp8_meta["recipe"] - # Store tensors - if "scaling_fwd" in fp8_meta: - state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) - if "scaling_bwd" in fp8_meta: - state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) + # Copy tensors to CPU and store + if state[mode]["recipe"].delayed(): + if mode == "forward": + state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) + state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) + if mode == "backward": + state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) + state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items extra = {} @@ -595,37 +611,37 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) dst.copy_(src, non_blocking=True) - # Load FP8 state + # Load quantizer state if needed for mode in ("forward", "backward"): - # Get state for a given FP8 tensor + # Skip if checkpoint has no quantizer state if mode not in state: continue - if self.num_quantizers(mode) == 0: - continue - fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue - # Load extra state + # Get op's quantizer state, initializing if needed + if self._fp8_metas is None or self._fp8_metas[mode] is None: + with fp8_autocast(fp8_recipe=state[mode]["recipe"]): + self._reset_quantization_recipe_state() + fp8_meta = self._fp8_metas[mode] + + # Load extra items + fp8_meta["recipe"] = state[mode]["recipe"] fp8_meta.update(state[mode]["extra_fp8_variables"]) - if "amax_history_fwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0) - elif "amax_history_bwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0) if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Load tensors - fp8_meta = self.get_fp8_meta(mode) - if "scaling_fwd" in fp8_meta: - fp8_meta_fwd = fp8_meta["scaling_fwd"] - copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) - if "scaling_bwd" in fp8_meta: - fp8_meta_bwd = fp8_meta["scaling_bwd"] - copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) + if state[mode]["recipe"].delayed(): + if mode == "forward": + copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale) + copy_tensor( + state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history + ) + if mode == "backward": + copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale) + copy_tensor( + state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history + ) # Finish CPU-GPU memory transfers torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 9135237854..7dc380606d 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase: def __new__( cls, *args, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, @@ -71,10 +71,16 @@ def get_metadata(self) -> Dict[str, Any]: def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: - """Prepare the tensor base for saving for backward""" + """ + Prepare the tensor base for saving for backward + + This does not clear the tensors currently, because with PP config + that clears the weight cache between micro-batches. If the rowwise + data is not required for backward, this is a possible memory + pessimization, but is consistent with the other quantized tensor + classes. + """ tensors = [self._rowwise_data, self._columnwise_data] - self._rowwise_data = None - self._columnwise_data = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 138d1fd29e..695c5ffb8c 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -44,7 +44,6 @@ def __init__( block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - assert rowwise self.dtype = fp8_dtype self.block_len = 128 self.force_pow_2_scales = force_pow_2_scales @@ -168,6 +167,11 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: colwise_shape.append(shape[i]) return tuple(colwise_shape) + # TODO(kwyss): With FP8 gather support, we need to implement a + # shape/layout/swizzle check to know whether FP8 gather works + # cleanly by stacking data without aliasing tiles and whether + # the scales also stack on the proper dimensions. + def make_empty( self, shape: Iterable[int], @@ -181,13 +185,16 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, - dtype=torch.float32, - device=device, - ) + data = None + scale_inv = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -489,7 +496,6 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv - dst.dtype = src.dtype # Check that tensor dimensions match if ( diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 843c7936f2..2694319a0f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -347,6 +347,7 @@ def _make_in_reduce_ex( columnwise_scale_inv: torch.Tensor, fp8_dtype: TE_DType, dtype: torch.dtype, + shape: torch.shape, ) -> MXFP8Tensor: """Build MXFP8Tensor, for use in __reduce__ @@ -361,10 +362,11 @@ def _make_in_reduce_ex( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling""" return ( MXFP8Tensor._make_in_reduce_ex, ( @@ -374,6 +376,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._fp8_dtype, self.dtype, + self.shape, ), )