From 3b4366be34ec8e5b96c73ba20ec22f0dfac1b97a Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 3 Sep 2025 19:27:04 +1200 Subject: [PATCH 01/78] Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../pytorch/comm_gemm_overlap/te_layer_with_overlap.py | 6 +++++- tests/pytorch/distributed/run_layer_with_overlap.py | 8 ++++++-- .../distributed/test_fusible_ops_with_userbuffers.py | 4 ++-- transformer_engine/pytorch/module/base.py | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index eeb79c235c..d52e97d65c 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -264,7 +264,11 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) [batched_size, hidden_size], tp_size, quantization_modes=[ - UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) ], dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 1dabf6e451..2a6e55b2c0 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -420,10 +420,14 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): } quantization_modes = [ - UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) ] if opts.first_last_layers_bf16 and opts.fp8: - quantization_modes.append(UserBufferQuantizationMode.NONE) + quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE) te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 17d3512923..d6ddfe27c9 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -508,9 +508,9 @@ def main() -> None: torch.distributed.get_world_size(group), quantization_modes=[ ( - UserBufferQuantizationMode.FP8 + te.module.base.UserBufferQuantizationMode.FP8 if model_config.quantization is not None - else UserBufferQuantizationMode.NONE + else te.module.base.UserBufferQuantizationMode.NONE ) ], dtype=model_config.dtype, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3bbfaacdf5..a6275abd19 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -473,7 +473,7 @@ def add_ub( fp8_buf = (name in layers_all_gather_overlap) or ( user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] ) - ub_cfg.update(ub_cfgs[name]) + ub_cfg.update(user_ub_cfg[name]) ub_cfg["fp8_buf"] = fp8_buf add_ub(name, quantization_mode, **ub_cfg) From f378eaf2899f1148c68b567f62595e940556da7f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:15:52 -0700 Subject: [PATCH 02/78] [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 9 +++++++++ transformer_engine/jax/cpp_extensions/attention.py | 6 ++++++ transformer_engine/jax/cpp_extensions/misc.py | 10 ++++++++++ 3 files changed, 25 insertions(+) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index ec530a3959..87dfc113c7 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -41,6 +41,7 @@ from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, + get_device_compute_capability, ) from distributed_test_base import assert_equal_collectives @@ -348,6 +349,14 @@ def _check_configs(self): "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) + if ( + get_device_compute_capability(0) == 100 + and self.dropout_prob == 0.1 + and self.attn_bias_type is not AttnBiasType.NO_BIAS + ): + pytest.skip( + "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 089ef75f1c..df89174b2c 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -34,6 +34,7 @@ te_dtype_to_jax_dtype, get_padded_spec, get_cudnn_version, + get_all_device_compute_capability, ) from ..sharding import ( global_mesh_resource, @@ -2745,6 +2746,11 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) + if 100 in get_all_device_compute_capability(): + assert not ( + attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 + ), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 94dfaa45a4..3bda37128b 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -193,6 +193,16 @@ def get_min_device_compute_capability(): ) +def get_all_device_compute_capability(): + """ + Returns a list of compute capability of all local devices. + """ + return tuple( + transformer_engine_jax.get_device_compute_capability(local_gpu_id) + for local_gpu_id in range(len(jax.local_devices())) + ) + + def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None): """ Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to From 0f68f7b2f9e6e94d7037513942389432f9e58d68 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu <42691305+zhongbozhu@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:11:33 -0700 Subject: [PATCH 03/78] [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../include/transformer_engine/recipe.h | 15 +++++ .../common/recipe/current_scaling.cu | 66 ++++++++++++++++--- .../common/transpose/cast_transpose.h | 5 +- .../quantize_transpose_square_blockwise.cu | 20 ++++-- .../quantize_transpose_vector_blockwise.cu | 18 +++-- .../common/util/cast_kernels.cuh | 11 ++-- transformer_engine/pytorch/csrc/quantizer.cpp | 3 +- 7 files changed, 110 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 50fb696ea6..2fc8c1095c 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -84,6 +84,21 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( */ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Compute an FP8 tensor's amax with quantization config. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor, using the provided + * quantization configuration. + * One useful config is the noop tensor, which is needed by cuda graph. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[in] config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig config, cudaStream_t stream); + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index e1657b77a1..fd907efcba 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512; template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, const size_t N, - const size_t num_aligned_elements) { + const size_t num_aligned_elements, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + VectorizedLoader loader(input, N); InputType max = 0.f; const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__ } template -void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr, + cudaStream_t stream) { // Zero out amax so we can update with atomic max NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); @@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud switch (align) { case Alignment::SAME_ALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>(input, amax, N, num_aligned_elements, noop_ptr); break; case Alignment::SAME_UNALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>(input, amax, N, num_aligned_elements, noop_ptr); break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. - amax_kernel<1, true, InputType><<>>(input, amax, N, N); + amax_kernel<1, true, InputType> + <<>>(input, amax, N, N, noop_ptr); break; } } @@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud } // namespace } // namespace transformer_engine -void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { - NVTE_API_CALL(nvte_compute_amax); +namespace { + +void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream, + const NVTEQuantizationConfig config_) { using namespace transformer_engine; // Check input tensor @@ -138,12 +146,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt to_string(output.amax.dtype), ")"); CheckOutputTensor(output, "output_compute_amax", true); + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel(reinterpret_cast(input.data.dptr), reinterpret_cast(output.amax.dptr), input.data.numel(), - stream);); // NOLINT(*) + noop_ptr, stream);); // NOLINT(*) +} + +} // anonymous namespace + +void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax); + compute_amax_impl(input_, output_, stream, nullptr); +} + +void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_, + const NVTEQuantizationConfig config_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax_with_config); + compute_amax_impl(input_, output_, stream, config_); } namespace transformer_engine { @@ -151,7 +182,11 @@ namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, - const float epsilon) { + const float epsilon, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); } @@ -197,10 +232,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, max_fp8 = Quantized_Limits::max_norm;); + // noop tensor for cuda graph + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Update scale compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(output.amax.dptr), reinterpret_cast(output.scale.dptr), max_fp8, config.force_pow_2_scales, - config.amax_epsilon); + config.amax_epsilon, noop_ptr); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index a737239260..abfa226e88 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + const SimpleTensor &noop_tensor, cudaStream_t stream); // enum class for rowwise usage enum class FP8BlockwiseRowwiseOption { @@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scale, cudaStream_t stream); + const bool pow_2_scale, const SimpleTensor &noop_tensor, + cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index a603d1f1a2..c3f085b877 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const __grid_constant__ CUtensorMap tensor_map_output_t, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream) { + const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); @@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + if (return_transpose) { NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); @@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - tensor_map_output_trans, pow_2_scale); + tensor_map_output_trans, pow_2_scale, noop_ptr); } else { block_scaled_cast_transpose_kernel_notaligned @@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - pow_2_scale); + pow_2_scale, noop_ptr); } // full-tile ) // return_transpose ) // OutputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 6f5c0f3a6c..4c82b8c81b 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scaling) { + const bool pow_2_scaling, const float* noop_ptr) { + // skip execution if noop + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; bool return_columnwise_gemm_ready = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; @@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow2_scale, cudaStream_t stream) { + const bool pow2_scale, const SimpleTensor& noop_tensor, + cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; @@ -585,6 +591,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -613,9 +621,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, - columnwise_option, pow2_scale);) // kAligned - ) // OutputType - ) // InputType + columnwise_option, pow2_scale, noop_ptr);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 1158132e3f..8d87351181 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1427,7 +1427,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor.data, stream); break; } case NVTE_BLOCK_SCALING_1D: { @@ -1455,10 +1456,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; } - quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, output_tensor->data, - output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, stream); + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor.data, stream); break; } default: diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0c75789ed9..c690cd522a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -518,7 +518,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te // Compute amax if (compute_amax) { - NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); } // Perform amax reduction if needed From e9a5fa4e368464f3b310b90ab7f670f35319344b Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 4 Sep 2025 22:39:53 +0200 Subject: [PATCH 04/78] =?UTF-8?q?[PyTorch]=C2=A0fix=20cross=20entropy=20va?= =?UTF-8?q?nishing=20gradients=20(#2139)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon --- tests/pytorch/test_parallel_cross_entropy.py | 59 ++++++++++++------- .../pytorch/triton/cross_entropy.py | 3 + 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 77bea2b360..fa56852ffc 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -6,6 +6,8 @@ import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols + class TestParallelCrossEntropy: @@ -18,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -40,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -52,18 +65,20 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - # Handle backward pass based on the test scenario + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) + + # Backward pass if reduce_loss: test_loss.backward() ref_loss.backward() @@ -71,16 +86,18 @@ def one_iteration_test( test_loss.sum().backward() ref_loss.sum().backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - - if ignore_idx: - print(test_loss, ref_loss) - - # Compare gradients when backward pass was called - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 323a939223..7cfff1da9d 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -230,6 +230,7 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -252,6 +253,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -360,6 +362,7 @@ def cross_entropy_backward( _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From 11e9d669ae827b13dce309fefa79f8938da34352 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 5 Sep 2025 11:24:22 +0800 Subject: [PATCH 05/78] Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/module/base.py | 3 +-- transformer_engine/pytorch/module/grouped_linear.py | 6 +----- transformer_engine/pytorch/module/layernorm_mlp.py | 7 ++----- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a6275abd19..0f2e3c4de1 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1482,8 +1482,7 @@ def backward_dw(self): (wgrad, bgrad), _ = self.wgrad_store.pop() if not self.fuse_wgrad_accumulation: weight_tensor = noop_cat(self._get_weight_tensors()) - if weight_tensor.grad is None: - weight_tensor.grad = wgrad.to(weight_tensor.dtype) + weight_tensor.grad = wgrad.to(weight_tensor.dtype) if self.use_bias: bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) if bias_tensor.grad is None: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3d7a5efaca..e9189ccc59 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -452,9 +452,6 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): - wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() @@ -829,8 +826,7 @@ def backward_dw(self): bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): - if weight_params[i].grad is None: - weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) + weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: for i in range(self.num_gemms): if bias_params[i].grad is None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 182bf99f86..a6c55ceb79 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1197,7 +1197,6 @@ def fc1_wgrad_gemm( "with Userbuffers (tensor-parallel communication overlapping)" ) ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm) - fc1_wgrad = None if fuse_gemm_and_bias_fc1_wgrad: fc1_bias_grad = None else: @@ -2168,10 +2167,8 @@ def backward_dw(self): if self.fc1_bias.grad is None: self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype) if not self.fuse_wgrad_accumulation: - if self.fc2_weight.grad is None: - self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) - if self.fc1_weight.grad is None: - self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) + self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) + self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) del fc2_bias_grad_ del fc2_wgrad del fc1_wgrad From b10f436aa28ec8d885eb0d9bf134c320ffb6353d Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Thu, 4 Sep 2025 22:09:55 -0700 Subject: [PATCH 06/78] Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- setup.py | 7 ++++--- tests/cpp/CMakeLists.txt | 1 + transformer_engine/common/__init__.py | 5 +++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 52adaf9238..ed1f5b8a9d 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ from build_tools.te_version import te_version from build_tools.utils import ( cuda_archs, + cuda_version, get_frameworks, remove_dups, ) @@ -70,11 +71,11 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))): cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON") cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution( - "nvidia-cublasmp-cu12" - ).locate_file("nvidia/cublasmp/cu12") + f"nvidia-cublasmp-cu{cuda_version()[0]}" + ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( - "nvidia-nvshmem-cu12" + f"nvidia-nvshmem-cu{cuda_version()[0]}" ).locate_file("nvidia/nvshmem") cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") print("CMAKE_FLAGS:", cmake_flags[-2:]) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index c2c9d0d915..412c5d34d9 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -43,5 +43,6 @@ include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 7feb5fda5f..dd1ec480b2 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -218,6 +218,11 @@ def _nvidia_cudart_include_dir() -> str: except ModuleNotFoundError: return "" + # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" + # above doesn't through. However, they don't set "__file__" attribute. + if nvidia.__file__ is None: + return "" + include_dir = Path(nvidia.__file__).parent / "cuda_runtime" return str(include_dir) if include_dir.exists() else "" From c47f329b2084406093124851a3aeecb935183def Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 5 Sep 2025 09:56:15 -0700 Subject: [PATCH 07/78] [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 16 ++- tests/jax/test_distributed_layernorm_mlp.py | 17 ++- transformer_engine/jax/activation.py | 12 +- .../jax/cpp_extensions/activation.py | 49 +++---- transformer_engine/jax/cpp_extensions/gemm.py | 16 ++- .../jax/cpp_extensions/normalization.py | 34 ++--- .../jax/cpp_extensions/quantization.py | 70 +++++----- transformer_engine/jax/dense.py | 37 +++--- transformer_engine/jax/layernorm.py | 5 +- transformer_engine/jax/layernorm_dense.py | 12 +- transformer_engine/jax/layernorm_mlp.py | 23 +++- transformer_engine/jax/quantize/quantizer.py | 42 ++++-- .../jax/quantize/scaling_modes.py | 86 +++++++++++- transformer_engine/jax/quantize/tensor.py | 124 +++++++++++++----- 14 files changed, 359 insertions(+), 184 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index d5f21651db..11f07d9133 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( + NoScaleTensor, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -182,7 +183,7 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type) + return _jax_act_lu(x, activation_type).data def value_n_grad_ref_func(self, x, activation_type): jitted_reference = jit( @@ -337,8 +338,8 @@ def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantize ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - # if isinstance(ln_out, ScaledTensor): - # ln_out = ln_out.dequantize() + # This is a no-op for non-quantized data + ln_out = ln_out.dequantize() return ln_out key = jax.random.PRNGKey(0) @@ -765,7 +766,9 @@ def _test_quantize_dact_dbias( te_output, jax_output, precise_comparison=precise_comparison ) else: - assert_allclose(te_output, jax_output) + assert isinstance(te_output, NoScaleTensor) + assert isinstance(jax_output, NoScaleTensor) + assert_allclose(te_output.data, jax_output.data) if is_dbias: # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. @@ -1020,8 +1023,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - if isinstance(ln_out, ScaledTensor): - ln_out = ln_out.dequantize() + ln_out = ln_out.dequantize() return ln_out @@ -1177,7 +1179,7 @@ def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = _jax_act_lu(linear_1_out, activation_type) + x = _jax_act_lu(linear_1_out, activation_type).data linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 90b762c240..a44921c641 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -173,7 +173,9 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with fp8_autocast( + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() + ): single_jitter = jax.jit( value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), @@ -184,7 +186,7 @@ def _test_layernorm_mlp_grad( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource ): k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) @@ -226,7 +228,12 @@ def _test_layernorm_mlp_grad( fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + + if fwd_test_type == jnp.float16 and use_bias: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5) + else: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): @@ -252,7 +259,7 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, @@ -281,7 +288,7 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( self, diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index ef6def2d03..12b35ec43c 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -14,7 +14,7 @@ from . import cpp_extensions as tex -from .quantize.tensor import ScaledTensor +from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer @@ -22,7 +22,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, -) -> Union[jnp.ndarray, ScaledTensor]: +) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. This function applies a sequence of activation functions to the input tensor. @@ -72,8 +72,8 @@ def _activation_fwd_rule(x, activation_type, quantizer): Tuple of (output, context) for backward pass """ fwd_output = tex.act_lu(x, activation_type, quantizer) - if isinstance(fwd_output, ScaledTensor): - fwd_output = fwd_output.dequantize() + # This is a no-op for higher-precision tensors + fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) @@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g): (x, _) = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type) + # No quantization is used in this VJP backward, so the output should + # always be a NoScaleTensor + assert isinstance(dx, NoScaleTensor) + dx = dx.data return (dx, None) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index fe2253598f..d3c7d2b086 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -29,7 +29,7 @@ ) from .quantization import _jax_dbias, _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: +def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ @@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S x = jnp.squeeze(x, axis=-2) if quantizer: return quantizer.quantize(x, flatten_axis=-1) - return x + return NoScaleTensor(data=x, amax=None) def _jax_quantize_dact_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, @@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias( _, vjp_func = jax.vjp( partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) ) - (dx,) = vjp_func(dz.astype(jnp.float32)) + # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. + dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) + (dx,) = vjp_func(dz) dbias = None if is_dbias: @@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias( dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) else: dx = dx.astype(x.dtype) + dx = NoScaleTensor(data=dx, amax=None) return dx, dbias @@ -981,7 +984,6 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -990,7 +992,6 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1035,10 +1036,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) - if noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype - ) + out = NoScaleTensor( + data=out, + amax=None, + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1092,7 +1093,6 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1103,7 +1103,6 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1146,19 +1145,10 @@ def quantize_dact_dbias( if is_dbias: dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - if noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, - None, - output, - None, - ScalingMode.NO_SCALING, - dq_dtype=output.dtype, - ), - dbias, - ) - + output = NoScaleTensor( + data=output, + amax=None, + ) return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet @@ -1167,7 +1157,7 @@ def quantize_dact_dbias( dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None ) return _quantize_dbias_impl( - out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 ) is_gated = act_len == 2 @@ -1194,7 +1184,7 @@ def quantize_dact_dbias( quantizer=None, ) out, dbias = _quantize_dbias_impl( - out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 ) return out, dbias @@ -1258,7 +1248,6 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1268,7 +1257,6 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1279,6 +1267,5 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, - noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index be73f708e2..acc8d67274 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -22,6 +22,8 @@ from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize from ..quantize import ( + AbstractBaseTensor, + NoScaleTensor, ScaledTensor, ScaledTensor2x, GroupedScaledTensor1x, @@ -228,6 +230,11 @@ def _dims_are_consecutive(dims): "require non-transposed LHS and transposed RHS operands " "(`contracting_dims=((-1, ), (-1, ))`)." ) + else: + assert lhs.dtype == rhs.dtype, ( + "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." + f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" + ) # Determine output shape and dtype assert ( @@ -1134,8 +1141,8 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( - lhs: Union[jnp.ndarray, ScaledTensor], - rhs: Union[jnp.ndarray, ScaledTensor], + lhs: Union[jnp.ndarray, AbstractBaseTensor], + rhs: Union[jnp.ndarray, AbstractBaseTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, @@ -1191,6 +1198,11 @@ def gemm( compute the GeLU contribution to the gradient. Only supported with TE's custom call to cuBLAS GEMM. """ + if isinstance(lhs, NoScaleTensor): + lhs = lhs.data + if isinstance(rhs, NoScaleTensor): + rhs = rhs.data + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility if lhs_quantizer is None or rhs_quantizer is None: quantizer_set = kwargs.get("quantizer_set", None) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 7296afc725..de1877de5c 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -30,7 +30,7 @@ ) from .quantization import _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -845,6 +845,7 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1) @@ -869,6 +870,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(rsigma, axis=-1) @@ -930,7 +932,7 @@ def layernorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, mu, rsigma + return NoScaleTensor(data=output, amax=None), mu, rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING @@ -1064,7 +1066,7 @@ def layernorm_bwd( ) mu_empty = jnp.zeros(mu.shape, mu.dtype) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, mu_empty, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), mu_empty, rsigma_empty)) return NormBwdPrimitive.outer_primitive.bind( dz, x, @@ -1133,14 +1135,14 @@ def rmsnorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, rsigma + return NoScaleTensor(data=output, amax=None), rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) - out, _ = _quantize_dbias_impl(out, quantizer) + out, _ = _quantize_dbias_impl(out.data, quantizer) return out, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1152,7 +1154,9 @@ def rmsnorm_fwd( epsilon=epsilon, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + ) return out, rsigma is_2x2x = quantizer.is_2x2x() @@ -1254,7 +1258,7 @@ def rmsnorm_bwd( gamma, ) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), rsigma_empty)) mu = jnp.empty(()) dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind( dz, @@ -1276,7 +1280,6 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], - noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1293,7 +1296,6 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1321,20 +1323,6 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") - if quantizer is None and noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, - None, - output, - None, - scaling_mode=ScalingMode.NO_SCALING, - dq_dtype=output.dtype, - ), - mu, - rsigma, - ) - return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 198beb55eb..1813734b5e 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -4,7 +4,7 @@ """JAX/TE custom ops for quantization""" import operator from functools import reduce -from typing import Tuple, Optional +from typing import Tuple, Optional, Union import math from packaging import version @@ -38,6 +38,7 @@ QuantizeLayout, ScalingMode, compute_scale_from_amax, + NoScaleTensor, ) if version.parse(jax.__version__) >= version.parse("0.5.0"): @@ -64,7 +65,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): 7, 8, 9, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer inner_primitive = None outer_primitive = None @@ -535,11 +536,15 @@ def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) -def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): +def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1): + if isinstance(dx, NoScaleTensor): + dx = dx.data sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype @@ -558,7 +563,9 @@ def _jax_quantize_dbias( flatten_axis: int = -1, ): if quantizer is None: - return x, None + if isinstance(x, NoScaleTensor): + return x, None + return NoScaleTensor(data=x, amax=None), None return ( quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis), _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis), @@ -566,12 +573,11 @@ def _jax_quantize_dbias( def _quantize_dbias_impl( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -581,28 +587,15 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + # Early-exit for non-quantized call - dq_dtype = dq_dtype or x.dtype + dq_dtype = dq_dtype or x.data.dtype if quantizer is None: dbias = None if is_dbias: - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - if noop_scaled_tensor: - # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() - # always works. - return ( - ScaledTensorFactory.create_2x( - x, - None, - x, - None, - scaling_mode=ScalingMode.NO_SCALING, - dq_dtype=x.dtype, - data_layout="NN", - flatten_axis=flatten_axis, - ), - dbias, - ) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return x, dbias # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, @@ -630,21 +623,25 @@ def _quantize_dbias_impl( dq_dtype=dq_dtype, flatten_axis=flatten_axis, ) - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias scale = jnp.empty((), jnp.float32) + amax = None if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). - amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) + amax = x.amax + if amax is None: + amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # Make sure amax is init with zero - amax = jnp.zeros((1,), jnp.float32) + if amax is None: + amax = jnp.zeros((1,), jnp.float32) # It is faster to use 1x quantization for tensor scaling is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) @@ -665,7 +662,7 @@ def _quantize_dbias_impl( updated_amax, dbias, ) = PrimitiveClass.outer_primitive.bind( - x, + x.data, scale, amax, out_dtype=quantizer.q_dtype, @@ -706,10 +703,9 @@ def _quantize_dbias_impl( def quantize( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -719,7 +715,6 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer is None. Returns: @@ -729,17 +724,15 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) return out def quantize_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -750,8 +743,6 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when - quantizer is None. Returns: A tuple containing: @@ -765,7 +756,6 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) @@ -968,7 +958,9 @@ def grouped_quantize( """ if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) # TODO(Phuong): add support for flatten_axis = -2 assert flatten_axis in ( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 65d65e7d4a..b0ba734e5f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -24,6 +24,7 @@ with_sharding_constraint_by_logical_axes, is_fp8_gemm_with_all_layouts_supported, TensorUsage, + get_quantize_config, ) @@ -80,23 +81,19 @@ def dense( Returns: Transformed output tensor """ - # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): - x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims=contracting_dims) - if bias is not None: - bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape - output += jnp.reshape(bias, bias_new_shape) - else: - output = _dense( - x, - kernel, - bias, - contracting_dims, - input_axes, - kernel_axes, - quantizer_set, - ) + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + + output = _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + quantizer_set, + ) return output @@ -175,7 +172,9 @@ def _dense_fwd_rule( flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) casted_x = tex.quantize( - x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + x, + flatten_axis=flatten_axis_x, + quantizer=quantizer_set.x, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -183,7 +182,6 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -240,7 +238,6 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # GEMM NT diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 7a3ad597bf..0f5c6aeef6 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -17,7 +17,6 @@ from . import cpp_extensions as tex from .quantize import ( - ScaledTensor, Quantizer, ) @@ -112,8 +111,8 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps output, mu, rsigma = tex.normalization_fwd( x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer ) - if isinstance(output, ScaledTensor): - output = output.dequantize() + # This is a no-op for higher-precision tensors + output = output.dequantize() return output, (x, mu, rsigma, gamma, beta, quantizer) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index b830cdb4ff..fb97830759 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -22,6 +22,7 @@ noop_quantizer_set, with_sharding_constraint_by_logical_axes, TensorUsage, + get_quantize_config, ) @@ -68,6 +69,11 @@ def layernorm_dense( - The function supports automatic differentiation through JAX's custom VJP - Quantization is applied to both the normalized input and kernel """ + + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + output = _layernorm_dense( x, kernel, @@ -188,14 +194,15 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + kernel, + flatten_axis=flatten_axis, + quantizer=quantizer_set.kernel, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -278,7 +285,6 @@ def _layernorm_dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 00e3ddc3e8..fc957801af 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -27,6 +27,7 @@ QuantizerSet, noop_quantizer_set, TensorUsage, + get_quantize_config, ) @@ -104,6 +105,11 @@ def layernorm_mlp( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel_1 = kernel_1.astype(input_dtype) + kernel_2 = kernel_2.astype(input_dtype) + output = _layernorm_mlp( x, gamma, @@ -266,12 +272,13 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + kernel_1, + flatten_axis=-2, + quantizer=ffn1_quantizer_set.kernel, ) # NN GEMM @@ -300,13 +307,16 @@ def _layernorm_mlp_fwd_rule( # (batch..., hidden_in) -> (batch..., hidden) casted_act_out = tex.act_lu( - dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + dot_1_output, + activation_type, + quantizer=ffn2_quantizer_set.x, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_kernel_2 = tex.quantize( - kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + kernel_2, + quantizer=ffn2_quantizer_set.kernel, ) # NN GEMM @@ -404,7 +414,9 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True + grad, + is_dbias=use_bias_2, + quantizer=ffn1_quantizer_set.dgrad, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -445,7 +457,6 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 6cecfa361f..306603bbe1 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -19,7 +19,13 @@ from transformer_engine.common import recipe from .scaling_modes import ScalingMode -from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory +from .tensor import ( + ScaledTensor, + ScaledTensor1x, + ScaledTensor2x, + ScaledTensorFactory, + NoScaleTensor, +) from .helper import ( get_quantize_config, get_quantize_config_class, @@ -217,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer): data_layout: str = "NT" def _quantize_func( - self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 + self, + x: Union[jnp.ndarray, NoScaleTensor], + is_colwise=False, + dq_dtype=None, + flatten_axis=-1, ) -> ScaledTensor1x: """Quantize function helper for delayed scaling FP8. @@ -229,14 +239,17 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - amax = jnp.max(jnp.abs(x)).reshape((1,)) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) - scaled_x = x.astype(compute_dtype) * scale + scaled_x = x.data.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / scale @@ -263,7 +276,10 @@ def quantize( Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype if flatten_axis < 0: flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" @@ -347,11 +363,14 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - scaled_x = x.astype(compute_dtype) * self.scale + scaled_x = x.data.astype(compute_dtype) * self.scale # quantize() in the old dot.py do this way, leave this code block here for future debugging # compute_dtype = x.dtype @@ -360,7 +379,8 @@ def _quantize_func( clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / self.scale - self.update(jnp.max(jnp.abs(x)).reshape((1,))) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) + self.update(amax) return ScaledTensorFactory.create_1x( data=clipped_scaled_x, scale_inv=scale_inv, @@ -460,6 +480,10 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> Returns: A ScaledTensor1x containing the quantized data """ + if isinstance(x, NoScaleTensor): + # No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x. + x = x.data + # TODO(Phuong): use quantize_func from JAX if flatten_axis < 0: flatten_axis = x.ndim + flatten_axis diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 868570f73c..e81a614f0e 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -166,6 +166,90 @@ def get_shardy_sharding_rules( """ +class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): + """Implementation for no scaling mode. + + This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16. + """ + + def get_scale_dtype(self) -> jnp.dtype: + """Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling. + + Returns: + The data type used for scale tensors (float32) + """ + return jnp.float32 + + def get_scale_shape( + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, + ) -> Tuple[int, ...]: + """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. + + Args: + data_shape: The shape of the tensor being scaled + is_colwise: Whether the scaling is column-wise + is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors - (1,) + """ + del data_shape, is_colwise, is_padded, flatten_axis + return (0,) + + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + return QuantizeLayout.ROWWISE + + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: + """Get the shape for scale tensors in this mode. + + Args: + data_shape: Original shape of the data tensor + is_colwise: Whether to use column-wise scaling + is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors + """ + del data_shape, group_axis, is_colwise + assert isinstance(n_groups, int) + return (n_groups,) + + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + del flatten_axis + input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + scale_var = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + + class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for current scaling mode. @@ -740,5 +824,5 @@ def tree_unflatten(cls, aux_data, _children): ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), - ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 1459175b79..dbbac4abcc 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -25,6 +25,8 @@ __all__ = [ "TensorUsage", + "AbstractBaseTensor", + "NoScaleTensor", "ScaledTensor", "ScaledTensor1x", "ScaledTensor2x", @@ -34,14 +36,9 @@ ] -@register_pytree_node_class @dataclass -class ScaledTensor(ABC): - """Abstract base class for scaled tensors. - - This class defines the interface for all scaled tensor implementations, - providing methods for dequantization and accessing row/column-wise components. - """ +class AbstractBaseTensor(ABC): + """Abstract base class for all tensor types.""" @classmethod def tree_unflatten(cls, aux_data, children): @@ -93,9 +90,76 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st """ +@dataclass +class AbstractBaseTensor1x(AbstractBaseTensor): + """Abstract base class for single layout tensors.""" + + data: jnp.ndarray + amax: jnp.ndarray + + @register_pytree_node_class @dataclass -class ScaledTensor1x(ScaledTensor): +class NoScaleTensor(AbstractBaseTensor1x): + """Higher-precision tensor.""" + + def __post_init__(self): + assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray." + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations. + + Returns: + A tuple containing (children, aux_data) for tree operations + """ + children = (self.data, self.amax) + aux_data = () + return (children, aux_data) + + @property + def ndim(self): + """Number of dimensions of the underlying array.""" + return self.data.ndim + + def dequantize(self): + """This is a no-op for a higher-precision tensor so this simply returns the tensor's data.""" + return self.data + + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) + assert ( + q_layout == QuantizeLayout.ROWWISE + ), "Only ROWWISE layout is supported for NoScaleTensor" + return self + + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names) + + return NoScaleTensor( + data=data, + amax=self.amax, + ) + + +class ScaledTensor(ABC): + """Abstract base class for scaled tensors.""" + + +@register_pytree_node_class +@dataclass +class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): """Single-scale quantized tensor implementation. This class represents a tensor quantized with a single scaling factor, @@ -113,9 +177,7 @@ class ScaledTensor1x(ScaledTensor): flatten_axis: The quantization axis for the tensor """ - data: jnp.ndarray scale_inv: jnp.ndarray - amax: jnp.ndarray scaling_mode: ScalingMode dq_dtype: jnp.dtype _dq_func: Callable @@ -154,7 +216,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.amax) + children = (self.data, self.amax, self.scale_inv) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -274,15 +336,15 @@ def __init__( self.original_shape = original_shape self.group_axis = group_axis super().__init__( - data, - scale_inv, - amax, - scaling_mode, - dq_dtype, - _dq_func, - is_colwise, - data_layout, - flatten_axis, + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=_dq_func, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) def __post_init__(self): @@ -339,7 +401,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st @register_pytree_node_class @dataclass -class ScaledTensor2x(ScaledTensor): +class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): """Double-scale quantized tensor implementation. This class represents a tensor quantized with both row-wise and column-wise scaling factors. @@ -503,15 +565,15 @@ def create_1x( flatten_axis = data.ndim - flatten_axis return ScaledTensor1x( - data, - scale_inv, - amax, - scaling_mode, - dq_dtype, - dequantizer.dequantize, - is_colwise, - data_layout, - flatten_axis, + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=dequantizer.dequantize, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) @staticmethod @@ -675,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . if isinstance(x, GroupedScaledTensor1x): raise NotImplementedError - if isinstance(x, ScaledTensor): + if isinstance(x, AbstractBaseTensor): return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) From 5b3d65cc1a157ac76d0e4c6342db0c5d80f69984 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Sep 2025 11:26:52 -0400 Subject: [PATCH 08/78] [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen --- transformer_engine/jax/dense.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index b0ba734e5f..8087159a3a 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -442,7 +442,7 @@ def _grouped_dense_fwd_rule( ctx_kernel = ScaledTensorFactory.create_1x( global_ctx_kernel_data.reshape(-1), ctx_kernel.scale_inv, - ctx_kernel.scaling_mode, + scaling_mode=ctx_kernel.scaling_mode, dq_dtype=ctx_kernel.dq_dtype, is_colwise=False, data_layout="N", @@ -459,7 +459,7 @@ def _grouped_dense_fwd_rule( grouped_gemm_kernel = ScaledTensorFactory.create_1x( grouped_gemm_kernel_data.reshape(-1), ctx_kernel.scale_inv, - ctx_kernel.scaling_mode, + scaling_mode=ctx_kernel.scaling_mode, dq_dtype=ctx_kernel.dq_dtype, is_colwise=True, data_layout="T", From aa06107cbc1cc7378c665809c1608c53070447ea Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Mon, 8 Sep 2025 11:28:13 -0400 Subject: [PATCH 09/78] Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen --- tests/jax/multi_process_launch.sh | 6 +++--- ...est_multi_process_distributed_grouped_gemm.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh index 3e0852f393..fcb066de75 100644 --- a/tests/jax/multi_process_launch.sh +++ b/tests/jax/multi_process_launch.sh @@ -12,12 +12,12 @@ XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true export XLA_FLAGS="${XLA_BASE_FLAGS}" -NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader) +NUM_RUNS=$(nvidia-smi -L | wc -l) for ((i=1; i /dev/null 2>&1 & + CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 & done -CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS wait diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index 6fce62d8cc..31209d1bc9 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +import jax.experimental.multihost_utils as jem from transformer_engine.jax.dense import grouped_dense as te_grouped_dense from transformer_engine.jax.quantize import ( @@ -13,7 +14,7 @@ ScalingMode, ) -from utils import assert_allclose +from utils import assert_allclose, dtype_tols N_GROUP = 8 @@ -137,9 +138,16 @@ def run(x, w): out, dx, dw = test_func_jitted(x, w, w_amax) ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) - assert_allclose(out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(dx, ref_dx, dtype=jnp.float8_e5m2) - assert_allclose(dw, ref_dw, dtype=jnp.float8_e5m2) + e4m3_tols = dtype_tols(jnp.float8_e4m3fn) + e5m2_tols = dtype_tols(jnp.float8_e5m2) + + out, ref_out = jem.process_allgather((out, ref_out)) + dx, ref_dx = jem.process_allgather((dx, ref_dx)) + dw, ref_dw = jem.process_allgather((dw, ref_dw)) + + jnp.allclose(out, ref_out, **e4m3_tols) + jnp.allclose(dx, ref_dx, **e5m2_tols) + jnp.allclose(dw, ref_dw, **e5m2_tols) if __name__ == "__main__": From 603dbf72e4529868bcefd68bd5f901b84093626e Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 8 Sep 2025 10:55:57 -0700 Subject: [PATCH 10/78] Update list of authorized CI users (#2152) Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 85a81a6d48..f12a95d79a 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -57,6 +57,7 @@ jobs: || github.actor == 'tdophung' || github.actor == 'vthumbe1503' || github.actor == 'janekb04' + || github.actor == 'shengfangd' ) steps: - name: Check if comment is issued by authorized person From 84fa28d2477c2243ab32cc02ba83faebc59a9e6b Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:54:26 -0700 Subject: [PATCH 11/78] Fused RoPE with combined QKV input. (#2122) * Fused RoPE with combined QKV input. Initial commit for Dropout with 8-bit RNG Fix documentation Initial commit for Fused QKV RoPE WIP Initial tests passing Enable rotary percent and margin Enable CP2, start_positions, interleaved Cleanup test Revert "Fix documentation" This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715. Revert "Initial commit for Dropout with 8-bit RNG" This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca. Cleanup. Minor cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernels Signed-off-by: Vasudevan Rengasamy * Misc. Cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernel performance Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Move fused_qkv_rope test to test_fused_rope.py Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply shared memory optimization to separate fused rope kernels Signed-off-by: Xin Yao * fix lint Signed-off-by: Xin Yao --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fused_rope.py | 147 +++++- .../common/fused_rope/fused_rope.cu | 457 ++++++++++++++++-- .../include/transformer_engine/fused_rope.h | 63 +++ transformer_engine/pytorch/attention/rope.py | 163 ++++++- transformer_engine/pytorch/csrc/extensions.h | 12 + .../pytorch/csrc/extensions/apply_rope.cpp | 97 ++++ .../pytorch/csrc/extensions/pybind.cpp | 4 + 7 files changed, 898 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index ae25af9499..62d80b5529 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,25 +1,32 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from typing import Callable, Tuple, Union +from typing import Callable, Tuple, Union, List import math import torch import pytest from transformer_engine.pytorch.attention.rope import ( RotaryPositionEmbedding, apply_rotary_pos_emb, + apply_fused_qkv_rotary_pos_emb, ) # Gradient is a broadcasted scalar -def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: - return output.sum() * 2 +def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(t.sum() * 2 for t in output) + else: + return output.sum() * 2 # Gradient is a full tensor -def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: - t = torch.ones_like(output) - return torch.sum(output * t) +def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(torch.sum(t * torch.ones_like(t)) for t in output) + else: + t = torch.ones_like(output) + return torch.sum(output * t) @pytest.mark.parametrize("start_positions", [True, False]) @@ -238,3 +245,131 @@ def test_fused_rope_thd( torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() + + +@pytest.mark.parametrize("start_positions", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096]) +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("margin", [0, 10]) +@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) +def test_fused_qkv_rope( + dtype: torch.dtype, + seq_length: int, + hidden_size: int, + rotary_percent: float, + margin: int, + tensor_format: str, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + if margin == 0 and start_positions == True: + # This makes sure that the `start_positions` offsets being applied + # are with the maximum length of the rope embeddings. + pytest.skip("Skipping test with margin=0 and start_positions=True") + + if start_positions == True and cp_size > 1: + # `start_positions` is only supported for `cp_size=1` and inference. + pytest.skip("Skipping test with cp_size>1 and start_positions=True") + + if seq_length - margin < 0: + pytest.skip("Skipping test with seq_length - margin < 0") + + device = torch.device("cuda:0") + batch_size, head_num = 2, 64 + + t = torch.rand( + (seq_length - margin, batch_size, head_num, hidden_size * 6), + dtype=dtype, + device=device, + ) + + # Get arbitrary offsets to be used with RoPE for all the sequences + start_positions = ( + torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device) + if start_positions + else None + ) + + if tensor_format == "bshd": + t = t.transpose(0, 1).contiguous() + t.requires_grad = True + + rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_q = rotary_pos_emb_q(seq_length * cp_size) + rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_k = rotary_pos_emb_k(seq_length * cp_size) + + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + + t_clone = t.clone() + (query, key, value) = torch.split( + t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3 + ) + query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size) + + query_unfused = apply_rotary_pos_emb( + query, + emb_q, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + key_unfused = apply_rotary_pos_emb( + key, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + value_unfused = value + loss_unfused = loss_func([query_unfused, key_unfused, value_unfused]) + + if not isinstance(start_positions, torch.Tensor): + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + + t.grad = None + + # fused + query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb( + t, + emb_q, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + cp_size=cp_size, + cp_rank=cp_rank, + qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size], + ) + loss_fused = loss_func([query_fused, key_fused, value_fused]) + + if not isinstance(start_positions, torch.Tensor): + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(query_fused, query_unfused) + torch.testing.assert_close(key_fused, key_unfused) + torch.testing.assert_close(value_fused, value_unfused) + + if not isinstance(start_positions, torch.Tensor): + torch.testing.assert_close(grad_fused, grad_unfused) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index df9ea6ee5f..ccd0bc44c5 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos, v_sin; - sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + float v_cos = shared_mem_cos[d_id]; + float v_sin = shared_mem_sin[d_id]; int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; @@ -49,12 +58,12 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos = cosf(freqs[s_id * d2 + d_id]); - float v_sin; - if (!interleaved) { - v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) - : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); - } else { - v_sin = - (d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]); - } + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate; + float v_cos = shared_mem_cos[d_id]; + float v_src_rotate, v_sin; if (!interleaved) { - v_src_rotate = (d_id + d2 / 2 < d2) - ? static_cast(src[offset_src + (d2 / 2) * stride_d]) - : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(src[offset_src + (d2 / 2) * stride_d]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } } else { - v_src_rotate = (d_id % 2 == 0) - // d_id + 1 - ? static_cast(src[offset_src + stride_d]) - // d_id - 1 - : static_cast(src[offset_src - stride_d]); + if (d_id % 2 == 0) { + v_src_rotate = static_cast(src[offset_src + stride_d]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(src[offset_src - stride_d]); + v_sin = -shared_mem_sin[d_id - 1]; + } } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } - // handle the tail + // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel( offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } +template +__device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out, + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + // Split the shared memory into cos and sin parts for q or k + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); + +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + v_sin = shared_mem_sin[d_id]; + float v_src = src[offset_src]; + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? -static_cast(src[offset_src + (d2 / 2)]) + : static_cast(src[offset_src + (d2 / 2 - d2)]); + } else { + v_src_rotate = (d_id % 2 == 0) ? -static_cast(src[offset_src + 1]) + : static_cast(src[offset_src - 1]); + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = src[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = src[offset_src]; + } + } + } + } +} + +template +__device__ void fused_qkv_rope_block_backward(const scalar_t *grad_out, const float *freqs, + scalar_t *out, const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + // Split the shared memory into cos and sin parts for q or k + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + + float v_src = grad_out[offset_src]; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + float v_src_rotate; + if (!interleaved) { + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2)]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2 - d2)]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } + } else { + if (d_id % 2 == 0) { + v_src_rotate = static_cast(grad_out[offset_src + 1]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(grad_out[offset_src - 1]); + v_sin = -shared_mem_sin[d_id - 1]; + } + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = grad_out[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = grad_out[offset_src]; + } + } + } + } +} + +template +__global__ void fused_qkv_rope_forward_kernel( + const scalar_t *qkv_input, const float *q_freqs, const float *k_freqs, + const int *start_positions, scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int total_d = q_split_arg + k_split_arg + v_split_arg; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + s_id_for_freqs = s_id + begin_offset; + } + fused_qkv_rope_block_forward(qkv_input, q_freqs, q_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_q, h, d, d2, 0, total_d, q_split_arg); + fused_qkv_rope_block_forward(qkv_input, k_freqs, k_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_k, h, d, d2, q_limit, total_d, k_split_arg); + fused_qkv_rope_block_forward(qkv_input, nullptr, v_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_v, h, d, d2, k_limit, total_d, v_split_arg); +} + +template +__global__ void fused_qkv_rope_backward_kernel( + const scalar_t *grad_out_q, const scalar_t *grad_out_k, const scalar_t *grad_out_v, + const float *q_freqs, const float *k_freqs, scalar_t *qkv_grad, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + int total_d = q_split_arg + k_split_arg + v_split_arg; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_qkv_rope_block_backward(grad_out_q, q_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_q, h, d, d2, 0, total_d, + q_split_arg); + fused_qkv_rope_block_backward(grad_out_k, k_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_k, h, d, d2, q_limit, total_d, + k_split_arg); + fused_qkv_rope_block_backward(grad_out_v, nullptr, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_v, h, d, d2, k_limit, total_d, + v_split_arg); +} + template void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, const int *start_positions, scalar_t *output, @@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_forward_kernel<<>>( + fused_rope_forward_kernel<<>>( input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); @@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_backward_kernel<<>>( + fused_rope_backward_kernel<<>>( output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void fused_qkv_rope_forward_launcher(const scalar_t *qkv_input, const float *q_freqs, + const float *k_freqs, const int *start_positions, + scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_forward_kernel<<>>( + qkv_input, q_freqs, k_freqs, start_positions, q_out, k_out, v_out, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void fused_qkv_rope_backward_launcher(const scalar_t *q_grad_out, const scalar_t *k_grad_out, + const scalar_t *v_grad_out, const float *q_freqs, + const float *k_freqs, scalar_t *qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, + const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + const int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_backward_kernel<<>>( + q_grad_out, k_grad_out, v_grad_out, q_freqs, k_freqs, qkv_grad_input, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, const Tensor &start_positions, Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, @@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c stride_b, stride_h, stride_d, stream);); } +void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs, + const Tensor &start_positions, Tensor *q_out, Tensor *k_out, + Tensor *v_out, const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + qkv_input.data.dtype, scalar_t, + fused_qkv_rope_forward_launcher(reinterpret_cast(qkv_input.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), + reinterpret_cast(q_out->data.dptr), + reinterpret_cast(k_out->data.dptr), + reinterpret_cast(v_out->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} + +void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out, + const Tensor &v_grad_out, const Tensor &q_freqs, const Tensor &k_freqs, + Tensor *qkv_grad_input, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + q_grad_out.data.dtype, scalar_t, + fused_qkv_rope_backward_launcher(reinterpret_cast(q_grad_out.data.dptr), + reinterpret_cast(k_grad_out.data.dptr), + reinterpret_cast(v_grad_out.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(qkv_grad_input->data.dptr), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} } // end namespace transformer_engine void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, @@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } + +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_forward); + using namespace transformer_engine; + fused_qkv_rope_forward(*convertNVTETensorCheck(qkv_input), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(q_out), convertNVTETensorCheck(k_out), + convertNVTETensorCheck(v_out), qkv_format, interleaved, cp_size, cp_rank, + s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream); +} + +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_backward); + using namespace transformer_engine; + fused_qkv_rope_backward(*convertNVTETensorCheck(q_grad_out), *convertNVTETensorCheck(k_grad_out), + *convertNVTETensorCheck(v_grad_out), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), convertNVTETensorCheck(qkv_grad_input), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index f0817a97fe..610868f932 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream); +/*! \brief Apply rotary positional embedding to the combined QKV input tensor. + * + * \param[in] qkv_input Combined QKV input tensor for fused rope. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. + * \param[out] q_out Output tensor for Q. + * \param[out] k_out Output tensor for K. + * \param[out] v_out Output tensor for V. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + +/*! \brief Compute the backward of the fused qkv rope. + * + * \param[in] q_grad_out Incoming gradient tensor for Q. + * \param[in] k_grad_out Incoming gradient tensor for K. + * \param[in] v_grad_out Incoming gradient tensor for V. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[out] qkv_grad_input Input gradient tensor to calculate. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 60685a31d9..139381f2dd 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,14 +5,14 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] +__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] class RotaryPositionEmbedding(torch.nn.Module): @@ -170,6 +170,86 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], return grad_input, None, None, None, None, None, None, None +class FusedQKVRoPEFunc(torch.autograd.Function): + """ + Function for FusedQKVRoPE + + This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs. + The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input. + """ + + @staticmethod + def forward( + ctx, + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + start_positions: Union[torch.Tensor, None] = None, + tensor_format: str = "sbhd", + interleaved: bool = False, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + """Fused RoPE forward.""" + + if q_freqs.dtype != torch.float32: + q_freqs = q_freqs.float() + if k_freqs.dtype != torch.float32: + k_freqs = k_freqs.float() + assert tensor_format in ( + "sbhd", + "bshd", + ), f"Unsupported tensor_format: {tensor_format}." + assert qkv.is_contiguous(), "QKV Tensor should be contiguous." + assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous." + assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous." + output = tex.fused_qkv_rope_forward( + qkv, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + QKVFormat[tensor_format], + interleaved, + cp_size, + cp_rank, + ) + ctx.save_for_backward(q_freqs, k_freqs) + ctx.tensor_format = tensor_format + ctx.qkv_split_arg_list = qkv_split_arg_list + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank + ctx.interleaved = interleaved + return output + + @staticmethod + def backward( + ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused RoPE backward.""" + q_freqs, k_freqs = ctx.saved_tensors + + grad_output_q = grad_output_q.contiguous() + grad_output_k = grad_output_k.contiguous() + grad_output_v = grad_output_v.contiguous() + + grad_input = tex.fused_qkv_rope_backward( + grad_output_q, + grad_output_k, + grad_output_v, + q_freqs, + k_freqs, + ctx.qkv_split_arg_list, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + ctx.cp_size, + ctx.cp_rank, + ) + + return grad_input, None, None, None, None, None, None, None, None + + def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: """Change sign so the last dimension becomes [-odd, +even] @@ -393,3 +473,82 @@ def apply_rotary_pos_emb( tensor_format, interleaved=interleaved, ) + + +def apply_fused_qkv_rotary_pos_emb( + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + tensor_format: str = "sbhd", + start_positions: Union[torch.Tensor, None] = None, + interleaved: bool = False, + cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input qkv tensor. + + Support matrix: + Fused: + Training: + qkv_formats: "bshd", "sbhd" + context parallel: yes + start_positions: no + interleaving: yes + Inference: + qkv_formats: "bshd", "sbhd" + context parallelism: no + start_positions: yes + interleaving: yes + + Parameters + ---------- + qkv: torch.Tensor + Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which + rotary positional embedding will be applied. This tensor has q, k, v concatenated + along the last dimension. + q_freqs: torch.Tensor + Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + k_freqs: torch.Tensor + Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + qkv_split_arg_list: List[int] + List of integers that specify the split of the qkv tensor. The list should have 3 elements, + the first element is the number of elements in the q tensor, the second element is the number + of elements in the k tensor, and the third element is the number of elements in the v tensor. + The sum of the elements in the list should be equal to the last dimension of the qkv tensor. + start_positions: torch.Tensor, default = None. + Tokens in a sequence `i` should be applied with position encoding offset by + `start_positions[i]`. If `start_positions=None`, there's no offset. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is + of shape `[seq, bs, ...]`. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + cp_size: int, default = 1. + Context parallel world size. + cp_rank: int, default = 0. + Context parallel rank. + """ + + # `start_positions` is only supported for `cp_size=1` and inference. + assert not ( + cp_size > 1 and start_positions is not None + ), """start_positions != None with CP SIZE > 1 is not supported!""" + + assert tensor_format != "thd", "'thd' tensor_format not supported currently." + + return FusedQKVRoPEFunc.apply( + qkv, + q_freqs, + k_freqs, + qkv_split_arg_list, + start_positions, + tensor_format, + interleaved, + cp_size, + cp_rank, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a6b65562eb..4cb05725bc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor const std::optional cu_seqlens, const int cp_size, const int cp_rank); +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank); + +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank); + /*************************************************************************************************** * Miscellaneous **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 6f6f827252..d1ba1a351c 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, return output; } +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, + const int cp_rank) { + TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + // output + auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device()); + auto q_out_size = qkv_input.sizes().vec(); + q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1]; + q_out_size[3] = qkv_split_arg_list[1]; + auto q_out = at::empty(q_out_size, act_options); + auto k_out_size = qkv_input.sizes().vec(); + k_out_size[3] = qkv_split_arg_list[1]; + auto k_out = at::empty(k_out_size, act_options); + auto v_out_size = qkv_input.sizes().vec(); + v_out_size[3] = qkv_split_arg_list[2]; + auto v_out = at::empty(v_out_size, act_options); + + auto qkv_cu = makeTransformerEngineTensor(qkv_input); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto q_out_cu = makeTransformerEngineTensor(q_out); + auto k_out_cu = makeTransformerEngineTensor(k_out); + auto v_out_cu = makeTransformerEngineTensor(v_out); + + auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + } + + TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor"); + TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous"); + + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1); + const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0); + const int h = qkv_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(), + start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(), + v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h, + d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1], + qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(q_out, k_out, v_out); +} + at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, @@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor return input_grads; } +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank) { + auto act_options = + at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device()); + auto qkv_grad_size = q_grad_out.sizes().vec(); + auto total_hd = + (q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3); + auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2]; + qkv_grad_size[2] = total_hd / total_d; + qkv_grad_size[3] = total_d; + auto qkv_grad_input = at::empty(qkv_grad_size, act_options); + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1); + const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0); + const int h = qkv_grad_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out); + auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out); + auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input); + + nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(), + q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0], + qkv_split_arg_list[1], qkv_split_arg_list[2], + at::cuda::getCurrentCUDAStream()); + + return qkv_grad_input; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 541b16848e..7649ccb6d6 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Apply RoPE FWD", py::call_guard()); m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, "Fused Apply RoPE BWD", py::call_guard()); + m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward, + "Fused Apply QKV RoPE FWD", py::call_guard()); + m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, + "Fused Apply QKV RoPE BWD", py::call_guard()); // fused router m.def("fused_topk_with_score_function_fwd", From a26a7f1f660a416ad790123b577ba665191222db Mon Sep 17 00:00:00 2001 From: Autumn1998 <1515848689@qq.com> Date: Tue, 9 Sep 2025 10:22:24 +0800 Subject: [PATCH 12/78] Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162) * add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion Signed-off-by: tongliu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tongliu Co-authored-by: tongliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/fused_router/fused_moe_aux_loss.cu | 2 +- transformer_engine/common/fused_router/utils.h | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index a738be8736..94082594f6 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -229,7 +229,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, // Loop: for all positions in each row for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; - IndexType tokens_per_expert_i = tokens_per_expert[i]; + double tokens_per_expert_i = static_cast(tokens_per_expert[i]); double grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows for (int j = global_warp_id; j < num_rows; j += global_warp_num) { diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 46e0ba632c..b6f9d87bdc 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -246,6 +246,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i using type = int64_t; \ { __VA_ARGS__ } \ } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } From 5f2b83100c75fb633ef416d3685efb5e23062f5c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Sep 2025 07:43:28 -0400 Subject: [PATCH 13/78] [JAX] Scale swizzling via JAX transpose op (#2163) * add swizzle in jax Signed-off-by: Phuong Nguyen * added outer_impl Signed-off-by: Phuong Nguyen * clean up FFI Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 9 +- transformer_engine/jax/cpp_extensions/gemm.py | 97 +++++++++++++------ .../jax/csrc/extensions/gemm.cpp | 52 ++-------- 3 files changed, 81 insertions(+), 77 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index a27cec001a..c055705665 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -134,6 +134,13 @@ def impl(): """ return NotImplemented + @classmethod + def outer_impl(cls, *args, **kwargs): + """ + to describe implementation for outer primitive + """ + return cls.impl(*args, **kwargs) + @staticmethod @abstractmethod def batcher(): @@ -196,7 +203,7 @@ def name_of_wrapper_p(): outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results - outer_p.def_impl(cls.impl) + outer_p.def_impl(cls.outer_impl) outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index acc8d67274..2acc3fb68c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -152,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +@partial(jax.jit, static_argnums=(1, 2)) +def swizzled_scale(scale_inv, flatten_axis, is_colwise): + "Swizzle scale_inv via JAX transpose ops" + original_shape = scale_inv.shape + shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:])) + if is_colwise: + scale_inv = jnp.transpose(scale_inv.reshape(shape_2d)) + cols, rows = shape_2d + else: + rows, cols = shape_2d + reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4) + swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4)) + return swizzled.reshape(original_shape) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -286,28 +301,18 @@ def _dims_are_consecutive(dims): ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) - # Need extra workspace for swizzled scale factors - lhs_swizzle_size = 0 - rhs_swizzle_size = 0 - swizzle_dtype = jnp.uint8 - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_swizzle_size = lhs_scale_inv.size - rhs_swizzle_size = rhs_scale_inv.size - lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) - rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) - # Declare cuBLAS workspace # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size = get_cublas_workspace_size_bytes() + 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + return output, bias_grad, pre_gelu_out, workspace @staticmethod def outer_abstract(*args, **kwargs): outputs = GemmPrimitive.abstract(*args, **kwargs) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array @staticmethod def lowering( @@ -374,24 +379,22 @@ def impl( grad, use_split_accumulator, ): - lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) - lhs_transposed, rhs_transposed = _get_gemm_layout( - (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) - ) - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, - scaling_mode, - lhs.shape, - is_colwise=lhs_transposed, - flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, - scaling_mode, - rhs.shape, - is_colwise=not rhs_transposed, - flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, - ) + if scaling_mode.is_1d_block_scaling(): + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) + rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) + lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) + rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) outputs = GemmPrimitive.inner_primitive.bind( lhs, @@ -408,7 +411,39 @@ def impl( grad=grad, use_split_accumulator=use_split_accumulator, ) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array + + @staticmethod + def outer_impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + return GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ) @staticmethod def batcher( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 032ac9eb70..113072131d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { } std::tuple> xla_buffer_to_nvte_gemm_operand( - cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, - JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, + size_t axis_boundary, bool rowwise) { // Set tensor data with collapsed 2D shape auto buffer_dims = buffer.dimensions(); std::vector input_shape = {product(buffer_dims, 0, axis_boundary), @@ -61,40 +61,6 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } - - // Swizzle scaling factors for MXFP8 - if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Get the swizzle buffer - NVTE_CHECK(swizzled_scale_inv->element_count() > 0, - "Missing swizzled inverse scale buffer in the JAX primitive."); - auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - auto swizzled_scale_inv_dtype = - convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); - NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, - "Inverse scale factors need to have an 8-bit data type."); - - // Create tensor to hold swizzled scale factor - TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); - if (rowwise) { - output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - - // Launch swizzle kernel - nvte_swizzle_scaling_factors(input.data(), output.data(), stream); - - // Set swizzled scales into the input tensor - if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - } } return std::make_tuple(std::move(input), input_shape); @@ -103,21 +69,19 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { - // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; - auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); - auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode, + lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, + rhs_axis_boundary, make_rhs_rowwise); // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], @@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out - .Ret() // lhs_swizzled - .Ret() // rhs_swizzled .Ret() // workspace .Attr("scaling_mode") .Attr("lhs_axis_boundary") From 4903f947d6de871cd92c3478c8b5b78f835d5b7f Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Tue, 9 Sep 2025 23:52:01 -0700 Subject: [PATCH 14/78] Extract cpp distributed tests into a separate project (#2165) * Extract cpp distributed tests into a separate project Signed-off-by: Vladimir Cherepanov * Remove obsolete exclusion Signed-off-by: Vladimir Cherepanov * Run L1_cpp_distributed tests if at least 4 GPUs Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov --- qa/L0_cppunittest/test.sh | 2 +- qa/L1_cpp_distributed/test.sh | 10 ++-- tests/cpp/CMakeLists.txt | 1 - tests/cpp/comm_gemm/CMakeLists.txt | 19 ------- tests/cpp_distributed/CMakeLists.txt | 57 +++++++++++++++++++ .../test_comm_gemm.cu | 2 +- 6 files changed, 65 insertions(+), 26 deletions(-) delete mode 100644 tests/cpp/comm_gemm/CMakeLists.txt create mode 100644 tests/cpp_distributed/CMakeLists.txt rename tests/{cpp/comm_gemm => cpp_distributed}/test_comm_gemm.cu (99%) diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index aa56d69ed6..cd46b0b63c 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp cmake -GNinja -Bbuild . cmake --build build export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) -ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)' +ctest --test-dir build -j$NUM_PARALLEL_JOBS diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index f4f914b3e9..e074b46ae6 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -9,7 +9,9 @@ set -e TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH -cd $TE_PATH/tests/cpp -cmake -GNinja -S. -Bbuild -cmake --build build -mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm +if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then + cd $TE_PATH/tests/cpp_distributed + cmake -GNinja -S. -Bbuild + cmake --build build + mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm +fi diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 412c5d34d9..c2c9d0d915 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -43,6 +43,5 @@ include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) -add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) diff --git a/tests/cpp/comm_gemm/CMakeLists.txt b/tests/cpp/comm_gemm/CMakeLists.txt deleted file mode 100644 index 55f5207acf..0000000000 --- a/tests/cpp/comm_gemm/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -add_executable(test_comm_gemm - test_comm_gemm.cu - ../test_common.cu) - -find_package(OpenMP REQUIRED) -find_package(MPI REQUIRED) -find_library(NCCL_LIB - NAMES nccl libnccl - PATH_SUFFIXES lib - REQUIRED) -target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) -target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) - -include(GoogleTest) -gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt new file mode 100644 index 0000000000..ed3ddeb885 --- /dev/null +++ b/tests/cpp_distributed/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +cmake_minimum_required(VERSION 3.18) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + endif() +endif() + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +project(transformer_engine_distributed_tests LANGUAGES CUDA CXX) + +add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) + +include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) + +if(NOT DEFINED TE_LIB_PATH) + execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'" + OUTPUT_VARIABLE TE_LIB_FILE + OUTPUT_STRIP_TRAILING_WHITESPACE) + get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) +endif() + +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) + +message(STATUS "Found transformer_engine library: ${TE_LIB}") +include_directories(../../transformer_engine/common/include) +include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) +include_directories(${CMAKE_SOURCE_DIR}) + +find_package(CUDAToolkit REQUIRED) + +add_executable(test_comm_gemm + test_comm_gemm.cu + ../cpp/test_common.cu) + +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) +target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) +target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) + +include(GoogleTest) +gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/comm_gemm/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu similarity index 99% rename from tests/cpp/comm_gemm/test_comm_gemm.cu rename to tests/cpp_distributed/test_comm_gemm.cu index b34d4db4b8..8355d5f96f 100644 --- a/tests/cpp/comm_gemm/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_common.h" +#include "../cpp/test_common.h" #include "common.h" using transformer_engine::DType; From 483d9594fb070f62966f6a12ed6c90942310b48e Mon Sep 17 00:00:00 2001 From: jomitchellnv <148147880+jomitchellnv@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:54:43 -0700 Subject: [PATCH 15/78] Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129) * test - adds unit test for cp utilities and the utilites Signed-off-by: Jonathan Mitchell * assert line change Signed-off-by: Jonathan Mitchell * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jonathan Mitchell Co-authored-by: Jonathan Mitchell Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sudhakar Singh --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/pytorch/attention/test_cp_utils.py | 715 ++++++++++++++++++ .../dot_product_attention/context_parallel.py | 211 +++++- 3 files changed, 926 insertions(+), 1 deletion(-) create mode 100644 tests/pytorch/attention/test_cp_utils.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index e5b4b58617..7f061d222a 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py new file mode 100644 index 0000000000..00200c62d2 --- /dev/null +++ b/tests/pytorch/attention/test_cp_utils.py @@ -0,0 +1,715 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for context parallel utils.""" +import torch +import unittest +from typing import Tuple +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( + get_batch_on_this_cp_rank, + pad_thd_sequences_for_cp, + generate_positional_ids_for_cp, +) + + +class TestSequencePadding(unittest.TestCase): + def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self): + """Test with custom padding values for all tensors.""" + # Setup + + input_ids = torch.tensor([1, 1, 1, 2, 2, 3, 3, 3, 3]) + cu_seqlens = torch.tensor([0, 3, 5, 9]) + labels = torch.tensor([-100, -100, -100, -100, -100, -100, -100, 13, -100]) + positional_ids = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3]) + divisibility_factor = 8 + + pid = 777 + label_pad = -200 + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Sequence: [ a a a p p p p p b b pppppp ccccpppp] + print("input_ids_padded: ", input_ids_padded) + print("labels_padded: ", labels_padded) + print("positional_ids_padded: ", positional_ids_padded) + print("cu_seqlens_padded: ", cu_seqlens_padded) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + pid, + pid, + pid, + pid, + pid, + 2, + 2, + pid, + pid, + pid, + pid, + pid, + pid, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + pid, + ] + ) + expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24]) + expected_labels_padded = torch.tensor( + [ + -100, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + 13, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + ] + ) + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + assert torch.equal(input_ids_padded, expected_input_ids) + assert torch.equal(labels_padded, expected_labels_padded) + assert torch.equal(positional_ids_padded, expected_positional_ids) + assert torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded) + + def test_mixed_sequence_lengths_with_divisibility_factor(self): + """Test with sequences both shorter and longer than divisibility factor.""" + # Setup - divisibility factor 6 + # Seq 1: length 2 (shorter than 6, needs 4 padding) + # Seq 2: length 7 (longer than 6, needs 5 padding to reach 12) + # Seq 3: length 4 (shorter than 6, needs 2 padding) + # Seq 4: length 10 (longer than 6, needs 2 padding to reach 12) + + input_ids = torch.tensor( + [1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + ) + labels = torch.tensor( + [ + 10, + 11, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 30, + 31, + 32, + 33, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + ] + ) + positional_ids = torch.tensor( + [0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + cu_seqlens = torch.tensor([0, 2, 9, 13, 23]) + divisibility_factor = 6 + + pid = 999 + label_pad = -300 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: [1,1] + 4 pads = 6 total + # Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total + # Seq 3: [3,3,3,3] + 2 pads = 6 total + # Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total + + expected_input_ids = torch.tensor( + [ + 1, + 1, + pid, + pid, + pid, + pid, # Seq 1: 2 + 4 padding + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + pid, + pid, + pid, + pid, # Seq 2: 7 + 5 padding + 3, + 3, + 3, + 3, + pid, + pid, # Seq 3: 4 + 2 padding + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + pid, + pid, # Seq 4: 10 + 2 padding + ] + ) + + expected_labels = torch.tensor( + [ + 10, + 11, + label_pad, + label_pad, + label_pad, + label_pad, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + 30, + 31, + 32, + 33, + label_pad, + label_pad, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [ + 0, + 1, + 2, + 3, + 4, + 5, # Seq 1 positions continue through padding + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 2 positions continue + 0, + 1, + 2, + 3, + 4, + 5, # Seq 3 positions continue + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 4 positions continue + ] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 6, 18, 24, 36]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + def test_sequences_longer_than_divisibility_factor(self): + """Test with all sequences longer than the divisibility factor.""" + # Setup - divisibility factor 4, all sequences longer than 4 + # Seq 1: length 7 (needs 1 padding to reach 8) + # Seq 2: length 11 (needs 1 padding to reach 12) + # Seq 3: length 5 (needs 3 padding to reach 8) + + input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, # 7 tokens + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, # 11 tokens + 3, + 3, + 3, + 3, + 3, # 5 tokens + ] + ) + labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 300, + 301, + 302, + 303, + 304, + ] + ) + positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4] + ) + cu_seqlens = torch.tensor([0, 7, 18, 23]) + divisibility_factor = 4 + + pid = 888 + label_pad = -400 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: 7 + 1 pad = 8 (divisible by 4) + # Seq 2: 11 + 1 pad = 12 (divisible by 4) + # Seq 3: 5 + 3 pads = 8 (divisible by 4) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + pid, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + 3, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + ] + ) + + expected_labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + label_pad, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + label_pad, + 300, + 301, + 302, + 303, + 304, + label_pad, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + +class TestContextParallelUtils(unittest.TestCase): + """Test utilities for context parallel functionality.""" + + def setUp(self): + """Set up mock distributed environment.""" + # Mock torch.distributed functions + self.original_get_world_size = torch.distributed.get_world_size + self.original_get_rank = torch.distributed.get_rank + + def tearDown(self): + """Restore original torch.distributed functions.""" + torch.distributed.get_world_size = self.original_get_world_size + torch.distributed.get_rank = self.original_get_rank + + def _mock_distributed_env(self, cp_size, cp_rank): + """Mock the distributed environment for testing.""" + + def mock_get_world_size(group=None): + return cp_size + + def mock_get_rank(group=None): + return cp_rank + + torch.distributed.get_world_size = mock_get_world_size + torch.distributed.get_rank = mock_get_rank + + def test_cp_rank_slicing_simple_case(self): + """Test CP rank slicing with a simple 2-rank, single sequence case.""" + # Setup: Single sequence of length 8, CP size = 2 + # Each sequence gets divided into 2*cp_size = 4 slices of size 2 each + # Rank 0 gets slices [0,1] and [6,7] (first and last) + # Rank 1 gets slices [2,3] and [4,5] (second and second-to-last) + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]]) + expected_labels_r1 = torch.tensor([[30, 40, 50, 60]]) + expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5]) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_cp_rank_slicing_multiple_sequences(self): + """Test CP rank slicing with multiple sequences.""" + # Setup: Two sequences of length 8 each, CP size = 2 + # Total sequence length = 16, cu_seqlens = [0, 8, 16] + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]]) + labels = torch.tensor( + [[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]] + ) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8, 16]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # For each sequence, rank 0 gets first and last slices + # Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8] + # Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_with_cp_size_1(self): + """Test that CP size = 1 returns original tensors unchanged.""" + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=1, cp_rank=0) + input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # With CP size = 1, should return original tensors + self.assertTrue(torch.equal(input_ids_result, input_ids)) + self.assertTrue(torch.equal(labels_result, labels)) + self.assertTrue(torch.equal(pos_ids_result, position_ids)) + + def test_cp_rank_slicing_sequence_dim_detection(self): + """Test that the function correctly detects sequence dimension.""" + # Test with sequence dimension = 0 (sequence_length, batch_size) + input_ids = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) # (8, 2) + labels = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) + position_ids = torch.tensor( + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]] + ) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Should get indices [0,1] and [6,7] along dimension 0 + expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_mixed_dimensions(self): + """Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension.""" + # Setup: Single sequence of length 8, CP size = 2 + # This tests the opposite case from the simple test: + # - input_ids and labels: 1D (no batch dimension) + # - position_ids: 2D (has batch dimension) + + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D + labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D + position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result + expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result + expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result + expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result + expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_integration_with_padding_and_cp_slicing(self): + """Integration test: pad sequences then slice for CP ranks.""" + # Start with unpadded sequences + input_ids = torch.tensor([1, 1, 2, 2, 2]) # Two sequences: [1,1] and [2,2,2] + labels = torch.tensor([10, 11, 20, 21, 22]) + positional_ids = torch.tensor([0, 1, 0, 1, 2]) + cu_seqlens = torch.tensor([0, 2, 5]) + divisibility_factor = 4 # Will pad to lengths 4 and 4 + + # First, pad sequences + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=0, + padding_label_id=-100, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8] + expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0]) + self.assertTrue(torch.equal(input_ids_padded, expected_padded)) + + # Now test CP slicing with cp_size=2 + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens_padded, + input_ids_padded.unsqueeze(0), + labels_padded.unsqueeze(0), + positional_ids_padded, + ) + + # Each sequence of length 4 gets divided into 4 slices of size 1 + # Rank 0 gets slices [0] and [3] from each sequence + # Seq 1: indices [0] and [3] -> values [1] and [0] + # Seq 2: indices [4] and [7] -> values [2] and [0] + expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c6f4647c04..f00bd573f1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4,7 +4,7 @@ """Context Parallelism.""" import os -from typing import List, Union +from typing import List, Union, Tuple import torch import transformer_engine_torch as tex @@ -3927,3 +3927,212 @@ def attn_forward_func_with_cp( raise ValueError(f"Unsupported communication type: {cp_comm_type}!") return out + + +def pad_thd_sequences_for_cp( + input_ids: torch.Tensor, + labels: torch.Tensor, + cu_seqlens: torch.Tensor, + divisibility_factor: int, + padding_token_id: int = 0, + padding_label_id: int = -100, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pads sequences to be divisible by the divisibility factor. + + Args: + input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences + labels: Tensor of shape (1, N) or (N,) containing labels for each token + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + padding_token_id: Token ID to use for padding (default: 0) + padding_label_id: Label ID to use for padding (default: -100) + + Returns: + Tuple of: + - input_ids_padded: Padded input_ids tensor + - labels_padded: Padded labels tensor + - cu_seqlens_padded: Cumulative sequence lengths accounting for padding + """ + # Flatten input_ids and labels if needed + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + if labels.dim() == 2: + labels = labels.squeeze(0) + + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence (make length a multiple of divisibility_factor) + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Extract sequences and labels for each batch item + batch_sequences = [ + input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + batch_labels = [ + labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + + # Pad sequences and labels to required length + input_ids_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_sequences, padding_amounts) + ] + ) + labels_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_labels, padding_amounts) + ] + ) + + # Compute cumulative padded sequence lengths, starting from 0 + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + cu_seqlens_padded = torch.cumsum( + torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0 + ) + + return input_ids_padded, labels_padded, cu_seqlens_padded + + +def generate_positional_ids_for_cp( + cu_seqlens: torch.Tensor, + divisibility_factor: int, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """Generate positional IDs for sequences padded to be divisible by divisibility_factor. + + Args: + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + dtype: Data type for the generated positional IDs (default: torch.long) + + Returns: + Generated positional_ids tensor where each sequence starts from 0 and continues through padding + """ + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Generate positional IDs for each padded sequence (each starts from 0) + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + positional_ids = torch.cat( + [torch.arange(0, int(length), dtype=dtype) for length in padded_lengths] + ) + + return positional_ids + + +def get_batch_on_this_cp_rank( + cu_seqlens_padded: torch.Tensor, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + position_ids_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup = None, + qvk_format: str = "thd", +): + """Slice batch input along sequence dimension into multiple chunks for THD format. + + This function is inteded for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + if qvk_format == "thd": + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + if cp_size > 1: + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + # Process each tensor directly + input_ids_padded = process_tensor(input_ids_padded) + labels_padded = process_tensor(labels_padded) + position_ids_padded = process_tensor(position_ids_padded) + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded, position_ids_padded From 405d474b39d0975a5c2a732d68e3ba9cfe28313b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:29:20 +0200 Subject: [PATCH 16/78] [PyTorch Debug] Fix issue with negative underflow% stat. (#2107) * fix underflows log issue Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/debug/test_api_features.py | 12 +++++---- tests/pytorch/debug/test_log.py | 10 +++---- .../debug/features/utils/stats_computation.py | 26 +++++++++++++------ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index 974772599a..d28db16477 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -268,7 +268,7 @@ def assert_empty(): )[0] expected_underflows = ( - ((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) + ((tensor_fp8.dequantize() == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) ) assert debug_api.transformer_engine.inspect_tensor_enabled( @@ -302,7 +302,7 @@ def assert_empty(): )[0] # Second config in same yaml - tensor = torch.rand((100, 100, 5)) + tensor = torch.rand((100, 100, 5)).cuda() debug_api.transformer_engine.inspect_tensor( "decoder.6.mlp.fc1", tensor_name="activation", @@ -316,7 +316,9 @@ def assert_empty(): stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"]) - assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean() + torch.testing.assert_close( + stats[("decoder.6.mlp.fc1", "activation", "mean", 200)], tensor.mean() + ) debug_api.transformer_engine.inspect_tensor( "decoder.7.mlp.fc1", @@ -331,7 +333,7 @@ def assert_empty(): stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"]) - assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max() + torch.testing.assert_close(stats[("decoder.7.mlp.fc1", "weight", "max", 200)], tensor.max()) assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.7.mlp.fc1", tensor_name="weight", iteration=201 @@ -377,7 +379,7 @@ def fp8_tensor(t): return quantizer(t.cuda()) shape = [1024, 1024] - tensors = [torch.randn(shape) for _ in range(2)] + tensors = [torch.randn(shape).cuda() for _ in range(2)] tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] feed(tensors[0], tensors_fp8[0], quantizer) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index ca8e10ad69..dcc9861c84 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -167,8 +167,8 @@ def test_numerics(fp8_recipe, feature_dirs): num_quantizers=3, ) - tensor = torch.zeros(1024, 1024).cuda() - tensor[0, :] = 1000 + tensor = torch.randn(1024, 1024).cuda() + tensor[0, 100:200] = -0.0 quantizer = recipe_state.make_quantizers()[0] quantized_tensor = quantizer(tensor) @@ -191,15 +191,13 @@ def test_numerics(fp8_recipe, feature_dirs): if "underflows%" in line: underflows = float(line.split("value=")[1]) expected = ( - ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) - / dequantized_tensor.numel() - * 100 + ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100 ) assert underflows == pytest.approx(expected.cpu(), abs=1e-4) if "mse" in line: mse = float(line.split("value=")[1]) expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean") - assert mse == pytest.approx(expected.cpu(), abs=1e-6) + assert mse == pytest.approx(expected.cpu(), abs=1e-4) if "overflows%" in line: overflows = float(line.split("value=")[1]) expected = ( diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 3842ab1c56..2fa6985acf 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -199,6 +199,15 @@ def _get(buffers, stat_name): ), } +FP8_NEGATIVE_ZERO = 128 # represnts -0.0 in fp8 + + +def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor: + """Count the number of non-zero elements in the fp8 data.""" + fp8_data = fp8_data.view(dtype=torch.uint8) + zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8) + return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum() + def add_underflows_stats(recipe_name: str, columnwise: bool = False): """Register *both* underflow stats (num and %) for the given recipe.""" @@ -212,22 +221,23 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): stats_to_num[stat_pct] = len(stats_to_num) STATS[stat_num] = ( - lambda x, aux_dict: ( + lambda x, aux_dict: x.count_nonzero() + - count_nonzero_fp8( aux_dict[recipe_name].get_data_tensors( rowwise_data=not columnwise, columnwise_data=columnwise ) - == 0 - ).sum() - - (x == 0).sum(), + ), lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), ) STATS[stat_pct] = ( lambda x, aux_dict: ( - aux_dict[recipe_name].get_data_tensors( - rowwise_data=not columnwise, columnwise_data=columnwise + x.count_nonzero() + - count_nonzero_fp8( + aux_dict[recipe_name].get_data_tensors( + rowwise_data=not columnwise, columnwise_data=columnwise + ) ) - == 0 - ).sum() + ) / aux_dict[recipe_name].numel() * 100, lambda buffers, _sn_num=stat_num: 100 From cd2034f3f28ec07ef4feb18d469a393a9cd2596f Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Mon, 15 Sep 2025 17:08:28 -0400 Subject: [PATCH 17/78] Lower precision gated-act to accelerate FP8 current-scaling. (#2153) * Applying the original precision as Norm outputs' and activation compuations. Signed-off-by: Ming Huang * Adding knob to control norm output precision. Signed-off-by: Ming Huang * Removing the knob and applying lower-precision norm with current-scaling only. Signed-off-by: Ming Huang * Fix the error when quantizer==None Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang --- tests/jax/test_custom_call_compute.py | 13 +++++++++++-- transformer_engine/jax/cpp_extensions/activation.py | 6 +++--- .../jax/cpp_extensions/normalization.py | 4 ++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 11f07d9133..9e39b84c0b 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -465,14 +465,23 @@ def _test_norm_forward( x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_mu, ref_rsigma = _jax_layernorm( - x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) else: output, rsigma = tex.rmsnorm_fwd( x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_rsigma = _jax_rmsnorm( - x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) ref_mu = None diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d3c7d2b086..cdda201668 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1045,7 +1045,7 @@ def act_lu( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( - x=x.astype(jnp.float32), + x=x, activation_type=activation_type, quantizer=None, ) @@ -1178,8 +1178,8 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz.astype(jnp.float32), - x=x.astype(jnp.float32), + dz=dz, + x=x, activation_type=activation_type, quantizer=None, ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index de1877de5c..7a978c1b74 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -842,6 +842,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) output = normed_input * gamma + beta if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) @@ -867,6 +869,8 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): output = normed_input * gamma if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) From 59130cc9d0bd7cc66457556373d731ca0744cf9b Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:12:12 -0700 Subject: [PATCH 18/78] [PyTorch] Support activation CPU offloading in fusible ops (#2158) * Add CPU offloading logic to ops. Fix test to compute dgrad. Signed-off-by: Tim Moon * Make sure grads are contiguous in op backwards Signed-off-by: Tim Moon * Add op-based MLP to CPU offloading tests Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle different weight cache behavior on Hopper/Blackwell Add MXFP8 to CPU offload tests. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove MXFP8 test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_cpu_offloading.py | 196 +++++++++++------- transformer_engine/pytorch/ops/_common.py | 4 +- .../pytorch/ops/basic/activation.py | 3 + .../pytorch/ops/basic/basic_linear.py | 3 + .../pytorch/ops/basic/dropout.py | 3 + .../pytorch/ops/basic/l2normalization.py | 9 +- .../pytorch/ops/basic/layer_norm.py | 7 +- .../pytorch/ops/basic/rmsnorm.py | 7 +- .../fused/forward_linear_bias_activation.py | 13 +- .../ops/fused/forward_linear_bias_add.py | 15 +- .../ops/fused/forward_linear_scale_add.py | 5 +- .../ops/fused/userbuffers_forward_linear.py | 3 + 12 files changed, 174 insertions(+), 94 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 0b0732dfa8..0e01f0b04a 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,8 +2,11 @@ # # See LICENSE for license information. +import contextlib +import gc import os -from contextlib import nullcontext +from typing import Iterable, Optional + import pytest import torch @@ -11,15 +14,16 @@ from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from utils import ModelConfig, get_available_attention_backends -# Check if FP8 is supported +# Check supported quantization schemes fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() -fp8_recipes = [None] +quantization_recipes: Optional[recipe.Recipe] = [None] if fp8_available: - fp8_recipes.append(recipe.Float8CurrentScaling()) - fp8_recipes.append(recipe.DelayedScaling()) + quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) model_config = { "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), @@ -48,85 +52,139 @@ "transformer_layer": lambda: te.TransformerLayer( SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 ), + "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + "layernorm_mlp_ops": lambda: te.ops.Sequential( + te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + ), } -def _get_input(): - return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() +def _make_input() -> torch.Tensor: + """Generate random input tensor.""" + return torch.randn( + (128, SIZE, SIZE), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) -def _get_fp8_weight_cache_size(models, fp8_recipe): - """ - Calculate the total FP8 weight cache size (in MB) for a list of models. - """ - if fp8_recipe is None: +def _warmup_model( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> None: + """Perform forward and backward pass""" + tensor = _make_input() + for module in modules: + with te.fp8_autocast( + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + ): + tensor = module(tensor) + tensor.sum().backward() + + +def _estimate_cached_weight_size( + model_name: str, + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> float: + """Calculate the memory (in MiB) needed for weight caching.""" + + # The weight params are cached directly for unquantized compute + if quantization_recipe is None: return 0 - params_bytes = 0 - for model in models: - for name, param in model.named_parameters(): - if "weight" in name: - params_bytes += param.numel() + # Count number of weight param elements + param_elements = 0 + for module in modules: + for param in module.parameters(): + if param.dim() == 2: + param_elements += param.numel() + + # FP8 tensor-scaling caches one byte per element + if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): + if not is_non_tn_fp8_gemm_supported() and model_name not in ( + "linear_op", + "layernorm_mlp_ops", + ): + # Modules do not deallocate FP8 transpose for weights + return 2 * param_elements / 1024**2 + return param_elements / 1024**2 + + # MXFP8 caches one data byte per element and one scale byte per 32 + # elements + if quantization_recipe.mxfp8(): + if model_name not in ("linear_op", "layernorm_mlp_ops"): + # Modules do not deallocate column-wise MXFP8 data for weights + return 2 * param_elements * (1 + 1 / 32) / 1024**2 + return param_elements * (1 + 1 / 32) / 1024**2 + + raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + + +def _measure_cached_memory( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], + cpu_offload: bool, +) -> float: + """Measure the growth in allocated GPU memory in MiB after a model forward pass. + + Memory measurement excludes the input and output tensors. - # One byte for columnwise and one byte for rowwise, - # hence multiply by 2 and convert to MB - # there is 1 byte of scale per 32 elements in mxFP8 - factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1 - return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2) + """ + # Reset memory + gc.collect() + torch.cuda.empty_cache() -def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): - tensor = _get_input() + # Context and sync function for CPU offloading if cpu_offload: offload_context, sync_function = te.get_cpu_offload_context( enabled=True, - num_layers=len(models) - 1, - model_layers=len(models), + num_layers=len(modules), + model_layers=len(modules) + 1, offload_activations=True, offload_weights=False, ) else: - offload_context = nullcontext() + offload_context = contextlib.nullcontext() sync_function = lambda x: x - for model in models: + # Forward pass, with dummy step to trigger offload for last module + inp = _make_input() + tensor = inp + memory_before_forward = torch.cuda.memory_allocated() / (1024**2) + for module in modules: with te.fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe ), offload_context: - tensor = model(tensor) + tensor = module(tensor) tensor = sync_function(tensor) + with offload_context: + tensor = tensor.clone() + tensor = sync_function(tensor) + memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) - max_mem_used = torch.cuda.memory_allocated() / (1024**2) - torch.cuda.synchronize() - + # Backward pass tensor.sum().backward() + torch.cuda.synchronize() - return max_mem_used + # Memory usage in MiB + return memory_after_forward - memory_before_forward -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("model_key", model_types.keys()) -def test_cpu_offload(fp8_recipe, model_key) -> None: - """ - We run three configurations: - (1) No offloading: All activations remain on the GPU between forward and backward passes. - (2) No offloading (one layer): Only the first layer's activations remain on the GPU between - forward and backward passes. - (3) With offloading (all layers): Only the last layer's activations remain on the GPU - between forward and backward passes, while all other layers are offloaded to the CPU. - - We expect the memory consumption of configurations (2) and (3) to be similar, with - the difference being the size of the FP8 cache that is not offloaded to the CPU. - We also expect this memory consumption to be smaller than in scenario (1). - """ - import gc +@pytest.mark.parametrize("quantization_recipe", quantization_recipes) +@pytest.mark.parametrize("model_name", model_types.keys()) +def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: + """Check that CPU offloading runs and has expected memory usage.""" - gc.collect() - - model_cls = model_types[model_key] - models_list = [model_cls() for _ in range(NUM_LAYERS)] - - if model_key in ["multihead_attention", "transformer_layer"]: + # Construct model + modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] + if model_name in ["multihead_attention", "transformer_layer"]: available_backends, *_ = get_available_attention_backends( model_config["small"], qkv_dtype=torch.bfloat16, @@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: os.environ["NVTE_FLASH_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - without_offloading = _measure_memory_between_forward_and_backward( - models_list, fp8_recipe, False - ) - without_offloading_one_layer = _measure_memory_between_forward_and_backward( - models_list[:1], fp8_recipe, False - ) - with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True) + # Warmup + _warmup_model(modules_list, quantization_recipe) - assert with_offloading < without_offloading + # Measure cached memory after forward pass + memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) + memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) - # The only difference between the memory consumption of with_offloading - # and without_offloading_one_layer should be the size of the FP8 weights cache, - # which is not offloaded to the CPU. - memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) - assert ( - memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON + # Check for expected memory usage + assert memory_with_offload < memory_without_offload + memory_from_cached_weights = _estimate_cached_weight_size( + model_name, + modules_list, + quantization_recipe, ) + assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 8e997428f4..99bbc34c45 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -29,7 +29,9 @@ def maybe_dequantize( if is_quantized_tensor(tensor): return tensor.dequantize(dtype=dtype) if dtype is not None and tensor.dtype != dtype: - return tensor.to(dtype) + tensor = tensor.to(dtype) + if not tensor.is_contiguous(): + tensor = tensor.contiguous() return tensor diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 5ef421bc1d..22779b6017 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -11,6 +11,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -110,6 +111,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 8336330558..70c70c54d2 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -13,6 +13,7 @@ import torch from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -964,6 +965,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) ctx.save_for_backward(x_local, w) ctx.with_quantized_compute = with_quantized_compute ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index f0f55322c4..30ccf5ebcd 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -9,6 +9,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from .._common import maybe_autocast_dtype, maybe_dequantize @@ -70,6 +71,8 @@ def op_forward( # Save context for backward if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(mask) ctx.save_for_backward(mask) ctx.impl = impl ctx.dropout_probability = self.dropout_probability diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index a340e7d42a..440fee34d1 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -10,10 +10,8 @@ import torch -from ...utils import clear_tensor_data from ... import torch_version -from .._common import maybe_dequantize -from ..op import BasicOperation, OperationContext +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -22,6 +20,9 @@ warmup_jit_l2normalization_all_dtypes, ) from ...tensor import Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize class L2Normalization(BasicOperation): @@ -101,6 +102,8 @@ def op_forward( # Save state for backward pass if requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3d8862e99c..91e6de07d7 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -14,6 +14,9 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -22,8 +25,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class LayerNorm(BasicOperation): @@ -215,6 +216,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 42d3fc101b..8c3f029747 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -14,6 +14,9 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -22,8 +25,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class RMSNorm(BasicOperation): @@ -196,6 +197,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index b87b12f840..02bcfee0ae 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager from ...tensor import Quantizer +from ..basic import BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasActivation(FusedOperation): @@ -121,6 +118,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index dd59e602f2..15cc081c1d 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) -from transformer_engine.pytorch.tensor import Quantizer +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer +from ..basic import AddExtraInput, BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasAdd(FusedOperation): @@ -118,6 +115,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 448f72763a..21190d4fcf 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,14 +10,15 @@ import torch +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..op import ( FusedOperation, FusibleOperation, OperationContext, ) -from ...tensor import Quantizer class ForwardLinearScaleAdd(FusedOperation): @@ -95,6 +96,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 574642794f..a604e57dcd 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,6 +12,7 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size from ...fp8 import FP8GlobalStateManager from ...module.base import ( @@ -353,6 +354,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer From 258d084237dccef6d862d20eb2fd63c77315cb36 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 16 Sep 2025 11:29:04 -0700 Subject: [PATCH 19/78] Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174) * Do not use norm fwd + amax fusion if cudnn backend is requested Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Read envirornment vairable directly to avoid include error Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/normalization/layernorm/ln_api.cpp | 3 ++- .../common/normalization/rmsnorm/rmsnorm_api.cpp | 3 ++- .../pytorch/csrc/extensions/normalization.cpp | 12 ++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index af19300a96..398c0acbdd 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -66,7 +66,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { - cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); } bool gamma_in_weight_dtype = false; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 1aae72e152..82e360ed64 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -52,7 +52,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { - cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); } bool training = diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 59bac8fe5a..c63f892cea 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -110,7 +110,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); std::tie(unquantized_out_cu, unquantized_out) = my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); @@ -145,7 +146,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); } else { @@ -290,7 +292,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); std::tie(unquantized_out_cu, unquantized_out) = my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); @@ -325,7 +328,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); } else { From c221909dba98182dcac7bd438edad30871639b33 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 17 Sep 2025 06:32:54 +1200 Subject: [PATCH 20/78] Fix unjoined comm stream in UB communicator (#2160) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index d90dd3abc1..0874934958 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -612,12 +612,16 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, recv_stream); + // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf for (auto stream : {send_stream, recv_stream}) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); - // We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); } + + // Next we sync with the main stream + // We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } /*************************************************************************************************** From ba37529c273182c2ef192e7198ceac1ecfa78e20 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 16 Sep 2025 17:10:39 -0700 Subject: [PATCH 21/78] FP8 Output Quantization for GEMM (#2123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Test working as I think it should work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> address review comments Signed-off-by: Varun Thumbe * Update test_multi_process_distributed_grouped_gemm.py change accidentally added while merging Signed-off-by: vthumbe1503 * Update dense.py change accidentally added while merging Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address revie comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug solved: delayed scaling quantization with mxfp8 inputs didnt work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the unit test error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * just to trigger ci Signed-off-by: Varun Thumbe * address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * fix merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 76 ++++++++++++++++++- .../quantize_transpose_vector_blockwise.cu | 11 ++- .../pytorch/csrc/extensions/gemm.cpp | 53 ++++++++++--- transformer_engine/pytorch/csrc/quantizer.cpp | 49 +----------- 4 files changed, 125 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e720673675..a50b3fbca5 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -39,16 +39,21 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states, get_available_attention_backends + # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -2607,6 +2612,73 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "input_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + ], +) +@pytest.mark.parametrize( + "out_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + Float8Quantizer( + torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3 + ), + ], +) +def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer): + # For MXFP8 and CurrentScaling, below unfused quantization should happen + # FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output + # Skip invalid configurations + is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance( + out_quantizer, MXFP8Quantizer + ) + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if is_mxfp8_needed and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + outp_type = torch.float32 + quantized_out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=out_quantizer, + bias=None, + use_split_accumulator=False, + ) + + out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=None, + bias=None, + use_split_accumulator=False, + ) + expected_quantized_out = out_quantizer(out) + + # Match results again Pytorch GEMM and allow for quantization tolerance + pytorch_out = torch.matmul( + inp_fp8.dequantize().to(torch.float64), + torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1), + ) + fp8_tols = dict(rtol=0.125, atol=0.0675) + torch.testing.assert_close( + pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols + ) + # Match results between quantization happening inside vs outside general_gemm + torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize()) + + @pytest.mark.parametrize( "shape", [ diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 4c82b8c81b..d38bf79963 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -579,14 +579,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor "Input and output_t must have the same shape for columnwise non-transpose case."); } } - - NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + // output may not be defined if rowwise quantization is not needed. + NVTE_CHECK(output.dtype == output_t.dtype, + "output and output_t need to have the same dtype."); + } NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT; size_t scale_t_k = scale_inv_t.shape[1]; scale_t_stride_x = columnwise_compact ? 1 : scale_t_k; scale_t_stride_y = columnwise_compact ? scale_t_k : 1; } + auto output_dtype = + rowwise_option != FP8BlockwiseRowwiseOption::NONE ? output.dtype : output_t.dtype; const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); @@ -597,7 +602,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor input.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output.dtype, OutputType, + output_dtype, OutputType, dim3 grid(num_blocks_x, num_blocks_y, 1); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f4768bb9ba..485d67055e 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -93,6 +93,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans bool use_split_accumulator, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, bool bulk_overlap, float alpha, std::optional beta) { + using namespace transformer_engine::pytorch::detail; + // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -123,10 +125,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans "into D tensor. Beta has nothing to be applied to."); } + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); @@ -139,12 +141,35 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } + // maintain unquantized tensor in case we need unfused quantization support. + TensorWrapper unquantized_D_tensor; + py::object unquantized_out; + // Unfused quantization is needed in the following cases + // 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that) + // 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling, + // GEMM Output needs to be in BF16, to allow for unfused quantization) + bool unfused_quantization_needed = !quantizer.is_none(); + if (low_precision) { + // At the moment, only use-case for fused GEMM: + // Delayed scaling quantizer with per-tensor scaling inputs + bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr()); + if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input) + unfused_quantization_needed = false; + } + + if (unfused_quantization_needed) { + NoneQuantizer q{none}; + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + } + TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; + // Bias tensor TensorWrapper bias_tensor; MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + auto opts = + torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -157,7 +182,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -210,7 +235,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -218,14 +243,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -234,14 +259,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -251,15 +276,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), alpha, *beta, use_split_accumulator, num_math_sms, main_stream); }); } } else { - if (D_tensor.numel() != 0 && !accumulate) { - D_tensor.zero_(main_stream); + if (out_tensor.numel() != 0 && !accumulate) { + out_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -267,7 +292,11 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - + if (unfused_quantization_needed) { + // Quantize the output + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + my_quantizer->quantize(unquantized_D_tensor, D_tensor); + } // Pack outputs std::vector out; out.emplace_back(std::move(D)); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c690cd522a..cd7e70fecb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8Quantizer::create_tensor( @@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - // quantize output and its transpose - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8CurrentScalingQuantizer::create_tensor( @@ -562,20 +541,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } -void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { - // Change the rowwise and columnwise_data to the configured dtype. - // May be a switch between E5M2 and E4M3. - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { @@ -917,18 +883,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize this->dtype = quantizer.attr("dtype").cast(); } -void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { From 7042d7ae6daab0624e3bf7412e276d61be8283f6 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 16 Sep 2025 22:30:24 -0700 Subject: [PATCH 22/78] TE Gemma tutorial attempt#2 (#1839) * add tutorial files and other local changes Signed-off-by: Sudhakar Singh * remove extraneous code for easy debu Signed-off-by: Sudhakar Singh * make cuda graphs work with non-paged and paged attention Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * perf imp for kv cache ops Signed-off-by: Sudhakar Singh * add code for calibration Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize kv_cache reindex and copy kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes to make quantizers work with fp8_calibration Signed-off-by: Sudhakar Singh * avoid reindexing from python side Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename variable from previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use quantizer only if needed Signed-off-by: Sudhakar Singh * functionality of the tutorial tested and perf checked Signed-off-by: Sudhakar Singh * remove files and update headers/licenses Signed-off-by: Sudhakar Singh * update header/license Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tutorial for review Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make weights downloadable on the fly; remove extra print statements Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint and update comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comma back, typo Signed-off-by: Sudhakar Singh * sequence_start_positions should be None for training Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add paged attention numberes and update requirements.txt file Signed-off-by: Sudhakar Singh * more fixes Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make tutorial work on blackwell Signed-off-by: Sudhakar Singh * remove gemma FT tutorial for now Signed-off-by: Sudhakar Singh * fixing the headings placement and rewording attention -> kv caching Signed-off-by: Sudhakar Singh * fixes from comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the images Signed-off-by: Sudhakar Singh * misc fixes Signed-off-by: Sudhakar Singh * add more comments to te_gemma.py and cleanup utils.py Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more information about the hierarchy of the classes used in the tutorial Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add better cuda graphs picture Signed-off-by: Sudhakar Singh * addd updated cuda graphs pictures Signed-off-by: Sudhakar Singh * add illustrated cuda graphs Signed-off-by: Sudhakar Singh * fix Signed-off-by: Sudhakar Singh * small fixes in documentation Signed-off-by: Sudhakar Singh * add torch.no_grad() to force reduced memory usage Signed-off-by: Sudhakar Singh * some fixes from recent comments Signed-off-by: Sudhakar Singh * more fixes from remaining comments Signed-off-by: Sudhakar Singh * add te_rope_emb to class desc Signed-off-by: Sudhakar Singh * fix tutorial wording; add calibration fix to grouped_linear.py Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- docs/examples/te_gemma/media/calibration.svg | 620 ++++++++++++ .../te_gemma/media/calibration_1_half.svg | 415 ++++++++ .../te_gemma/media/calibration_2_half.svg | 401 ++++++++ .../te_gemma/media/fp8_model_init.svg | 500 ++++++++++ .../te_gemma/media/fp8_model_init_1_half.svg | 358 +++++++ .../te_gemma/media/fp8_model_init_2_half.svg | 371 +++++++ .../te_gemma/media/generation_animation.gif | Bin 0 -> 135280 bytes docs/examples/te_gemma/media/graphs.svg | 232 +++++ .../media/transformer_cuda_graphed.png | Bin 0 -> 369694 bytes docs/examples/te_gemma/requirements.txt | 4 + docs/examples/te_gemma/te_gemma.py | 703 +++++++++++++ .../te_gemma/te_gemma_loading_weights.py | 189 ++++ .../tutorial_generation_gemma_with_te.ipynb | 941 ++++++++++++++++++ docs/examples/te_gemma/utils.py | 370 +++++++ ...tutorial_accelerate_hf_llama_with_te.ipynb | 2 +- docs/index.rst | 1 + .../pytorch/attention/inference.py | 28 +- .../pytorch/attention/multi_head_attention.py | 24 +- .../pytorch/csrc/extensions/apply_rope.cpp | 3 +- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 17 +- transformer_engine/pytorch/module/linear.py | 2 +- 23 files changed, 5152 insertions(+), 33 deletions(-) create mode 100644 docs/examples/te_gemma/media/calibration.svg create mode 100755 docs/examples/te_gemma/media/calibration_1_half.svg create mode 100644 docs/examples/te_gemma/media/calibration_2_half.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init_1_half.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init_2_half.svg create mode 100644 docs/examples/te_gemma/media/generation_animation.gif create mode 100644 docs/examples/te_gemma/media/graphs.svg create mode 100644 docs/examples/te_gemma/media/transformer_cuda_graphed.png create mode 100755 docs/examples/te_gemma/requirements.txt create mode 100755 docs/examples/te_gemma/te_gemma.py create mode 100755 docs/examples/te_gemma/te_gemma_loading_weights.py create mode 100755 docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb create mode 100755 docs/examples/te_gemma/utils.py diff --git a/docs/examples/te_gemma/media/calibration.svg b/docs/examples/te_gemma/media/calibration.svg new file mode 100644 index 0000000000..16e1a43141 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration.svg @@ -0,0 +1,620 @@ + + + + + + + + + + + FP8 with initial scaling factors + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + diff --git a/docs/examples/te_gemma/media/calibration_1_half.svg b/docs/examples/te_gemma/media/calibration_1_half.svg new file mode 100755 index 0000000000..478604d415 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_1_half.svg @@ -0,0 +1,415 @@ + + + + + + + + + + + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + + FP8 with initial scaling factors + Calibration + + diff --git a/docs/examples/te_gemma/media/calibration_2_half.svg b/docs/examples/te_gemma/media/calibration_2_half.svg new file mode 100644 index 0000000000..439f4c16fb --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_2_half.svg @@ -0,0 +1,401 @@ + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + diff --git a/docs/examples/te_gemma/media/fp8_model_init.svg b/docs/examples/te_gemma/media/fp8_model_init.svg new file mode 100644 index 0000000000..57af23dc31 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init.svg @@ -0,0 +1,500 @@ + + + + + + + + + + FP32/BF16 + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_1_half.svg b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg new file mode 100644 index 0000000000..d86751e071 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg @@ -0,0 +1,358 @@ + + + + + + + + + + + FP32/BF16 + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + FP8 + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_2_half.svg b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg new file mode 100644 index 0000000000..c3e4146bad --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg @@ -0,0 +1,371 @@ + + + + + + + + + + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif new file mode 100644 index 0000000000000000000000000000000000000000..25150cb9b64162084b017442a3905c57127c6713 GIT binary patch literal 135280 zcmdSfRZtvG_%7(d!T^K=5F}EkGa;EQ1g3?#|%u?(VLGySwl2e`@#C z*{i+YzUum_tE;*%zUq4Wk&%<;<2TOuz=!An08qZ8DNAXnONgmTak8-^zyba%Jt88F z0Ym{T|IGpavn=ra{r&CjEjT#%zd}z>Phnx9m6a7YH#Zg*mW+(d+S*!OT^%(wwY$6f z$;pY9me%CtWKvQR5-m3lrzj!~*ZJiYCMITNWF!Rz#rgSpMMZ^!gM+H7>cqsv#>R%I zsAy(p=D&acjEs!f+1WcfIsotph*a#boxSbdy{oIMUlj;{ngD2Q?7qIfO-)VH)6-E= zQ4$go&d$z#eSI7p955J6U0ppRBg53xl#-INw6ydS$5-mFN)`nlu)nZwZ14WhLH_?d zkki2cUqAnVz#vF)NN8AiL}XNSOl(|yLSj;KN@`kqMrKxaPHtX)L17WJxTLhKyrQzI zx~8_SzM-+HxuvzOy`!_MyQjCWe_(KEcw}^Jd}4BHdS-TReqnKGd1ZBNeFL_+wY{^u zw|{VWbbNApc7Abrb$xStcmMGC^!)Pr_6|TmC6TSn>hc4AV$vV1%O zCE${^ULCB;9f+im2qBTH&KrtlP%qXWs?HxtWdAvwCRbB1mdf}0XmzNja3WLG50zBD zwrDC>HkRqfa4mGEP&HRRUB0e(u0*?r*LrQZu4JLYs3(L}p}usf#&V|k$4GtIN`u4p zaJoW6`C5zH_0if$L&Zi17y*q;v9WTq2l9!T&0w^#YI`7xS|LNRsd{%LiOXhvw5eu) zB1IKZ2-^wJBI!u&+jCc*EbAA%=>K5WTMgyQ-RQiUH1 zi(*HUP#sYEi)&h$eHA9kF%dqX%@uV=K5)o|$H+NQONb%h|Cpjxe@G{!N|du1VvQ

70h+CDgkjRjPS0?*>3nJ| zqhb<-RNW8+_+e7_pOl0^RZFLj8b^a+!w`9?)Ut!49~8 zp%T5n0$DiOY%!{LTLlA?j;xMVAA{8w$ zKnyg!uyzN6iwxaC;4s+_6O@zJ6zGx^bre{g51rLJC>J$O%10Mj;r=_Vzn$VE&`T1e z3xl|+C%hhwkU73uo>CA(o?s!`Um=~5zNWBxr{B3OwA4YL7Y-><{gZ?*1AUFk*K=^cVsuWke^HgK z#Pw#YJ%=nRQI9q=<8$mEW&b$;-7Vw0^ZQ-7n5A-o{p#)cw4ksd=ZZbWYVZ`qJs}iZ zk3D^RidD|aOL;>k-U@41AEWU@_8#XRe9Fao=gcP)i>eUcgIo8jZD4Qb`OJU1lH~NF zRae$@=1TyQ8bWmB6dNr!-^!tYCmf_%M_2N}bzeUvDgsWE<~V+*mV6+aZ3jWvg|Q_4 zArx?J$h2qf%(WBuzS`N}yBvI-ski)ET_tvJvYHFyPrc?E}i=zjX?gWc`Q-yb4(B58VB7{gG3an8&CG> z3Bzp|UKH6hfp4adLcd;OoidN1HyMO^9{oXIw{b8CzyYVI0g9RLS1s?eMZ9?&J$_TqW%)=x826dl3{7E*(Fwf-fh9B&ZO&=7kI9?2&BooGD@$2jL zdK#+OH}R+m&zlj8v!k2b*gIv_LVrv4+wE^xwKJ)%LXzR~PmQEUPJq*y$hqkxm5L{v z(lea6gw}?sQKZ?j*F$S|Iff||wquJDp(2C(Crv80+7L&cSqhRnX+YSOKry#}56f^j z<(EHD-|suW@92Ot0ddknvKol`NJ2V~W4WP48zrSl@^T?qfXbdDkxNKRJzZg{g$u)# zQ%KH}RYJ{S_Q|6tpY-lK3z~eUQc5)6FsE3bx;hly= z0Bx_`&qk`?5*-OQLY7T*sbE0uY_qw+1sKeS=v z;CNAbimwac*Ujvn|NPZz?6syW#6Mr`qX*eSq9^n&U1N=*O#4avk57z0V9LjZWN|>8 zaRdi1(z8%uXeWj5{g1WMWoUM&34?@7KqxjRX&=3@_txT&OnGJUvw$&$pnxCo&9%p2 zK5Uy%A?_x-AxbCGC~sO+H*zKWXq4oh*zw{#me((8hWnKvNXD_$seOar<}k;0o#J0_)6CKp^W`KtX$FyX znRqW&yXScVGfNLYpWiR~4njMU(H$8hmp2!%D*goNF-xTJj;_Ghd;dTMdsd`uh$P$u zDo5;}*Lwm!1$X>xWEBC1Ux&APNxL<CXF6tj#a*jHo z(mIhQ1`l-|JSU;LbrytKCrccx?44-s6Z78%-K~Xm{2~R;oXs3J$j9HWhZ_!8NfE(Q zd;R?D;hb^AD@#03UUG3d(jCybv)ShFbndTfml-h_fjKB=tK;5D`aSC|I2R%@dW0eZ z5N%|-8@`#LAiv(w$;AOr8~V@9TcMe4ev9pP4HAdo_f_k&4XErC@^ownXckUX`ZA_Q zWNRD8y&QEPUSref`-Ik`%T0~9V;y|AfW==R1BHkd z{pGB%=WmG`It2`fyF#H=HJwsb!!PARb3-KREHjBu*WO_OJA4-(o1H&%0}hP^|LR=L zkOEEfVrkJE1po*>hgecZ^M5wPOe~P7x9R`R%=_S8`TYvgXOn#g@=Dr-^Z2<-s0-6c zq7Veo`u=rs66NXE^ZOv{NAa~K*I8%JSg}!*kk9P0K=ote*Dy^=g8?S}2MPck(0I`e zr^qiw)7-0o?hAvHv#fw)3@^8@Gy4|%We<%K!k-Y#AOcI#zit3t6R=tVAaoJ(L=XH6 z^pg|!@jj6L97Ki85Rkz~f|m6CzL0FVN3;B%l|DXrvBxyX-;)0iAMyoP;;s8%DsgWo z8Pp4@^qtR`Kt1u;pe%apqblB!LRT}HU!sT z-32`(s!-{^cJn1dbeqI>Zua3V#ZI*No9OoV5eDbIKkEp?VSV|@{Wk!`T>alB zeNf93PoR}NBsuX<#ADJGdkDO9_$T!rsadu~8pKH`$!SZ;Cl{g>-}SkD0e)5}nuA8I zjwx6N$)!svI}}EF`3RUr%Gm(L-EHD}e)00A)Yc_*W5$HwEG2^>87Uo90nvDPznCt^ zv`MF!_J!0*!t`nW^jV$s`M>Fl$?3~Y>8ngj>FbZ_Fv5&2{)`=+jJ>}Z2gw;nO&KRk z8E20f7lfHt{Fyg8nRkCPALx@apPDjXmNMTSGXX?d@B&$gx>+CGvw$gCsLffRW^_Z)_l944CP9G2xAwx=9+ zqFhdaTyEXmFYdW~DY*j8xkAgiB2T%ZM0w)W0(p|UdD8BAvMG7;&3TH;dCE_DszmwU z1oAa>^EKV`wNvunb({0`m-B;-;Ejk1Oaux{bqmbh3oKI#teXpLmkaEl3LJ6MJdI_&Bdk5#pSwCkLJQeqVycyl6v=&#*~ug z<`U;(5FKVoCsAp)KxwaTX}^2vU`pw5bLr@E>G)IWBvIM4K-sKr*@!^NBw+z=W!dU- z+4@r%jHrA|pnONSe9yi7Af^1Mx%_0g{Oqaxf~evOwHyafkS7mE5CaEC!S-BS@$yvh zMpXGKQ27m@PzkSB3G}E$wW$Q9R-(65Vh~qh3RZp8tHQRa!ttmgO06OZts-lwB44Q@ zW(Im#!!ItU8p*@Kf2*KN1$7ix!(CO=kEC)EQ_y-;z+aWH>I>GCC08@sRIo9ZaOhPC z9jEZ1QScJi3YL%x=+z28mwauh?T6J!FxM$RC&`qM%7ucI^y;*ZlQlwXMMvsXkLwIa zl66u^^*lg7&?xLY>LUp2bp`8f_3G^vYTq3~>%>EA-A3xXMr!}1Hn=i2e7uCG1p~JG zv408H{XuKgWNy4ks_&#=vLOJ?87FUn!}>a_+ld-CRT&@@Oz;;X@o0d0zb5>dw2uz# zg{q`*RgF=^P2p%Dq4?wiWEH55}W zF9(4X1`jXFR;Cd6l-0JF3if|S!?^;S`!pXBYo9%%?2cfdkbqv$K<{W6NMYU|YnbY{ zL#G2YgH_4%ctR#UlAlsPqOW#v$K!W581V`i36|h=EzfkNS9Tu3x;q@&y`LL5FuVTr zK$ej@vQtU?4p}FLjJ0@B=bk^(yt8!g-gUNOvgqY~1xvu=`qCHn@uGLq)2gPGej(NQ z3m|e4-rDTdXH5I>n+K5-p=g(vL$sfez%ho(cZ&+(=?h16=!gtsCAa{5mrW zL&y$46`O!F+xo9GF!V&CbWixNAK)_p#5r1@f!2k7Gukgnv1BG_OV2cU48NMMc6^Z7 zlRNs7&)uf~-y@`F9Xwp{UTP4ShKA@wdhH2hvgML@>e;(QNF1>~=is{~mis2ridan* z?CD8L0Vw(g`E^T{t5rWYb*q-4C1LlTS-onyBCDB={H z<}1xC`uanL@|?Z5<{GyiJXnfCA5~L)zmwvfp#A#++J6NFRJVq9(vMs*)Q=d?CJG^p zuj9*5mV% z**0u1ODu1E*@VAi1`teY3-v|G@nE4t!(e#N*ur<-D^Oj)`_Uo$cLwfFpFw<^&j844U^17#p%R~bQ?M_0yWRe( zA#NaO-pgeFo3a#dxPKDX_q;RG#_ljSt&Oo3XW<`0+7S@JNfo=OdkkFa-sPthZ~A^1fLh-b{&3v@>|WnnhLcjXFcHv>HwcMDPR~LxzI!nY zHg-9p`uIM+X|xWU!-`{D*f~+X*z%?wd=n@%80W#=|DX>f!dxQ@YWehQ|8zaZJJ6gi zb6eQp_-W79l%VUh;3_wP!TB+fGID?Giot7bI^qnC^cOnTi^z{Q#~CFdU(is9!#r)m zE|wxCwi3mP9i8TXSCmHIXtTADbxCYUYqWuqQ!W1oLGNIE@abpr`DN6}xaL`e@l))N zA$;jPT8B5gb8MCK<=fp4HKTKlo*vvrtH1)@*04if=eF3VxhM@|;VSL7+7RnIZm4@;SV_-TK$1Nd(1)hrc}Yr47o+uN+@og^rytVdqY zV0H?CsJ0~oJNu-uEQGI9mLS&hA&MaiE-TBAz4KA~;Gq&2g%wTk)6IG@nh%VELd-0v zpr>GL@~xz0WW~nP>6lqaQ6I$oIOGCt9{K!*;87Ngv< za+-x{w3?jy)f{rjMPiRpuA`3Vl|96%$AW{M>bmP{ee#=>y~bE5T43UEL3 z-%(G<^|*;p>VD*$QjHIDkQBZT!5kK!IWoItV5ZRiNitwA`Ju#*p0u(w4d5A$>+|%$6SEw4xR$mygpf`xWdS@t(S+6}zC(&f!Gj@C)5~s}v zk_`JVPD4dF)$WN=t}2}vv^&1g3$y6Zxr(+lL#FOcGC_3m2sPOX?ho3pl!7? z+p%r!v~nn8?d;Q_jrqL6$aejz_g^&oMvG!5`}lSHS=%PVj1v2nv-TG1mWzZWyTJBw zwEWhGvyopNSTV;HkxwFs4xMZ3HV)mWpIbLOL4w-nW*A>bVZAtpVUGPk+)|DZJmeJn z0d8*=ry;sb5XumP3f^TTup(_|gndHs_gIUA?e74lUnEZBjKe2^zbD0WTdya{>PBf@ zrCqgerj@_iF^{WKFFK2TQwgV?)c!T*()iu=6mvSx3;m0PX=3=F<%zq}yFy)sKQ1fL z>8!46e{AIL*Ipq%c2?XA5AQd8KfYP4`gagMq}YYf@DBJ}%&=MVqhAoaAh#O{2x?q@{t3xiznVg^B&FULZLK z0yx!SkI8~{|C`#@PF&l6IcL2e-N=ndPt)bcdZ18p3wjjFS~~*Q*nnSPCe;4;6lpsVCRdn14=xRMaA> zNvJW+FQ|1RWU&rZNzTjSt>_lfNjgv6v6bTM(`k1~xK0b*DOiz)4u@uOX58Ea@VL0B z(D6(fT+n=oW`im_Rdbmgf{J*N2JYirCSDKj2izT`NMq z*ofshvfa4s9$QFIN(xLD#l?j%BJE3(U#kr`JBM=ko)vq9$bGz z3w)Il3$h*$R6Z(_B^6cseDf{*M7_{7nCbg_g)Szd1AGZ@&C#Lv5n6k8EtFZRP&_pS zQk{ZZW_5dHVzQi+BCT4kUv^xK*$Eb=izyVSVe0e`nJ;t?g398a7)=mq=LrjaXqJtn~zFH3j`XD;`~;UZyKk?B8Y8LvLA&BfZ69ZOm5q};GsgBocRLWzMTKM%2S2Bs=8ZLKj?olR ziNKM&-jT)aWT*BpNu98$`oRP=N2h|?z3@@nViXO*YrCj^1nJdp~6n?tL}2fk^ecIOP%wG|rjFZ0j>NEVlgWz%l+ zZ0k5ynnmqiMK$k?*&Cf_=-8yp#8aDv0cM=v=?46d`yF>nWFs)<<=-P(Il-FTKqTKE zWBc};Efa$US*0k>Hghg3V=Ec;6W(5oZXZ!*KmNG4Zc;oxPDss@{&1g38ba%y*t~XD zpWMpqa)FDLmUZ`~H}bX3%1c6M@OAtiS4X&B_ovGyQl?Hf_3h;+F=7XFa?bYTGFv5e z{iU0=1_hmn4RFuu8p-*?uy{K|T4K?j3mZRWyWM47L~6doKhXHF5sWMN^*Ys6jozUDXfd+PYlPlsw#56k^jF)l`gvYkA>=3$)$1q9gvqL2;-Vh@^+A`O zd&$(>q!sULk3QJbH|y^&xcdA++=%en8DEq@Eh!KE4-y43b=y@;#<@q=$&m75OrY zxjF93V0~=l<^=k@!T%fEQ{f|^p%|p}+9$vT_={@GNig_JJ;=&v%ZUwouv&^R4*I{| z`BM=vd^`9k;-=57?YHwn9ps`3Qw8(C7Kp`d@Xl5B`VHtQ~473~!w37|AZVa?% z2(kbM$(wt4?|W+}1w1(U1B(vuJRE{^1(%rGx&c&K=Xt~v=MfKO3{$LY+#*9K$#iDo6jL-F<2GhZw$VPcb#{P zY6uSp;6Cb;#3oj_mShftP7;A%<)CXnL1ORj8xxfR(8K`pE#k`3{#lk3 zG?)|y{lqfb7+MPsMVi75`2w1krA$lW>TSeYWe_`B41b5gyL{6MSSwCm=rw?xQLZ{Fd9k_P-jlNP>cJjDv{F zB(5-RieJ6VWMiIT>i1({#jNS-phOmJ8I0`wg%!3ePv6x^A32Sb^92`M;FUDmh=o@|9yRi_B z5Ah1-$3i;XOI4I_7P#J1!jx6=gp4#Lmd0u*--zYhRT$%)sAImn#fO+mE`fF;d8QyI}-tk(3v=22Hh zNdPAtVMkIZKYB$|RCF?tyK!?iY%J@~{YZlG6B~aQJC~yIot@#L{NWBMVG!Np`4guYb zF&u_aTo%zls&-w2lxh#2=J}avI02bAnbNL7oM!1!x#=rbECnb@MaikuoAD1E;js@1 zQ`D)sl^J1CDImsA<&FsmqH*w@7zdP0Z7AwbLjhkca)uLfSQ>HwS}u!no=72C??f}{ zFD~)#ts%uVI}}ACj-poUgZ>&$9ei4kz3=OP(XFKd-^1AcQ_bPX~j{S}xUBuI zg>AYuT5O1JqplY65wzo*S@81DyN7h)zW@#Zc7Xsk2L&MB0dT1UwBC%}iHVB6>?{L; zV{R)!RztBQ!=I{t1NdNJZAK+g_0Z$*ku3M&AsQ7ox&?N)6^PcQ6g3s{3l&4CJUm1{ z>cF9FmFNdX_!|xYzscN^3v>W9$}jq3FhvziC(qxyNZ&Drfw9aKnwrYU-c6+0i;Ah~ zY9mtjl`fT&A+@3%w-P(L08yga-LDkhh8guJU2Ca4a_ZB~ptOvY3RS8>1d-PGlb)={ z5BU*~+PjLdsptaN8otz@k}E&!2eXn131ugbYOuBJjP39(99BN}~ z!yBj9RN&Nrv_eyXRKINA2#46+$&g`(AWw`9-bhm8yFA-oXhr9+>axuj9)-Ks(da}O4ds#kKXJLrTTz0@7!1mU6v+-_cXIo$%I zZC9!9VhwFacf{8pMDs)ts#uZjMqXjz3@!FNQ0FhQ5MezL`5fkZgWpW;(QVAtW^r7? z+|sj!*NKVHa%A4ka3bnO5WpKo=?sNWvO+`8G)c;EqTaxFpE?vz#R#ou-|Bgi)z z&i^Ib^Uy?@&)oJtVrM4a8u45&7UnFG)(6e)wZtAk#qKD->RO)by8kl(AnEQ;rBr0; z`0&y{%Q;~BZBT=HP)M=UdZj0#A}Qi&bpcm+PL?VMz6G^M}DdFN0(xqa9HL zOq#9vFF6UVBYv@JVb;u8fH4zW$~oeZ9=$O)n>GzlZ~bb(s-k*h>*%l5k%5(=Z;BNe zp4EG!19>kl>z*zho~1C7aY>=^Am0A+RgnkJ{_588c)HGb&*5t!J~yf4hHn$t2ovAF zkM_0>WQUD4%ji<+w_l`A64Ot>r`O=O8WXJ<hJ`w*1+9a{NF}Df!tGV) zJRz^`kkjyb)|KSy(r)as&r3xWV<9+7tFF+MHsOidv5K>x74@Xm55p_N9(IflM0(L) z)UjjV)mmm|l-6RpS7ou+Rgc%CuUAox7jR8nybm#P$X;*QPi%485H8V+M6p zQ3C-oiC`+10lVZHxnS`OgJ z|BZjobll&7^tYz>?r`J0ajVuN;hy=ruT#0NiTB=MDU`N-PlPL!nPiBKXUlwZ_Xz)> z7iNK#ICcUWviGnkLmOvNo@LA6VsGF7LAHrmW<+T~>E|tvk`esx&FOXY2jl+1vBxOC z-LSR6p;%2oR{P%o47@ahuu3*ikg}s)Nk7K^L8I~r2>wu$`tXay;d_ZUjTfKWr0XGP z&RW&+QL4q!fUY_kMfCN(n-$TiJa6i5#{me!|6=jad5;;XvT5#`ssSa z9%vA@L>99`cAAuebrpftRkm&1c=p6}lpJyR_I8e5aOiJ;`ZwtG!X9?Fe(-$O+>#a9 zJ%)xF$uy&N2KWBUf9ycI?woz)Y#IB6QREC5>43dqeW(oDf+oYNPl{_U_dj^ITUq&ScxwjBWqr+giWudUkjYG zd?qhdA;#s6hWp2i%Qs0}_yv0oF6T!F`6YjJ2JH5esI>523qD8*s_9R zOf0YHCnpOL5KzgkkVmCx6ys^sh%6#A`ZV^~LV6NEFSeVJ-m1X{f`1O`lS?EECEMmn z+f&Q63&lJBFL_%~0{!3e7RNx!hFJ&o_usOfjv!QgI-wi6AS&6QG%`$D4t1>%mC)f4 zI=tM^`OQgqf}!)8!!foiZ1GKujLQvXo71%NU;YQ)c3p!0Z+IIj%rwTP&^kq+NmV3= zjZusM&?W`rs%cB0>YKXpv0_n3I$}9Xs(RX2ZW%kR74ISn-mR+E*-6aD3ErJ`a=bR1 zE44mMZuZAA8;-xX^4s)Jkz{)QVJ-Dl$z`TwF=gih1GPS)Z+SC(u$VSb5ky{YXuLUW z)##5`XghnUzdhxeA$@1R$~rBhhk=h?WMDo(x+0kGB)bCu;eNGn_E(i>jwmf3=t7hh zCI(reB;<@*Gsh>|nguL*ndw@A+D(m&&jTbkDLMXdQv-PRep5$aEA#0YckSMqgi7Mc zZAVJeL3d&l*^;R_1q-e^0ifPVFeM>HZzlpDg}#%^Ul#(*Yhgn^j?VAn?O3J|# z-JGYN)tHv-)AmDd*;*^tL#Oke3i3Q;w;a7$|+S^ihkR2~x*4Jy=Rt`l~UtJ8Nl8;`Eu)h!3PW2Tj>MOL$O_%O;OKuhy z4vC! zSKH>YZ^s`QZCL)e$MRcvZ1dTCys3$a7`Z;nM)=LKukdGTV?#VKTXSQ6r2TVr%|`k2 zS%kaD@ZrDfH}y?XB=jbR{tA)WBPs-R_k&<1d;62_5k=42MT@T#)FnT-MwnnhpWhy~ zkX1~nJNijaM0Og^-&pKQUf*7JL(7KlZ}v04WIjqB|9XFFN)3Lhy{|pSx&Q-(a9#qt zIzI?p_{{1>o$H!*ii=(N5|9KA;U#rpSWh~Vfi?#v54r$vlJ1`Wp_EUWUAWCV7Xc{4 zA5gd~#gUYv#1Ifc$QHmZ^s7q0(q~HaM9N+=d4&Mf5>in9d^fGYW$+iZURaiz50ma? zi1z3fE|)4Cn>mGG+w*8n3gi=O%4L`?A~pTwpfq=}+9$iPpLiQU8UE#}&?#nunJA!? zNpw{t7(^pzKgci6#1Q1mvOvDzGbl-96CDwjM{`Wsk3hs3=$W>Yt?HJhXk9HA^^(UR zLl~l9Fdm0&O~W{FFjU3h7{7x`%VNPOudz3lfDl8=n%X0mec+Th@<_`bv?Qayj;>=* zw9OpcRIJ%A5kJZj!=5WQI)+G8lagLa`Q_uR!jfH$+RRGZ#Y<%DhVm6Q0(f%r{VxDEasCJ<;D;XNQG?7$K?xjJ4UAF1m-1sG~`5b~{GEsc=% z)ElfP3Q+p2RPLTc+c^LHt8!;SAZ1<>V!9(GptTj`ELj@I23Bz*u%^SFcnmZs0EJ4d z)>E`79-2c1$fmp0DzsVjoxf`y-VfVwi&*>?!=vSW4{txv?LlA5+ARZ zciW)0LRIcz%$D>#sgz6R{^t=gt)_Rx3b(1g)l22v#;$4VFCl0mc~SLk;l`g`^>Suo zQIOmAcIln2M>ck0S#XJev(Y|))?Sx)+=-!|5viuGjtgM=P|6s+UBvON!&| zvN+uMW5R|!;WPSmh^FD@dP6Tho|d!v^{C0UIbZ9OK06Bu%eK}DSIc0DBQ2ReH^KzM z0MnggC{oQ9CGjIYGNen>;9;x6?&0%AIEIHySoHn3CUYvyw5Uv-op7PY!JmhmNdrQ$ zbb_O!0L2Jjuf}Pw@+Q5PB~Br*rx8aiq4l5knv889^T1o3Aj5(S_hVv{d_1CYZ}H2U z9Yu`Jv%e&1UbT5smlpotpISKVuj7B#S*fo!&djcHM@Hire^SwQMmx!&PujrBjUJGo7G@H^0WRXorN2+0|j4~R}O`GcYs8cC_ULUje4*omBVJX0xZWvz9FE<3Cl=nZA*Nit1pJt1?_C7`(R_JFCM@S z)DPQ}lvmkk!B;vnu+G`1mQUZ8gaxeiptP zQ|;BKt}I87-%4lWh+kSAmEGKk+pn4k1dic4AN3Rwuj-B5*)lTFX4KcPn>yXubpM4; zX^{hrJNg!XXFup=wOux7w=Cp-{;NModU2QYSB~A_-ej0rS7?cI8eA z!u}tu1R_CzKnN{Lkuf>HT=u-$y{7s2S-D8xW%_y;g%BN!#ghl;?L zjKGoe_YbYr*x#e2c<@6`fLk#gZl=`=$%Rgj#&x^Z%fCoQbUf0t*S1tm>z(9M* z9n-~tnvmN^aiG7=l)oxp057uNPNhc>*n_RmL%_|Wjou-KDqss7a1HY-G!N8B3bc6e zZ=Vft68|Ud=J?muuciLqS-rnCgx(}6$kQ!Q6KKT=aph)!=p9HHLHIm?7^Vc~!34IJ zi$PJArZ_U1zl-p1ErK13Twi>G@Ma-Z?7 zZHOEnb3a3PL4R1sbPCCG}S-2M{OS6elqhCu1d*ofW6VA1^}~r%H&b?=0gK7jfAT z@pmys$eb`?4>5riFFGkcCPq5Jq$y!;Cl*kYQ0x$67@T0MlX!@jh;N=SPM6?jmFV3h zZ@=`<%L?eD7-=(C-b`|J3#|;tWwI8h09P3vW8^y{#sN$rZhqYi>#cA593R9Qkw|V6r59s zcv6wz(|$^)bvLEC)uhrcqfjfc|L&2%z8 zOJt-3r2jKZj~UF^d(8N|okm5UArzc(@i+6xB-1ZHgQF;Y=rH4fFiR0JBO91?FOj8Q zlnLLQg+!O|vSj%o#RPRY3#VAIoiGFIDG*;b>rE${teNvuisnaD2&ua%FPeL~@0J{DnrtIe(anLQqLsDBuJgq6~4N z20TzzDX6XsR3{#46aY;vhQ{bZvxcFWOvMRD(EMd+kw9?)QE`cUacN3%VR3P#F7%cU z4%?{2^9g;n5<%qR``6q5FL*0Z)`nUXno=g~E+%jr^MAtIug)gKa&UA1C%hG?m|ZT4 z%`ba6;y$!4e_5`0d#V5sS8g5^K2cVF@Tdf)R-!&vAf{GyVLao1y+EL6uOg(V!u6>7 z3jSa4wxx<}rHbOYii)_JMzESrubRQ5nkluKrKOr}rJDV@nv=MOTd?K}S`EEw6+dy+ zR~WuduWmDLZ4qrPu4t{aUadrEtyD>^3>rZ0xK;tJPU*Q;gShUyV4Y@aopwu|;!54O z<2o(0dR^vv{pUJ!;(80gdMmwp%awZT<9gfYdbO5XO`Y109u01(4el)so+}OB&kbPW zMqj~3f4#;)k48vpV@OM5*h*srMG=F*ks^5^DC;+AT`mRh}*dXJXI)RyL!me!S)_UD#P;?{0D!PZ{A z)_#xH!PM5_me$dg*74`oN#eF?!M0hwwt0{L3~k$TOB>bdO56H#8;rPpOR#-MuYJ#> z{UEjdsHOd6rTy%={ermTN<^^ZMz7<}qvIj9>^0(>UFPp$6D>etFJ&K?xqszrqS;XI_hFrRehfH zWm)ZJd+BB;>ERUW;nwf@;@QKO)+5l`BedEh^3o$p(#!c0#i-xQ#)BrF)+^uItGL>$ z{L-sR(kDlvFRkCl$piY9)~DOrr@z`~@Y2Vd1}|ybXBwtjg9(67>SyNju^#QWS?zat z>32HmcP1Hd5gPcDHsIDe;JP~Celp#n1 z5pAPb0S(i!&akgd%X-bq+s!Jb&nl6QBRYlD-&rIv>%th^{ngcr)SqV=0#l%^Q6fp>ZK?ZRu~@lD_0pg79(z*X(Eg z`N*}!sJ7+ewPon(B8%;!<;z@A-j5XzNjNspaz@)qYul8q=VCVLbiMHExbd>K-6G_5 zxwLGxylr)SZB-S0;bOJ2_s1HiH9V{BMDyBzVbR(IXt|wsbzXQqH-5^Sb?M-Bb+m2$ zXl;FRY@I@BZAW;+WlO;Q_idXmYnxI)m*vxAmfzO@a39ZzSticH5%m5mMoE>;^lm_B)@wcVML}@MJs65geLh+uuZX zRYAi_Z#&xMIfQEq_dkXu%6H_*_GCr&STk0@>ARL^GHTv1p!b}{+nxhij~V)KZTj9H zHV#_!;Rml}Bjr7l^?mRjOs>=33gErs4Q*aq`C z+fT_jOe0%bTJv>bJIwLUx;^Q>D_si9I8rZP;dTHI-IEl^yG zySr1|o#5{765QS0-QC??Ki>EI&zd=ygE`50)=G9(o_+tWy@-6_$=ma7{PZYeKbp2@ zi}(CBVqq#`ILqVwf%lT4?!2x1Y>4jSqW$tpYyVbyxI_999pyX7Y8bYn|1IN+yG(tl zeGQH8ir7o9*Xr1A{1Rc~lCokwt8J~q;)29_r^o7QGGhIk_6_0K^@=70{l_(V`&g#u z6}`ImiHz^4@LFo8ngQ!8HY1kx2`DnZq}39^dtWC556Df zav4h&GB;M4k5jD|RRoWIACJ36)0P=$Pw5ZtUQc?FPg7<6rL8wne5+6GBO!dxvte-N z)^`yfA5SHb2m56ged`aE6;IU{V~PyV80W{}i?e9=k-UwUq|xgi9}AcnFTGx2g_$dL znNPz+FC!fzSssh^p0^p9W9QF&DY9>5fp2MP7p-2W?H_L%?u)TRD|zmBeHREw3T5+N zZarD=WC`zguHAPnV*`?`tdy%yrZIZ1#}^8KuJZ7gNuicaJW<)VydPt z!O6u4x|WZh5%5OL;Hb>qR>WUGBY0pi0zU0$eQV(!mR6c( zF^y0Bn1;qfr()?5o?g~Aws!UgB&6q+*0S=9$S9AHvP}O=qRG@Escr2QnOV^>v2kI1 zAqMIW!K8-!Mhm*k67Vl0g~3l^-0lvh{wClm%9Pt1Qu^-l3dfu`82uxhGx$xMMzbRN!M#d7(~5Eh{cq<>fhI>go5LRi2b z^JN+>T4pSarHfSt1LNSn&%;M`G~?&uBdk^`4O%D5ZJ}%x^8(}v;~KE+m787uhE#^g?l@~PRVt?#&fRpqH%c+_Z7jhm_h52P+*qiMbLnKP#q9(69Yb*?cqElC zlDEyZm$?kaMuVmG_38FxrJ+}i>u3NQ{%B*I$oTz@0?X-v;IqChu2S+bU$4c+vqpbs z)wRLUGSLmhv;vV>SDDsKD?m{Hx4w18R)oe{3=<3}WmgWxZiCZTXZFhZPk!rbKW(w( z-l!ENfGlPZeORjubm#hSe*5M7ve&>ErEIvo-52aIW!@pfL@i&t7{3i65HR_hBWOR_ z%&yTW#iBjg2qu&#={JnJXRvXGD{HY{xOq>dy{ca$oUYq%{;)=ATLH* z+*HrtT%ph>)N1&{v~Uo&c{@GKTRf{MV7h1zluXe?AXV`t97~>UN0}GQuUZ&0scnnc z8o*U?ToXw4levsI;Ys>RHzZ|HP@|vmE%vMZj3{4Sc!7X+_+18z@&MDgujAOiaXk4n zhb1FU%OE#{qHIFA;iL_X!9LAS` zL#vcPhuyz2gbsT#umOy=&)o+1^Bx~deBIn9JklhV)BKsKuJ#I>`SIL+}?l*hX?+fg;>e5ep^U9GtcgvS6)?8^7 z=kAZ&*7>a(dEfuG8x;k;-97Jg$eflp1z1w`w(SA-nD4l@qP#a#T%qzIY z@^J!G8;`G9Q0g!|b@AWb7G|_y7({=wzL}VP{tOdiJa(dnRN@?kM!ZZe)P$yAqLRIi zB8b2VlDE+dq3>#UG!}N?A4jRP?&^GbZixBlvjGVy{6UT8FbQ7hS-M1m~1<0nF!Zp$pve;k>GCNWDxd6d^@Yu?r;PR za0?FyT0+UixWf+ce|-6?H7Q04g*YGtO)nP}Rw~oZ4F{&vvyF)1j{A^ell>r@;8b3mIV`HTWWsYqY!oR}HDmR*7eXJJ#esP)fLWmQ8lBDD( zXqI&ozf~E;SYk~YYBLe|mpgH4Ql`;Vok}T>>YY|3HZdO>w@->rtym19o-%EhVPEhR zwWH@b*W+|-&syjge1OvTdAxhtzUXHUg`_ppWG-oZws<^@T;_AMi10+AFoRaBurT`Bg93Y((=g1?7lss87?E+UPMh-xL_Dl3^AHFBuY2pD6=#4FbjcC5{# zD>F@&gL4_DO9-^2lX)>wj8YP$zCiOgkqKQ{J$7tJ@h0THJ)W8Cj;TpGE42+B8C^tR zb-h5PvCc$XI~2HgsGU=`jZr6GFmq}-_FuEe{w;G9;?%mkXAbR>sJUlX*?3A;?#2}& zh~I3--eG6uauK@m%KXo)lc2=`Y_ANAU>pR%L1kJ^;kw&RhmtVgJ|23Bh4T|q#c|Qc60f-1DQR%MTC>b zZmf1<hsKr9jAs2 zUaF7^^gn+_;YLK2I1TU^)uhtF6=w>O4SqrfyeT1wTqn5Vu_iP=crW*mz4P=f~R1(aJ`NZtEk@%b`bY z#vJzrB}?aQ!z!uD);n&S@q>}^gHkp;d};(|N)0_r;Up*9&=uFxkDX0R32y3X!=~R6 zgx`SI4Zm?JAU7%+hgQhWWl3Il<;$zs(Gjnb*fjR*eK=Zo)NB*H@OGx>o1zAGuDSx< z57()y^tRRPGObrPcM^)H4PU9VAf5LWyqxE7p_Ws@EC(_VEE6p2cc?^542%@bOH?o% zloc#Tx&v%8is!fWUhIpjlvi^XR(GuoPTfCMX%_aREp3GHPQ8s8y;AM&doz*L(E`)e zmb4zz3tJUvBHjH8MNKBA-EW)^T15#i{*-0jT_k+)-7f5CJ-^;lo>VHeo=^qk<>q2u z7e74j<|{lcgWGZGU6{9dP+rF?TBIE^8w=noUW7Awp9BfHOcGz84|6HjmPKjKlC94L zw+ZhS6Rn_42c5P~G+&uSsw5*O-p*+>E;Ejrp4<4g_l!KXfjFH2jZPr;Vh6-R<>!L? z;k9`Q35g2zfepnz2{E0B^UMZ#RYpcS^xn0RpYsRS9C%(7+B7uKHwF1j?%AQ|Vaz_z z+&OtA-Jv5<0)9bw)0lXXiu%&B17NcNU2L9`Re<|MUs9Z}Z0rCI6Fk>GE;SKX0(Sio z5r0BQgYPU}JC{h1I$n>kq^OHtOoxz)CS(E1hd`uFFBD>cA`S$*H{jkFLbk};k0b!4 z(JOiv+YH-q(8>RUuF&VA+s8UdU5>)b^qS^5lZ?nD^ldK|_E#Te>0Pn1in*~78 z@j}W1aEkh+E(DPuVn-i>{!(BFiUtuI`6m}?gW1?5RFI`rd{=A&=lDUNQW+XQ=x35< z$fBPiB@P6@XOfMtrYKM>0l3f`+;s04m+0AoqlyUY`CG&h_p=Sa0eTGbdEtQ2Dd2S3 z0f8#a1f5|_zrEkFp{Q`a`r$x9WC4JTASiDDBr1>r5O~i9iKasivlzaC6S1Qbalj5Z zD*A;m2?+y=V84Xq>hg9wgk%DQqv(JjP$7smz1TZLjIJSpSpoMCfMd|t&$<^l1c1x} zM4KpnY%E4WT}DOuM$d|b{JobuCO{1YPzna20m2s;L19Y(Ex2$z;s`XI+~JP9I~18`ypL>dfFIf5+5fl_+}aLNXhgP;t$y^&BMKEe+n z*(*bcs#ETRpp;a+CJ+5^`D0s>nFjS zMaYQQkavo#kUb|FR?M&eF!@Fp%I|FOrx7bSDb)IV3gjBwsmjbW0{ z{=`WTUpN5lj|(t}!#557t8s9W5)0(^BZ)2uU@PoPARY^|DTu=ePz#9`GqYd8RdW=}`#>Vk zt0c~d9mVI;Qq#PF7zZffe&7`NcE%W%q%8q_CQx&yl5}An zA@~eK$@fL9#RMAyivK84(4b*L*pwfoKpr;#hfi>rCf*rVM-d(W&Mx-c<9$ zucWM`xtiY%(94)(hvlT+8j;gw$W?nbjC%`%IF;k|S#UmCi#M+N`!~~woI#AzSaca3S3Ng*wrPM$LA78>Z5U()& z*eU@|*yg0U;)QqbsmO~JIRjkbk5>}rR^g)8>ggU9$6;NZ> zsvQcB+^zvkMm@sjPT_sEbgAN^p&2cy)pg7ME6Qbv*SY~{|0+Z{bcw291=n&i(q68g zw^3rWgp2Xi%;Sm(ct8FF_!7>q%B(vGnA0hM-+?@bk0Zsaj8eyMHt$3Y?JR(6>PcmE zm#FuK|M3P#MLd#UwDkKHFAKf5!aao%Eh_fQa&-e*OzX?<#N~FH(eC{i|F^o_+QX*x z)Rw+yMNJHVF;xP^eF_;8U~4g%?~iZ!QXAWkIQf@4iIG-Tq_!c6Mmj3EU%_C9I0%NO zI#~S{KRo&^(62BKsT>@ip`}{3U&*O0URmr7VhQcktKA=#r~RdRL5GH__5RKSgA7D^ z9q4JA*MIT@l7K%b6RN6mh!Hw<+50oq^B1!kpL1Fwz=J^A0&hxS98TZlW9zJV7lFDj zzd*+XWyE4-4&>$MH>#C~wt0FpgwqDNMqNUC45AoqVSeeaksQ&4w8s-xG|*nSD(m3pUA0IGLPkz{$nW3vy~YRoneQ>hIG8^OxMZ(TTS(3R_x>W5pza zsUeTJfgv}`4x}>KZ%~gK1*8bxRn46UG(*uRBOfF!l`-%`Wg1fen*_jOt=eLEUeb_} zMx#?nTLpYLaT$O|C;_a}9Y5;J2mixJqRlNCIrT^J`{)c$*>`XQ#ufAkE@Ka1ek8GV zU`4(Fx^Wt|GL4Ov2JI3sIS7Dr9vDy`!(OX^dlOVy`rEr1B#P{5-S-7{}i>Kx3YNy4?!%^3X(@^W6w#h|wzx6d_Ie;dfY zsB>#5Ue|GnOG(+5%&xnFlG|Fe_q3ONRgN$lS&u{#+FIb#If96TV{i)d&a~q_mLpO+ z_9_{VLl^5TS2od@4_kCk0P2EcOW1!})cg$dQ`;Ci=y(#>8f3#)t)3P+UhJY0R@dSg zo3cYcSCR!yQhVdlirSWCs#R`{jrusg1YcFnbkH$#I}0>UzLPyy7O&=Kr9CO9SXvopH@4Jgxf&<3yM;_$T~2{yPe^iu;ZCs1e5JN0~urfsXx8i+IIgT357Ro3Mg|$rpO->k8rGu z#UminobToP5a`nHhw$$kj_-k|jPIURH_SFm`>b<3mPnqKY}dsd2={>W9SZB3N*Txx zYpXctia_3$CptXhjQwT%{kD57W8{5<2&{;V!+G@0Q=Ub}Qe8^!jYQ;w4!4b}IaUPE zEu`_)@v_Yih11^+61z^cdzV}mcG4{|9vQ%QpgQl7&8;bF_Q{v76PT%!f7$#^?Hz!3e>*`nxb zH_z?4hP=3*Siklg=%*x7yb9`H?bI6wuDAN>t+$^+zP0>rKSs{j+KmWFDLXhGKRY7a z5Q!9nP3+#M@7ZAexDdN<=YHj-{u#-w=0(jY!>z@~uCs9mZG9JW<&#+6|0OW&cg6i5 z>)jgUI}i0OvEiGl3^6Dv)qQw#7^{Os-mBiR<~}-P3a@*n-hbN$;@;JdcLkh5M7Rv2 z2ZNeAP7IIsVduWq2V3uFQWsBIgtw`ls}<+>p68q2wKuZ(?qcS}d6_PvP(<2P_>q4n z2ITQqXYz0Qy>vmpIO_&&Pv>frVVXK^qHNq8%B~AuT&~hRywx1J@$L?6oRy8|UWVVE zm%sAH-QX+Tw4waLd64L_e>qLOYnFMp(|ug`LP`nghS0>1M^TfEcpIHqawc3~E#FFy zeDkut6$8TIe&7Op0|Fre{epx2!+gU-qa$Ntb3FNfxVKf)a96gk(^qAvTkxu;4c-DR_Tww zdcv{4w5vgJxAS?q@y=@rKRBk(HN~j)#dqE8(4>+Bmpn&9r($$`sN{-|_v`+bdi#XU zRia2W^Ni%0s%SkU9wZEuuB$6b|5R_KTJEc?F;R;ZL6^4MNMBHPEHPXvH$Zl8F&M$G zld6684=I{rt+b+t9w*XH>n{d^Tf@rF((GF37AuD{VC|2I_P4p>@`)DT3N=rf=lruB zSj&yJj}O2X6%;lD2N5xcFOUs*j+=i5h;_H#<`QkU<=__d0{$ovN_GH+MeWdz;}_MJ z_Y`&XVd1gl#f#I(H$xfwDR;t-qOmeVbMyC?&DlxwcOous7bhaQ?nQQ@Mahqdtn}s3 zqio?k8#KtGi+|+%p|i3NN#J9OAp9BojiY571xGHfNwyvjpFEu(yy%a`jG;#&Ol^9Q z?z~@INdF=?r=2PD)`So>cgRMW?QmZF8Q-o$C-}%*?~%FK3Or3E6Y8d{*}$eOlBAKD zl^i+#OIR)lu&T$T4vNCs|B*)*GJjZFJ8wqg$~yl+u^Ux}*#uS)TBoF`@c+VbS~f@y zK9!yM1uvUVr3OnKKTl$!oRn$iQtTJ4RI-;cp_z!CVkZ4Vw~;czZMD!Q#iKDN7vQOrYxd`sKtu2okrV4H`R)f=yq_ET$We_e)0oS+k3-7`odsN$}$D#MQvn z>XWUegnkz?oh`?qmEfvxrPsr>4X@(mPj6h>mCb~R78Tvq`xRsK52rU?lH$ErB5P(3 z(;nqTSz&8mHx5S^_}~Yp4u)1xutJunR>#VB`VV_C~c-w^V;^u zk^e?B+k2|!rq$$BqTG_-_wf1L_ zDAm)=UkRI%JpFdnXVT2+%XhpyXx#UHAku@DYK<%2?#DOP6*_EF+n?9+-2^>PRtqaU zx#74rxHkL9+PQlF22q(gqaezZH{W8bM89u%tYevn=Ywf0@@sm8Ii1xCn?<^D6~ZtoNP?a za9iKd?7q!7R(=`}i?A6keGM#;MkrwfQ(v9WX-#xsQt{FayBBxe9*dV(pen-KLT6Ig zxj!ka^!rs0mhQcB>jByQ~V zX4#Ae1?cctBR!%$`KBhOlKv`{%HzHt-)eaZg+$OhOxYCYEq|3^9Q>v*tGr*bi}C)c z_{&9Jl9(FZ_LNi~3p=)MfFsC4ZS^ld820aw?z9~xq@k1250Q}N+X6Ts1z}&Il)|+_ z;mT|l;YiR-4k;oV;LpLE??ZdwWu^b|YtGd=#)JqxQZ9`{9>UqsxUkl>o!AN~+0hT7 z+?2agH8~gldstkfH)6 zlWU)i>8~vA(F7GuMSi#DOAX~_i?sfm&+FM)ynd+5@u~%nU%JA``R;e!ZMq*D*&YGt;tiSN%|5?#AFgCZU}rrk@g?Z{U7kC8oaV9nac>LQ+>B^vB;6oGstDyUI6URYfm#k@Te{w*Q)y~u zqCGM|x&$>us(FdD;pzEf@|+2oNggy(<;}J6oG1Yl}O@>p(XF^JvCh!0BX^%|A&zVC|PQ5i}hkLM5d*(W4h?3MG5RRy#E%2ggaR%%Jf39ZvB3<&O_j+{>? zRWve9P2Jr?zO2C?OJay`5+fJ1250BbRsc2UmAPJ)CL`+6QVrKc@IcAGXhU&BdH9Ry zXL3i|e$G}tH~hW{l-e)9(}l?MvY@FTXDLp*!NUrcRbVuQ7>T zQdQ?elyIb{%J|^nrnj@Kf%0)`qmaG|QFm98WEQQP3O;AiazR&MWjA6Fwy+Ut2oQG6 zuydfxe)U9>CnI1p;BV8|5PjWdpJ*C)I?uT0A>a6-wJX~dYHY@e@H_+TU%iNMI@Q+a zIE-CmNeo@F7d778=bmVuZy0wdLtU7O*Q+t0GFAh>FGSNpTo(QINV3qjUnwLfY_ePn z9$l`X9k_2Gsa*fKYkzsW@JU73%@z%vr<=Dt81 zN{onmTdE@M&KilIL7qTkj!p%S2w0`)I6KEf12std)p-M5WtS%?U)x2H)>dFWi*y8g z@M#S~++T^IM5YM^M;#m!RZ=eRiZ!#V}Y(M+I1dR7zArI9&_b2El1VlB(Lx| z4^X4Mq>`ZJE_ax71hflEtte%ON)&<%-Yhv>9~^#*lhy1Yy-j$n_TFE~92E;skR45< zPBj~7deQ4E2eSv|*JxvNv&JauA(YtP5o7f+j8vEjhf&((bcgo=_J?$YN9HfEN-o%L zyeg(=ur{v+3L%4HJkFL*MR9#eF-%W!HU-hbMEHI4T5)22&`=YiIl zS@68PwH$nmSN_WSoytywGIC6X$5bvt@q;6AC`$H#IMKc-=BNnkeZW{SnOM+I2ND)k zX>1QJa(jr4ba1qZZQ#>Nij%7QRT~KBoSc}zfq_;hpV*WPo{2(0_hM_b8xwR~5jOKp zc8u-Qeo^`NRW))NEk&z{9VSrb0!J!-D7g_An2f8?l$`jWBq+xmXO;NZP08zO!BrK` z=Ii3$+(BDwv5+%3ug}$8pDI9$6{}$>b$JP;u~XW*~FpfeXGIZl!V z*#nF6ww)1GwDFM6sz}m3WoF)8_(oHv`*iwYRQl@@=mE>D)+ml>$PGvSeLXh?Y^*l!=270m%E~w{AI_h-vH&ZX; zhv=)XNyBI5TL=g5=cq|iV*M1Cl$$0U<{75C^^snF+e1jV-!2CWI3e} zM>Z@mnj9}d!WpLpALG;LSo17RO`Q(Jsc>d9Q%SKmX^uu2AsKZnhcl1w z`_zZ=s`uD(ngla)I1F&iGDn+|yjf))f=OO_tk@8=s8`%cR0A27gIyo3AkC{k_Ejna z&}#+oTtS>}-N|6l5U}1;b&l{q+0gEEukyu1!Q-p77)6LZ>7?K=_QtYN>{>iaSwv|R)!3HM{UMEh&CAIV9^5n;U zldfgv3*fB1+JPd)A<=;Ey5Uc?xs3*4pM5SVXSJYObsD1UzodHV)T$*GM=maw71%14 z#Tu3t(_!ZZ4yx4Ac$f4f(fO*TZ1cPft^(bb7R|!Sz>#VwtQ6hZ0tkvEZg9h3ef>vB zLVal63$UA)~p4H)q0`(T$@iA1Q z?4{V{`Ue8ty(CwpDNrSDT$T9NnS?_;CUGF_2sgyn#KO0z-m<06_e;e&4IElp*B)$e z9Zd7DR)>I<4t)?hK(g=)Yt0XfteD~E=@uJuq!59jgj^le#pDhSN$*3$o>t7hZMJ6p z#*%^)xTEQSkDD|E<8H8eo9>ay8+8ktOCe{mp@f<OeO| zD|c%yc%-}598CkOKWK$TV!nQ*$oN*QSXH96kfp~Je<;R6VIx&~q`8kazFkelA=woq zg55oRSDKYjVOG+4)7b54itx-aRJuB>-`_zHR)F`tf|^n2gr4}Rc} zZ(GQ7;g8%L7cV2_!ZcXxhdDNv(oOpf#s-Zs{A#EuG2Y}Je{P+a_wPfb9adv+3i#)P zGEka<)XzIuWSBP2p_<^vsYB+x;`%x^`)MPL4O^ckQ@b8JJo{2z>@RsYYT4j<1%ZbC*mWq!{IbiWsMK@jL0iJ!L4PC zuNpkz(tW23A90#XIECeyIB_FXlMy1>*h`TdV>V<|*@!Ep0#Z$(?!j#AO9^Ys<(3ps z-`WpW@^6y)mIwQPXfoZ4)>$RzOT2W;zNEzkNMm*`S9q)rzL9PXl}36veN}5;g`41C zjh$~X55#UM3@;m#=%jOP=ucnWBv?OY%KFPQ8V8vw*#Lexw9c>`X$o2%51a2f%Y(BZ z(WP5|Oy8(kq=nC|xEO0}B}gBBNwS1r`}rx%j74n5tg&g%9>QL61tP>$Nc=kQt z4@P9&Q}VX-4kn4^qXj(UF4~+i@Rlx42Pwwg#1ryp+TnDN50gIcwN88+5Z(ybW#7v;5)ZrN6M9&et$E6X z2#!=FCVvok?1Km0GiK}GPbUcltOHRgi?B)4tHa91pB*qYe0pDnFaDD6vP?JdPwx`lAdo41{R4eKkP zOX_KJ(#dAR@J00U%KW_~9pN>5$NBef$K3SxPmAOF4~pjm`9&u=dbKA6AG4IyYw<1Y zJG?vjj3*p?cO){zKPr5`tV`UKPktLScW#@c4*?gX-Mzin$EO50Z_)mhxwELa8SD09 zl3Am|Y@kQHXU# zXg$8iXx+!0kB*l*!hq0;QMmL|pR#YqVeiZK zCQ10%tqwCdlx5mN8QtHu2vb8e&L#&9Sq!1ySe*Z9j3o*bDnZH!V4k~2E?jnJ8WWdTsWy?%;dUDzYck1~402IDpF_kU)XcIzW`8VSmi+y++BDb7uiI*c_wOCZ z{n2>2)fi<>qmqHb(3BM3#%AJ)Yr!%!vvKtrnxAzN=u}%@>)arQOsYO3{$nnoedN8 z>rD%iWphk-qin6Hj8dK7sdm%N_Dw0W1J9|9lH}jnjT8U0H|6FC_8%1)M%rtcnHE?H z2j~@Mcp{h;51WC@+Y!u5tN+X1u7{YHH=hriSG2z)fc5kL)7|zH zhFVk)`&J`Wj`Jc~ezLa`mUZ*OG^f?eRwHzD^hk{SP-_tqR!x{ZcM3nE-GgL)9;xM? zxA^v|LjVl>@Pb>O&tKx&4)?(fosb`?R$WhoVf0-vDt+h4Zvn0 z=7;7b3b&sYXBe}e`Hx$ik)Tes85$+_St3M5BOb|kaYJ}p>@+SWamDTxa8oGHSr`+;*Bc|db19p88 z>F>$;YHhqXVTE)C$vL5s>7jra3b`mubU2}JnrVTBR%f>8`GNrn@0T4dr(=P!Lytolyj(L zXQHQFYp71%=WT>Y&m%q)U%@vhob+JGyPtqsR2@>QlN)t%{?AEOv)huOa&-?sF&J40 z#iNT$2lDEwkni2p+7JIU3zqM&t4-6_BR$jyk#o}*)*BHvS#$${lLZAs*$LBkLfA;X z*!>W14N!vW+eQU+AyrlcB{F)?3Wx~WL<7F^q2eQRdUwB4G5mw>g??eSZSTsax2=Ez zOe(<&QdUPod^(bsoSnSTQ78b_A8BNbMS=h!VvnK|{wo`}Zti2-NcU6ynMUJ3UB55U z`$o0^$f%qW2K+9;2HH1Ya6|wKXfB)I^2ioNcpFS*qru-vUz-PoM}JSpe*d+JJeLTM ziA)@Ag$F=<`cf{kZw*L*_FloD9O^O^h!0jE!DXrv)qJ*1>ZL7U`iOj@LeoOmNYyZU zaRhk_Y3W7O((VB+8+%6gV&YN(YIJM$LK3?di7L?8Pgo-3{Qb5u;YcP9K+#mJ@Ui>Q2g+F7YFAL2sRd3;giEe*s{p@Qsl{8xGt{wpQ8 zC0T$}@$WyyR1o~Ka^9GZoN=WTb`l|5OVT!%&}Fufp{u+5S5-xI=0Z*?vbmhGWvzy0 z_I%j@2ukayN|*q0v=-?ywOB9%MxTBU=b9|^FU*OSbh{12jmKdps~q6b`RT}t@A`B$ z_SlvKwzI%Nb4=9UT_Irp+k}n7&#KfY+T9RDJ_sbc28p&$UU{ks3N)7nOsFK4X;cLi zp?BZVRRzm@gl*d~Jv1dXo*=`dXi2zZ=XqdWkQ*g00JXn@1vT-%d|U0iJHM`Pkpe?b zsIN(E)?~e;QBdlZ==A!WK&IeWhE#;@(VF8b+|iYgIl~S?G%U7>XBc38Pu7){{hBIF ze}l+2B0<0wn_`jjkt=m0gscbTBQhfQuD02$eXHJXa4W09?e1t z)F7taX#6iBM=)p?fkt&+D=k0JMM^>&MJ^qsNast@leVaY+wyEVt8%lXCZqA!n&dk0 zHGPn1vTXme-Ll11Lm!(wtk!vMV$$z7HJ;<=qt1qG_l92(1Y+>*FWJUeM5I5vENvK; zm!TONyHMKE#Q0|W#H6z~Q|MZ^Tseo(6!s*1aLxL+yn6{cyP!$hgGYp|w?(n|uW%}D zvjbP_d{f!ON9ZiVf+qH8yy(WhA6~7j;cszYG);IB#!xr~bn`PbPen%D%D|`7bV)bQ zWFm{+6;T);sgpF%Rz%#DniJ3}EQrrEAm3MpdYs}Pxhj=P-`9dY&MYPZma4Ar>PI}z z?JHW=E+QUUPdqN%FIqM}kRLk{Jum&~J{feB$j4q9&#N%&)*T|0r$GtN>$r;6J%-4q zQFG6m&+|+Nd??S8p`N#S476=WGLg^!3%%7wd08CsysxcjJF|{_Sv~Q5XuW8=@Irar zMD%*>WiV{Nii~{SrSWQu+}jK@>vza1{9|F~I{5op>7kepC9BlKA++q6vch#gfoO*q!-{{AJ(* z`aAu7*#hKD0+bsO72%wvK>?bF0X4Gz3WHSI;-4vL5rW#I)2nWv`CwqKccAA(pe1FH z9ZryYVcWpZpc~CW75_%*)%Tl|Axn#vJWiwuKCi}R4Mq<8ASTQ( z>KIz)77PFL;E*}4o7oZO{2+YZ47d7Ij1YVxpptXIF-osE?g#v@GFBqFp#;ArH>)K~ zt>pMva!nY81cRk`A=NnP{=_`!#4UYJdPbXW-6a2^s$q0XiN!APxLHv5&We3k z$@NERX$xtpTdb~2@KC|ZVWugOrfKC(iKVJ(*|L$)y@vR;A5Up}O|DbPnCp-5ee!Co z1$G*$N>NShu3~9%M~MWo8vaK3ho%{CxGt|Qn5w;iWHA9jczsgxm@W=p$Uq4lm$dY_ zbXOyNtMB}nME~3-5A>evrAcx(LXP)9(?$r1{ zF^_nFlx#LvXI3*z+)Fi-N+tTs5d1&-zoGJgjG3eyQHvO|BR8fsSUjiX3ZF?aKlLBf zV`&8dQh#LM=Iv7Zau0Kb%isa#!!pmg z!UH4)I>UwT4TVZqpbbYr^$|R|GAEUYbz(lZP@MYR_r{_}O1OLH&ktW=h}s{7PCLb& z2+dE0`!d+XvMsZ6HMa6EiRA~w^V8EZRWv)(@h=^5)ca>2Im1yQws;pJpoPpRQRr3i|5lU4AG}YI0 zm7vC|@1fOU_f_Fo)qGSZ)me4b)Na)b5;f#lHR|)#EZ`a|#p+4U8ZN}zL5CW(gqq*x zwFoS=I4?CKBek~iwJJ)rQZ#kFvo%f+wF&Z}byEN8sIcl(PwI$j>nO46wI%BJnd`~t z>kPs5^YV51#1vry-?>N|Y$Y1(Airl`)_dqRxPcoyxEe6L8}uL>++P|z)f=?_)tiGG zhtSh!d>f{(8>hBeOj{b`x!A%Hn=&>VQ&O8WVjB|>n{$1eFtQ-?&6^8Dn~Ou6*$12Q zz|9qiE!3UO={D3=<}G0L3N^u|`RtJv!)*gx0KJ(Vk=a!|RwhoN6jb9otx(HFY29W?(l@89mJ=<02*BsNJN*ZUEh$p@R0zd?j1i~yFWu*=FwkY z!@4=R>8T|lK1l`d6L?12j@wXpfv|R=*I(p?r9Bd*$@F0e$giEBvn&E_->=dx*4AI_ zqrK|xwAi!=!rbl5Nd4Hry%*Ecz8ryY*MwR7!uTjrP0Y zQ(K+(h|&)5ybipp_x^kB@=xnE8660fEDbrW59IEq=k5!_9{?~8MQaSPjkbr44()CC zW$O&=st-ir_vF$JvYZZLr46T#4(D$V6ln})bGI3~4=cA0^X3jDmz9n@UJchH^`;Y# z9EA@0y!LyS)-|+_47rm+BaK9>kH$z23x$86>N@=N{kSA zOuR}C#nMcXoK4^OjK8^0Q4Q5xq_Gto3uZ!6VX*0i|Y;&PMUZroW%f zNLa=~dLWSTbbMPIr3#x*h z3|tDOTY6xdiR4*I1TOV+PrAJ=r1PNXzR(I0EJt%KYkr^42w#RBTOtXg`>DCyxw@LZx&Fm% zhhRJNJEHIH;oTVl6)?z({}y;TZWp6E7# zH`Lm~d!G;ow#H6E(V%XBr{5-%-lnkH-oV>p<=VnO-^MxJ$|as%#ormq-(q;*KK7XT zL)XDMu6h@~voE<|udy@eu=Bfp%gAz8{mZVX)tET%EbjV(wDexZ=8pN;&T8v6+5E1Y zV)=-y^v++cT`lST2kw27vpKc-J;U)K9l~*AWHpiT{WFiffV91>*1em${eR;FRv8mE zFQ|dBZTtXhBNwq-yP@deYoL<4>xeJo@c)6w-kU!z9XG+UbD9&Gq?|* z*F26ATvyr%k1X>JpBzmbczBXp?W0B>=TRShZ(ksy!A0_}9C)$`? z7C3*UTKV>+l2pC*d1>mIM&P+(WW_ip>_Hj!(fZ;;Xb`>5)y9=8{VMxJ`NvN3}7H z;&;$8n+?sC`{5C((i2wTHJjvhIPPj(t!PJIwNBhL`Cnld_^q2 zuIql)wX`RTI~9C%Gf2%ImT_aQbks9n7WuiU(Q~(%=B9GFNtt`%vG!;OmcIGOdAAj@ zS#CW)sDG6C^|oiU3+wT=iz28g@?Nd9-M#R>GxVmE;NiDw8M_72iq-s9#+`8K!$|AY zLB?2h`ctm7WbWChqd_rPUGAJ^b zIVB_W-}i5ez{Ms1i@=ReD@hCNAMgoDkLw#8_ZyBMo0#|=J2gEx+qGQujClmMvH9Q5 z-=1HxX@p@RWO<%<%>n7P*wQ;3&R<&ADf)M z0+te{C@zMe*@w3uOFu6^Tceq;L#bc5oUt?`UrQ&L ze-bqFbkqO&A?=ELXTI*|-qLVm?U7@?uwFp@%!b;KW~G>lRTJ27Gw=OFx=M4y#okQi z9v(yUnSOCXioeDJ(CzM}hcl*O0`8r2Kf=2*n{rQynQ8U|!qEWT0-k_nv zT|V5nt?ZR~QtcIU$#DlxN|L$G7BTjIBjHcQl9vf@!^+$Jcjk|KX#+WJdnt{x`}&`i zWkwE!l%<{y1eba7rSrEg@C^xXaq=gHKL4Wig~Niar=H!S!&(5( zyxZl$4hrPCa+}r)CX)D*N+G1aq!=}qa*UrXAtC0N;66?DYI_!u$u6FLD zyR)&*S+{eKZUuvT!~xu+o%1>F-*%UX1GthQ3z7JMlg$FeVY3sq(H7Q?R_@2U^T_9| zSKHIj$A>#?2hXd=y$k@loRk|%nd>!2CK5@XRAk)|5m68f#fkP6mduzBQAs8mH*5Pm zRhhR5Elw83j}@UN_X7t^$}B96K53h{F+aV^MRXmxPVzY(+U5DKewVgR>YL+0sk@c_ zcP7mghWa^o&KZ$Tkr6Jew@0PeZiTqT>(D?hc>}2@Us34`>WjID2 zXN6^W2;u;4GS}ShS<<`W)>}zZz&F%I70q03T+Za=P|ueA&kY z5;k?7_u%lj@}WFnEQw6e_(X7xQVNHo{LoiotT1Sp6a-T$5Zmw_QoHz(rga=bxKJR44yOH45tQnomj8Q9ek?%Dj^|)uN}NVvJcPI^136~%#dEXq2$ypy(&lG0 zlbT(Tp!3{j-oj`C5IE{ICWqB6yh!O6;d}G;bJHhjMJkS9(!`jZM1#~#diQ}}+2n#B z-!wClTJKh+QePCkLk(8Q=s=?RcmY-)W>>ETf1>d+DO3svYxasolZ335q)8QP{~DNr z!iS3t!;5uycoj*Flmp(R6@R|G!Hq-|YgjN!4A2#4I*5tOu##ByB}oP~U6ta!a4bs< zSm7~<$SaIH#?rDu^ z^>C@mk9qOgU3hKQr&543_QI-mTwVQH>5s%g&2_7|`Y}DoPkHjGZSBkIX5AG#=%Ajp z6-g%3fhfdD_;ztN-mz(Wa>H4tVBx%BI=ACs*=1Q#UuDG6{-6}KyBjr zA#oW%L7INDvFs&+KLq0fULm|~nP2I(L7B=`2TSR(PG{%}-kESC`|q+qSsdx7FJn3x zRd{{G8zqT~uez~}%5Uh$aMs9?`Py%PmWMkwwZwiu?v?7vj1nd_q{mP8l>^$wM3$^G z^I!MNM4ZOt?HjVcb{^1#*d{&KQ7^oz%$>Vs2?|gs7J&ttO4;^X{+sbgD?0-wDpR=++7l6o zd&k1zJMuc9Ul~(&`LAA>y+^!`jR#&7y>2%9RNyiRWfm;{W@)D3s6ADaAXuuEX8v)x zVXCaf-oVpl`zyumOrJ_k1r1_6!RQ}|TUDTZV5W0BH#1UGQ#*WM<$600!=tIK>s3Cq zO}SmzR;g{69zJwvyj?v0D{+rp=5Ln}OtrSmV69 z4xkXfTOo+7>-uYP?^a1i>w1JoP9j|I)@bo5>-%I?PGeH;)>&2S2Q)@b6B_R}cp~eE zj8)D+LvNS^DkWmK1)&SfI*zHpTt69Jy-4MB>PKzWDvlyU2#$1u7R{o%>zQ zKN^>81ZM- zeu(dglLpWDq)-_7k4{dlp~w#P0Bj(Ssi~=~tZZ;_FaraFhK5FGXXpO@enCM&Qc_ZN zb+wC&%hRV%1q1{Z7Z=CI#{T@a|K$rc4#pH=Xqs*UhFI*i*l>>8SuQ(45HF!0+cItB zILRaXa&7x9#4^ST?G#G2p);I(nFlVdurenlV9MxrmCLYI(N&Yj=+W(? zm-pYi7x={?DD=POJ(rN=)U@>f$a`rih5se*xq@rz8ycIMTUy)NJG@ouI(vHi`UeJw zhDRXXqZ7Y=|E*ljy&RukT86ExuB~sxEpG1Y?(H8O9vw?;pZp1YSJyXOmXBZa{FUhKY+-JD)I&sdA<$AwcGw3 zJlIXu{P|DbW19Y1IG!UH%RVgr$*$;E!Mhv<``>op$v>qlt`GLb(@;ddw|3p3-?xJtYBh8J{2TPryH2U}w*JsGo2Z#tu0*F z@Uy+;msl_Vw5|{}Wltnx6PULLCpDu^Z^{KknPn_j2eu!Fo9i`Z{Z@Fg-hXL@NEW1i3is8s3BDO^=P>VBf}Ua zp=*x07!?_|? zV=_iGubh9Sxxjaf72F=!wlkfdB4VE<>Yv#I!d#P_D47uU(VQ4{GdA(K&)}W>y&s6& z$MYR)w=kEP!K|n-e>VwSQZqMJ1nFTfW@9ZIG21Im{T;FgX~LV?gSG|am?L%%TtmzI zm?<(4xet4&Y($2`l7wwc#r&Z9{m0OQng!}COGNOS@2p)-Fh9guPcJ@fV9OmmY&>ef zK5E)z7*ITFUYrX(YMI&{JZc@j!9H&5!BRSIZzT^q?xjs>&A)^7Mvtfg|u(J`3-J!Ell^dM% zFKsLH6PEUqCV3i{DHN$`_N=+Tj;dBxZ>4hv$L|pFw8`QHIX~K;iJY z>!M-!{Bb?*<-$(C^5x>{Z20BU{Lb*@@&xTQ?iFkhQ{`&Km>%LZ5GtHruo58)q*#y9 zXd&N7Fjiq8-vl`>k!__1dXQ~rC*qOqi+YPYj73YrVI^WTG??k`NgL7T(p9aWdlmjWXRuT~Rj z2(LE_M(Lj3?AGc(y*=zbBDg!9&LFslZ*dbm++5-(Jpw+e0nn4^C>{_C2uaeJNIKuz zP{rP9;)l`3Gs?A(emeF^GNnZm?qkHkwLSLjJ*LHwz2YPU9ow@g(doY_>>%#(@dt*{ zaTu_6k|!|v1PBe$vDy@NQXBaMO5&`F`pLa129yWM>2^Jh?Rz85SswiEj-it%NA9ii z$CD7TWP0M-D>N?KlTdu59MVs#-Mk>)r^ZSQa`Xd*Jwn>v;Z|YaYS!gsgWJj@oED_1 zFN31QPfsEj0J*gHj4Qpb$S2VJC35M%MfJTEi$M){%Vkt0=$F@4Ly2!%XXG#Hm!-0; zhz-4CWW{G2c#l^Rmyd$jlI!pNRKXVsar;h1^^HwYbfO}`s)-3WSvaVdUY^)O$-v`_ z+4F@kHk7hQf|)qBXo#J!GWoq5vp@>#uqksnXo{0wsFIDU%hL8&;L=?_UMHIhMNDNX zp9G7TUg3y6xGe2Zf{IRJO}@i6J|neqINMT&`a%v+MGEE7$SR+=cJ z&7UvM3(wsIz`*|g@qJbHO(Uz!0ad|x_y9DA+MY)4HMk`@=C>R3i=rzvu-Z|1Ro*<6 zE1Oa|%dcd0XueQ5waU-ArnFN5XUTih6Cd{9Fsi53h0WsZYW4-c^X4E$D)Cewb9Nhw z)IQm(H-mSmD!`w{e^eKz-?D28vP_k0LrRR|DYbv?)>T*7|Dd#T=U823*XS6phIln{ ze7??~ZeB$stUL+>WNfP1cESY9fES!|gvFW_yfx(puAGMT%(MMpODhucsK|}@v8#t| z6A&N9a85{bXlF;)R7r_*nK|drO|DXvR73K}EWTk^&Q2tJuW9BQur1b!IjO0oz2&kI zrD9%K`C3xflSgVBi(S6Km-uDG9oUi0q383ywgEuV2z0p4TRd1TZd}SEaq7f|p2Q^T zZN7juOcYZrxl~UV?Ympot#T~8P1Y8kH*?pV7B72TBo>h0-CI{7al*WwJLRKD+Efsh zzCzUq;L`^1z|>{kdjh7Wnvcx7`P?0=dM``rTfs3<+CvDXPBUwD9DcHMW| z#NNKxlyr2RC_QYbz75&;5OyYA*V#A%^zy z>L(%OxI<7P2*>m42R-nl>ytzv#j{os8$QU7Wt*n8hHH_&9}K6_zw+0aGjs=Z$b(Mf z`NuZcEJTOiGn^%-yF3hUD=PBmzsDNJT z{l*l#z>uaQLhAd|(?SG8_q|U?2E}Mn{|<(;8pjtkjcI!po{bBw5TE*as>3gdy)%=D zTnN=e^}v@{)}X!W$e=n7P<30n7y(y^(*hzRCDvKEuLnLz0=~1jt>Qwjhq0+o((&(B z9c?ZY4JFYFRA0?vx7|!NK3tS}y6?U(yP27PfTJcgW12Aj2Chc5Eke{?1XfQ?u{vj_~73SPsE$n#Dm_+n{>vBLKB%r(}!8a1IX$Pyz*g9 z@&Vg=agO;g75dV$`tsKK2rgkfB@059kn?`K?JXsOL{;Y{Lh8rx&I|F7pNZ}(j%aRO zdA%_4Cn@s-u==ZD_{YNhs|x&eg8XIV0(AR4B{lt7w*5{n@t+U`+KL3)X$Cs521eNh zy3_?a$$7b51v+TreeLr$F$s{H3GkEi(FzLcH4aje3$W5enn4MQ6A5(I432XOPPz*8 z!3YN0_(k>wg%<`HvxaEMg#?lYbMpkJU<83?LR_v8R~&=Zd4d`>f(l7PKGp?oU;5@3 zhA_8kl9xi}o{&A{Gs0 z!VD#q_rt=BY3_^pW)f+(9o&`_ao`kFY7(206eA3X;4O;gF^yQSi+({C!!8xO8x)%q z6c=9@E7TS%wG+Fx9e$nUfkEarAX5~6gb}Hb99+#BD_R$Ghfu-ELGR_`4(Ou+_3@CT zc){8Db4^c!>o6|Sgxagvo9*~$r}%}w1lj&X&H6AWQU9N8iC^WD46mb%$ihsD;yvmU z-^#~WiAGY1#`%(kb89687bP)>Mu%%fr#VMQCr8I#$K^ODscDL3PT~1lbU2&w3NzD_B!dL;bC{{k&Ka2Gcqn6;csH2@ z7Sxs!#sXo0b5PmRU%1%5G>gdy&V8kj@m7oi4J%i4AosOGy0X!3u53t*d`O;RNY1Al zMh%x(Ev$SUtUSG4a#e8d`+>X+Gn2*KvUX#OwZVlhU;qfrE&@ox<&w4zPBB}!!C0DeUY!fuF-w`-g>25nZSR)sf}#5^Dd1~J1OfDZz~{CBGP)Z* z4uYcHmR{V%oVkP?ke3BE#Hqb1o4zSq0hg~~l_eJyeZ_iCid{~r2)#3d+R2xDDikR4 zR#3QBFk)BM?UV=mVph1YSM=Oe@P?M5Dk9&SBwaOB_{>#^>Li|OSBz^^F(dw~Zp)t! z#=l>wkO{3SAT7rtGP-cYB0y|yk?qv_pkUG-pn(O%_~ zkjBKW;zrlL#;M{a|HjIa!77MrP40YAK6d>O;(>gxY1gqnY_Pe@scCntnBM z2k8{HAh%bscCjIKHJW$HDs|Zxb}8P~n(cPgpmiIvcYkne@iOZsf92D)SlnIN({;Gr zt-;w-zSe2j#8PV8W+_}`hU1Vh*W(n3WV2w;6vj?OnfClHhPtHJnX*dQsTTrjC+6(% zRqC)NmvHOab;V5TYp(4Jb8{uU>WkWMb2jb?g|*CW#bu|I#mzRj@AVe>^iLM_YbmDO zaFll{#wCPdWG{IBxbDB6uPSsKm=`aeb8Rf*Y^o|5tk50wOsyNe9T?l~@1z{uNNGe& zCzS3Z5Ng>v_F%=Wd2g1^K%V??YAUE}e<(8*zq~l%Rcihx&fxn!tid6#{DnV53md7P zag+|a60NYIf^_n37QT_vvXRm0p|sTDTZ^Hwup!N%idDCc3#G9n!_gA2Q98}x1lN&X z#Iyos%n@aKD0xLHs2erB1@px)?(kUq{@5eVIDYAXfkkkzS<^{VQG3u>apu@U!RXA; zDDS~&_g%xn(1dExFLcW>RPe8S*hIhV0G0c23g@r4e2#qZGY-^>!z;<&RfhJWc(VcT0y z&MEy$4xg+anzmTPkH4S(vN#!|hm{yUlT$MDdU&SzU`F_0Qj%)Un``vaVkS6z_P%7c zet5RAykzpD(!AoqZ(FW~SiQMio4GfOGfvI3E-xlp(tewY&nR#$j#JHcQBAhF&GawM z4K~jYTTYGM&rEa8SC&q+E6)^oEqqU#ULIa*+FwY&pWk#}`t@RATY0HEZLv0d4iYxc zk6Y`y80o&)Wi{Mxqqp4czL;-0Uv#jHfJ=|V^G_sWGY*F3ai$`4mb}uIwMv&tUQF%{ zFSKwjZKomLyeYSo-48Dh?ZXCeVaH%D*z?b@WjENQrgK=WW#5FE|zfqv&yVwGmMC`k5;#D zU$n#6qwlnNMGdkF4~Ma~V2Lz!@@j8!C}f}UZ1E^;QOWEOCS4Y)g{ zx0-|ZJa1_#%?g(Bc_+6NR(kjhiZ`c!xc$XrZ=`caMP;wrW4ZZoPd<8&c&WSGW397g ztq1a~&uWc?dOzJ`x9xskwR~X4ityW9%K_u~fj>|056<3Ih`ohGueI)Bhzh%-Qo9yp z<9m9LM$2&C#G#=|pHX_c3FOG^VX*b)r~%{!czGPpa@?kH97WO}L+NUEFoF}>3=?k= z<2d2NTROQPytr*HE^eSZT-0Fz4)-;lSPY&7gtk6ET3W$9C8IfY*=uCxUg1|+scbkE zSZ?BcSlyFkhc}*)FP}Wc8_QZaV^Td~z(1MCUBwnZn;blS{j$1!Z{^$rKhoow#OV3z z-Z}U361}O)ggNzwxA_H8!=@V)<&!eaEi#`m5@x#_}}{WKsU*4OrzOU;q58YjfozwAuX{s&{2hb8Em^ zQJ*>H8@MS4`Ss?7)Az)eZ(s7TQWA#tFz^-W@}-Qh2(YKA6|1Ymz>Buuy}3m=sbkw0 z`Nc{d^3XbkaDj1y+fuj1?sI2vuLKP;VYFYeR_MNFuhM_Z$x)liS=ZLe`#xis&pfHG zYCO25vOQojQutY)J!HVBc#p%PWS`Tr^x(jN@z82C&zj7d$C6Ej?IfJne9SZiT-jb; zb}nFFeIe*jBb0RrKj#+7zjQK~_XfB$+=w;0oVGbP-rc$0-8DDzg^9X7dcSP*M0z5_ zV0=v5;fF=S$6+Ctm(r=Z!Yor#my6pIMlNg>)_KAp7e@U$W?!>_Q6ZXGqrzh6lxZj_ zHlKFDLV#CyI8{*X)Q&&LcpzHp^#i4X8p}kkLL#5#ZcP4%SHkUY-;%dI@h;J>wLRDs z>%%+I^KI;CYz-9|G5G#m?aw2RAQzjxZ%wh>)(;Xl z=fZ8sB-}DD?ygR^eoYUzNIu-%T%H~2U$l5W!4jeMMkRN`_CjZ`qxHoVyUHYbgDEXS z=l@*WX~moHYaLx6mCM!2PjVj-`d}uI)9QEn>^k~Tu7<0xitHUC4B>)vPHUWet91;K zFAz7Lo{J%gGDg28cb3d0kz=oCj8*1i%Z+`nAj%YvC9TDjpk-G7-C4u-nknf^NN`Hx zH;^+k$W&52A7t5Z&73M$PR5dEKj*w9B=U2$o+ZQmram~`6HAOW%irN?VOB7E18dGJ z+Bw$TXa(|}oLFrawtOdb5L-c-3;AY#Mu-@Dk%B;gHaoZotgu#8+;GEQLRCY~QCdBR zI8a!-+rR-m*_h)fYsaEkgm#j}euH$cE-(2^FP zHM6Z5w*Uk!-3M-D!}bEwAzq#-`SlDHQ{#k!l9p2vaeUC!V2D0d=b z8gf>neCo8Zp8{}twJmTy^8ilTSt~iVi#?`ROdCfYAl2jWs{m06x?lj5_bi#UlbRo> zuTeETso-03^AU5y?Y0LUL(eIJP6r$kv?2fu)?uHcE|xM!>^UT;&=D_`;7L6R5w{KY zG4s1uYny6qcc{OSVUFh*q1;y+z(OSnKu9f*#0g>UonULfK3?S~FJE?AZgPP7AZin2? z&!{u;k`4%;_uRE@08bXVCVJ!tunK^kBkPUnHim@fMFCSE6DBp`c1h;f+!jbG$MjNl z!+EERIwT~?z})r~3zpSW!SbTT63>}M_Z3#!;r7rw1u)J2k+4vzrC6#Kx>5skJ@q4l8&Wr2@xk*K7lm?P#R2XnS89>br@HP zN4MID<)mVgUgr>#khMrWaJi4UzUmReluM|l{r2K#J4P}|9}@MUM`G&d3+=H!SsArB z64vQUZCXh&F>WVI9T=+a-$Q1}vT_f>M@w#v0WJ~Ffy(2E$OWp@D%Tat#?o0B1MOdlX%_AeW{(dNf zf`SnzB5_=;qAhbZ^$WB7jzHR1AevSV06QR9-oxPon*O&moT}%_o-s!SyqPAd`cbVY zJ{4|DRokfQa5eQQDMVMS;EEW~7!~uJm%&{@{zE1)uY_G=Rkc6KIxeWGHAi@=aD@y5PGEu2A+j@=NjV8dU7EFW7twd4=L4 zVx>^1wWh{z9ydyOwTsE$g@OjsOZ{d(0`9M>%OiW@Dx3s0q|3=Chc`cFaKcOtB3VCz zh^NX@?J33goOMD@&nw*g+du9#EO1E+)+WbW;6xNl6x5IF#xrm^U4PY12zp(r?}V#I zq5(mb1Nc2lF@MQzFnvXMX+6VrXz1RE1nY4$@?9)=`ZG#;1non8!#xL8TCwgv@Vr5! z5aLZ`;Y+{S_5|NXubz^830J3EXd|!7S?f$T!wjDIO6llCOl;+0!m+)>;?U#C;MbLo z8QUjI-u@H?3Qkzz!^5!gds z8^ms;*pGAx;Uo!XNW8|^V&fe9+Z zH4@LgP7b%%P7*up=W`-UZb8lst=Q>KV~A2}yM%F}?zJ#4?QqJ5MUrTzRCR+9Kjdfg zrogTCtX)leP@eXxtyM{CHZOPMWH2!Yx zdacX@RE3iTS@m*km!g-J->TkIMdPSRu|j?XEL)0=h3So{#Cltch9glW*|MM7Ep_ub75E%=&vbJfzG1N>qIs34NoK$32gKw-#b$8nG>L!I`f>E zWDdrwUbSXnQV!LdKPM<_^Kb95B56BnY@(uhmL}WVDGAg6vH+i+c*`3$sk;;y>xIqw z{P=^UDxY0;;PS+U5}bj>+HJSO>r5?8|0lV=|5>irLsz1EePt`%BLOrBW=i$N$?9N zjy`!1*(P?3i*kHp3I#sOpcPYJqBr|P=dsTVQw@fl`m~jlinLoDO??_80ejs<{qkgR z?iI07-cvM%VX1anlvC>4bAwx^6}vIL6|!_#@WK;v+BcbGOB*!M3-ZSoboYWQ(6?Dc z%(S-e=t#}_)=i%wQ%1~4kuRsxjk@Uvqe|@ph}4>C+a**c$EJ>&lRT11_kRpd{*DO{ z1q5six8~9c(Mg^4w;`8htz{0c+C@yI;#oxF?MZ2(yuqE*yrWI®{r_M6pA__12k zr;RGL%BmgC2EqHyM#o{+d0XAnYe`2Uy|RXy`S>bp@qH?EHiLMM{u71)U@n7mcdmlT zpKhsxr6!y2YcGdb*PYCZg$f59N@bZ9`%KCx^CV%X+9u> zn3ybogFdHp4ynt^bCne&hpYkboQf5CLHf^v1FJ$ZInUuTyaR7Ur#eLQb6yelit%;v zFXc$0Je^jhZ(Vw6XNS0uyTjTPfdkM@F zQuwAq3{%=P6gbQsBui0)=e8xUsbl47;C;8`To*v2R>-E571~gYVUkc^(FEs7K$t!< zf7gO%yrkdISo$uPoGTra_dcC+r8HSqZ2nVCs+Mg|3T_IJ@bi1IB`4iA*BPo_RAz)( zr{4K4eaxiykT;M;X<(Ia;Ib}awrO;eXEMUH?oDYdX4L(j*?c$pJ3We#WL9ck{-g7z zG0;E?M5ZpespP|~xMY+xn{QIX`~yC2&i7Wmr9+08<)_dV5#^Q|V*v+e!B6!9dzsZx zM>%7TcU$%j1&)URJ9d8uqyZ<70_Sv=#P2Ms?8eU71+Ff}PRwszN4nfv3fzT^-IfYG z4hy{2Sv((Dye0xY(OA8S`Arf@S$&vUqi_m+gbIC`dwpdJ{jxE9RSNw-Z<~H(4X|Pj z{2CD8z#0@_^3|;{D4{TTmkTvGy)Xn)m|9jC+QRCX%NjPq8vfZRe2F!pMmFM*HF6{% z49*&bR`dpgEt;6eG=!8bhIuE5oh?>q=kr}*tV~gy`ja@7qWI4{gdf=wtk@F86cRny zl3aWf0@#uhI876h*+BWGTwz6^8n%RrqLd!fxQ?RKDbtvVqO{GTsMVr$_)f%WQ3e`& zI6D$JlbAh}5S+zq7EBM$7Ge+N1Lw%>`bmRx)y;g>zk>E` zga+&LIU30a8)`V3;4A};Jsiz*15HyLEiD7hn;fm#11)flHjjZ;G|u*~18smpJst8rRwcdp^1U7*eA?rn+<qQy0~fi16H~OFO9h2f=Od7qXAcLOjf9^vT%T_Kr2&L;F=`k;SVXu z`c&|m3S=YX_l6O$&kVQ~A-~}O+?F)mj<8B_1MZ|VK-{t+yB+(xHNco8$X?IkLO<{T zj=(W6=9e zXyPvSqqNE6gS9+T86dVFK*D252ZKIg;UQ+_K{hr)epzNr%Y%YlhN61>s9uJ8#*Jod z^FU+8gFX#K_q4eSD8p#v#!M``N$0`JfMP*;u4~G$eQCLIdTlO8cyNuOxXWc1n`L;4 z-1wKrXAe9C+)x61-cuUl@~1@Hge*UffxJYw5TchShce~Q)_~77cn?1F68F+Vh^>F@ zIh2zW14#or>@aFM-rsC##3NG&B&J zN8S~*3R+}ZARUSAGBqFlRw+HV?V?Zx!yp%Qui&2N zxa!q64rvyFYM{Vt+2YqP?J8xe-mq}I)vzn4`79uVUMyp6hxmiaZn4V++LcBK$oGQf zGwh1Mh7b_G*a%Es&}TXuPe0;;lL)knJ=8no(nEbQ+|?6ar^ z)zHCe-1Zql)gQLlKFZpssR(}Rr7ilTZ=YgRtzN*U;aHvQA*dN!q!}rglu)hZ$fjLz zkw6I%)X^`}X}yT+sn#7e)f4B_TNaGntp2=gs!u}2r~haljaFlDWcuZd_X~|fBy-Kz zN7HYkyx(5J!zpBH4Drpr7xRAChld)~7}1y+2l5(w!h-{9Oend{OuzA(W(Wo4*OZuc%s+6LVPlwljBWO$H!df^JwPlc`R^)oN+Y*cDl|N3EkQV4TOgF z)(H%YgllYuQ%>7Nm_x<`5%YA{lU_(Ccdmug~%7qhoPmSSadi`J4&FMbChuanqr3 zs`a%W>*K#+Cm7e)+fvmNy4D|Ci6)7MCPj)?#?>b?V1o+k8!N7#H98+sLPS%MLsLgZ zp}*_X_OQ~o>YGnRNp9=+4n;FNDMB*w8;YMbWEEp&voy5wxsZvA?EuAc6NhpGaUi)G zE(N+``SVH`fPVp3k^cm){s@=)Ux6zC2^j%a5w!Im(TV`B2*rxPt_aThk8VYHR|INB z09S;1MMzf^dY+Y)mE7E10t#AOTwMHTBu^Q+(5TsP2%n=gCprG)Nt@Zc!udJ*@2wN>JEp#dt96Wqf z23|}Os{h5$0?=ZlQvW9?TjDVIKS5a)==}c;WrKo4{yUUS2BrL0C|mgdDMod5|A&(u z9UK3*lYKph*n0jql>L93Q3xn|dw2gYC|j1-_^|mFf}e1WX7iyYwI3sJ~LS@I&z*DZ4ja_IIA3=3u^# z?qB7oR`bWzC zX-Cl@sO#(!&$I5OZtgPu~vPpgs#U{w&pNbS4!pctlvr@2??gro4${_PX zW8cp7`==uXMqpW?-E9PxW#7pqi&Hn-$xAXqVA;PVsmuU&1eQ%O+btp~%m?oxwCwI~ zaaj-h-;&f`Y27B61fgZ&y9h0d#<7pkvgZ2;E&Hb>MJ%+pUpXejaZvSJ9ie4sjEWIj z)?tt2;Gdq00s1<1*1~vL|o@%Z%8c9$K_Z`&nJ|G4-joBB3Z7B-ybzBFD5mO zOE0E$9S<(1Um65*!DkGyiP{j81&F$oMS&iC{wD_Q<$~R`<>jJXXX)jV+a+JyvKKn= z3g&ZZ=|u5_G3shHM0n(CErP{{=xRL<(U;mtGKO4lraB&8Z)FA&0&li+6RmD`3jY?S zY7cMro|g3@HVY7qsoR6PEyyjRG4=fL_NWb=`|h~wPi3l)gPY{;R1GouuPQru_i1bYahngbhNy{c03oJ_jA}^V3H<$ zfTUwTR;4Vg_i`P?4IIb*JYiWlC$=?Y9ypdlsahy_HeQs7+LQ}c8-}0PYx)XHOe_BE zxgeWYw2P~Qe$qT&@P9;!pmG92nBM6UA`y6_0otCtacg^ym>MK}LkoEOP6|l@@B{#( z4OQuyenL!(LJg9_z^tQ3ahmBtdjr6f7N8?75_!T0Y6}4ze|yS2CPOU*Ku4GQMtS?( zix6gj6hq{LM93!3%-k00+vSb4B!_|yppAJ008suQcq07pBjzrK62K>lA{G!6+ntt>?@s;|^lt`lDMe=II?QyuI-i$;r8JAT96o$+V@%zbs z^;N0NCw1Ap{>K2Kd?NIkeM0~qz{?5mqduf$h2ja1ca>ljNRIUhfH?E5Tl$HGr-hdk zHwY?ycmk^TYeP}s>Cp(V%hi5nip-n|2-%IvD}T1b78nJ16Z0efwHN(UEO|9Lk+uTL z@Ga(UQNJRj3$W%eSe>DEU=reTQEY%)tl6vJZw0DOF}$B%HYuB`!SgD$w&&2D1yN2l z%sqcaKePus)P^2=A3;~vivQN8$l+x{1j;2}Nd{+n#Ng#o;fUIl;?vmyZFofzqBkWx zI6LA3ugoOCEj5ue_zDTaSpKNb$LD^OhA)xK<)8ODb=)DA8lx?DEGaG_4UicgnL zL6=S2%21c-!6kUZW%Ds!IMi)RartKMvgK+R>cM(BxQuY%i29Ta5;Xx#ZZQ)Lw@u2M z!X8F0Qimq2S(Z(J3g2Fe(xOs%oUl1Be3N_uKoS*NK?>3}eDQ*|PfT({K{BPO zKtWoxPcDxyL&5Dk5m`Yvau-w*FJgIAf0g`;&)Mc#6$&+ z5?cnf!K3m$mF1DmCTvqYF0)_B?P9**9tf~Cj8;Kvs_j8$Wai?^%`JRot)+Qi^c!Z@ zfw6_<;fHpSMOo0#g4K<^>t=$+&P)1)m4VH77H;2m=H5?Lx7Hi(d~|eLKu(OQ&ssS2 z`x!hcZ6(!7g?KyjO=pSs*>X%3w);_-zWnl&quSoATx)4p*L7_y$Bd^DwoJJ5Fe2;g zxCpnM=nCCUS)a@9xBL-5nNZtM(2NB}T_bX1Vx70&2otvXavtXj7r}nC)c^9Wvcd0l z+f~eV?m`^JV$_ltAmu(1ixr|aGz(3Wr!^Lv7(HajaZNr@JrbzN!@HzxrrfuK-iCLH zSVHz*$mwfb%mE)Ga$IEV*ZlBf<}1^#=(%2WhYJ(__D{X+>oi&6hBt25zd3VPY&uP` zep=Bk$$K@{X%dbmOt-3>kyBMrcV1#WzM4XPOHm}zmkGTOK7{Np?6tgc z<=yA$s*PSL$GKfkF@X;|N(NJV+V0GooYvJog!3$~teES+O;fGDY<(Gp3wpGAOV$i~ z>|Z8mZlJMI5NW;Js`0xXZGC{hY~_9T@?f%3pZ4g|LxI5SuoeZSPy@}$i{1|h((Cl7 zLEyEfK}ah4blXHE(+gX~i?r|sW}WwZEKcR67l8=iqv0#&Iz?g(5n!1Q=Zp_R$O1LJ z*{+0nG<}8YeE5P;*qwYOW_+F(;?$xm&|`Qpnjp?+y%D<-<}UM7n(2=&;THkC z)AU!b^HWRm6S48vCH0r9!)bDKYZ3O_JQij$@f-dYK>i)Y>zBXP6_S<{ise;+E@_~J zoDTxa>i6NaYdEaQ1{`Jv5M!YFB>8wJr1u()15;3e=eidZ!sQYT%}}gcfuy z6|9~a{FW!srY_iXO(m@_C_@u9QY0ig2sL#ZhZ%165Ey(=7V?2MR2mknVH5JkAT(11 zkew9jq#2U0>E~OBGlm{|CKdYLE3APutf?*_$0?{yBs6O~tR*R=mNjfr%Du`sY}+f$ zFDN_|BYb)$>_ea5+Dzzcg8)sZa3&|*85X|<(oh&i$N^&Jpe}SCwy5uvTqKS>?uxPfH_fOKn<$m2D7^l#YbS54ooJ$> zh;i2Ni#pt0N4N2g!055eXyvx(OV*f;x(L1^e*v;EmVVp?bq8o)%nU3>2@osH7QSv0 z#4C?lw;lO3IF=|F7uN)dV+ZlhBMvhuwq+(-PBfffE2h6NUZ_7_&oq9Kz(=hg_hXUw z*V!mp`8b;TFfFaP(7HIS(gzE>-}`#QW( z9uFvF`3^I6kT$hYE;XV)I!Y_0-!!#uHZ>A)3Cqj;iEP@3;N*hfv|Q6T*mcZ&aLB<< z`l=~jGw}}+)AVA4w8_GB{QmUWo%F$?xD&Ft?S8yXX2aHijCQTek^T%pvgAkk%qQB( z$lCFZMR=fV#Km*$o1!?p-9#F8kUnPASv}tFmc=8*|6%W~zS>;)w%b5yaY%xM06|KN zTLmpz1qu`j#kHk4MQXSecXxMpcX#*T?(WVGZP!|D*YkYed$Pyg2XBrt@&_at_ng=K zt#V)yM^T!GGo?jLK@pN{+&WN)r!|a@9s{Nj%CZ{Cdm4g=;a3WesYkemivEZ` zfj=!$a5d;QaWH{f5JWfOk+Aj{E#gNUa62d=M$O5D>c+_nj{7|!Yv{3B7_L%nAq*n8 z>|b=YO;MBskN`fSD%=>FBR*t8!FE(P9#Y(RznW~z4X`IpvC&QWCK5_^A;=G#j zZ8XI*EkzS8)x|tDC@l4*8_tc;nCAjW@4h6PirAZ%*;$4qyRX{W2&bi%*_Ftr5edcb z`@~{H09-+dak}U}PXHB&OW3FBHK*w@E8b+s2`LI0Tk@VcX|{O^xGhCtPt}l7umzIc zA_CH5Y@J4Q?Y!&Lk!mt0MF0~RSyN@4Rj0mV^jXWqS*zSxZcvtN`vL)H_-m(C&h1nia9?nh691tA^5v+Zqj&@FQC$Jmi9 za$h{nAueY}7mN~hzCn73I4hF7AsH=^^bxy8lv#X7q6 zS_07w%#Hwh{W@{Gy26v%PcLilS=J=F)=G!f-><-bm{EUQzm^XXWDu?6A14qLZFr*J z@NC?Nr6TxMM8osk2J!I*iSau5ih7~XqKyjbjfx(P?z)KhRRe&hQH2Lzji(7wz59<;+d5J?&$~&KeB^OpkKjv+CdE?q`eS1hEW^~*je-c(G7;MhQG5P##VnJ+jF|u%l zX>vSsa(-iSqjORob80VgybU%<=A_0}Q_>ZA(>uJlP`LD~Ug-^$VqsR@qRKJ!s%f;Y zX&HtQOq>}o^9;854DRGKf!7S=Lm{+lhQw>^aATVL3*gJ>l=Ak}xm6%CSt&zSF^V;c z$DZtc94=2ht_M*7cJVn^|2dwnxqDS}{F8HnWb=F~a{?dcAA8L|<(n5$nU86=K|2MY zsLiV8d`pwE_?B_zA=~vWxqLy+x=m(tfhA-COAyHlU*o-Wo^)3E@Js&@_+sPB3XNNGi?-XVtfqQJ6p%IL z@~vX1Bhw4a=7`^l%vj6!S}n0&t!Hc=Oh@ta&;65 zwqBa>on+{xI^Hm#sqZ`ZO zTlDH%E3R8Rlba2fn@e6Bmsy)FUaiR4t=pU1s4Sc4-C#g=IQFZyB9(1RrEO7+RlF-8 z!K(yZ@11(<_07p0a@DQ1$t{%ZjS2YcgGex{~k4H_ao2l zqV*mydzE)tsT?<<67pPDbzj4wXIdtNlAHOPf_Yd&~z?)!T9F`{T^}>g0P! z-3NDL_pw+G&$13#4fo#|9#YyJs&*efdUYhmzwyEW_lS)0sMaItBx~>Mt3$%P+tjQYdHKuI$B- z>!Z7{Cg$yx98x}is@+G*>o`z-4#_-+diIUJvTb9&xP_E02+Lko$bKM}^@R6wpEzs$ zG|Maga;3VT3tMQ5{A#}X(hKimnjEMaIVb=G;{sd@9yaRg;%w{Hiw5 z)l*mGY65mQ&&aAaK>!)r;{yGD&0}+4CnfKcy1;(E+UD6y39S@r2RZm%=uX+G5T6R zr^-N)70K5RCc$C)WdOMUZFc$SJo|-*)S*mQdm6GX~jj__NG%gcZ z$IY#Z!M>Ye=b?MnI@JPHlNOdIdn?0Rw`{O3@2s_Mf4wZ`x_2&hJ8N|E8>7G(zLHN- z?JZcvvuf7cq!Nz-?Ihhf=Zq0uya&LF?ONb7Qr`wJt}D<3ode78S#{sLoo9zw`#I=9 zu?kQ#>F06D8eH5AP$`y$2dExJN`EK!(E6cYGj*gTxb3Fo`QvA|*$Vsw?quVe@V97c zfLdr9N&P$STZQ+z)0gsq+nL8SK6WBMLh%=Vuu*~6&BOLMACqluU8G zD_#QSfZpd97$z;{=-wRz9qPaZYN|@s&&(6W|=rc=5Viat^upM-ZthGBL}`n{7-G zz4uYJhHBD2yYzmivU1lA1_j=KUb$MoPJI>*J@;#aI6*4ZT@U5zg zCZSEosX^|AvfQTfwQ2&@wul{PWRB74%xrd&9T}2p(!CBOb4!!5?jptA1hXL$gHADi z*k{98KSAmMjT93AlANJ(!tla^@Rw1&;)hvC8Zm-S!#22+U#A^8s!V5_#UoD^BVSN6 zb9Anvn+PUXPn@o;1zw(3MqiPgZ+p8&t*@nCv7fDAice*1v~#E$?JPt`Un;NP23(yV zOhsQ^EZz45xJz~bj*622I1;wzbTPhwZ`A-!pgoFd3<_09NZ%mtH|W4LP0W+9r=M8a z$x$dlwD%r^KL8yc(@Cq-ikdzK6|sVuB7(43N*12WAONN1m}*gFMvn=ZIS6o%3_e0^LT6}n6y#b34Q{~0eKD~n0I=#zmZz7zN> zn@hYHB!XoDbW#P0Z)Etf=x>%ELq{c?`S9aFfy+k?zfkh&=J=>)s83(Q@ttiNrLnaBb%?8(0PDW*`F9~w&2-elm73J#5hzN6C->kodj>$Fi z5$5>@u4~XK%D1Z!<`aevu4{6Q$#+{57Q_;4XbUSU^hXgE5~jY{(3Kxk7_K5L$_?Dm z*HKg)nWbxBxKOR#C=s;D$a1}VjFeY0s2I;OP52Px|f-29ZPsJx~^ z04awLZhk2pQ{J+MRLl@;nYJpb?BPd2Dp%fYnZw6aj;bJ4JAqr4D~kB4XOocXlff

%j%!2I z0HP$^aS>K}ix*AQNGG%7CO?k<7E(>rbT4SfLr3Wy$rMpD*U*j^81Xci{EDbWkZ{+> zRY`-29NLN}EVJtuI<7&(4{ehS+6_omdd~n>g|^EN?FN;Ozh|+5cBm8Xg#cTXG#^Am zJ9T9C!r&+0b6p320o4{|ee;N4Ot@uB`opoVXX{DzhI4 zozQ;4Pu%Alw4Xqy4AOb2O57hhw4cN^q4Uaycpw%?c#tBjtSc2wJeVqTkS0F?(v__y z9?A_m$k0*NQ=B3mF2x->$TFMIQ@tV{sUCX=E}iU2oBX6na5d@Tz2K7$Y#o)4;ZvkbT*D__W|P2= z(O0C)f{@c5R~3XS0$afrmObqYo&1!{4_lQC#yK5GRr#E*3R{yOJ{>B>nf#n>16x;z zoQ<@qe96a&hHdD`o{hpMzZ6%)HnEL^&&F3&zLrnHw#l=JNPp4|uu1dCt6*_#r$Tek#xZ$uLivzh>7FIRyk49AM zWG_~+sSA(z{(s+M!)*yJazvX{G| zQ3fX#9KJh?W;Zd-#vG1Hy3lRe1QrcDO;# zgmm+kj;`)YQPJw^Y5@TO3JQwo=;*DjEg%qRV`HPHrZzP-B_Sb^ot@3f%4%d}`CrN5d;Rr)C5Qj-BnPhVlEc%RYS$sy-z0}$ z9S35%*nb>{pNa$YPsibJ#ewK|$KfBvf$&$y;UC2T@|)xEkK#c1i{tQ%;z0Pf!Gmd|$HQp?(D%HiuK7g7jG?a5NY4Ki*}e6Tpealp+iTg}j8 zK7*uIPnxYi9{6xJU9whXFjIF)}`eYTGrqgrwLqQvmg`EIfy-`TmW`qkR8|J2?r zjlk)kp~dtAl`BG10pvIqTil0gpi9(8=;A=TaVd%$93hXeWN_{0Nz^U{LIf~PwjDA* zs-cN-qF@AxJKP+fnwOS)42|1%V%k^R7YY$1#CYZ0zBhxbGxCJw>+H4@vr;6`|Co#D zd5;U^&Mc~xf7Y?CQ6JvZ`yOJ?Ze{+!68;3%ct|`1oC`B?Dx?F zB2+0Kq`u=976*H5y}LsRlYjcstIwZgUW2gG<%PuLs{rq(cW=BwZ-4!^FA$KCNcf)X zrIL)1f0bn{t`52=yHlP|#Q^C&F=#uKY&WRM^6q{A*5dbmCc!{1!^+Kqh^wb`zG~Ax?zX$UC`aY6&Q#kK1dPb)m_8^A_v_cOci6R zkxE4TuTfx2GsR5nvrQX(@P*Bb6=Fpr*38|XSGS+R#%l|1ag!-Z8q9oYl&;iyIuzn` zkY(D5$icP^SDhYY+hS*^J-qJ0UW*Oiy#_hAB%Xy=?iYRM!`o z^x!~w&6!|$_bs+`)pyti9i&=ZjkVc@-Xs+I$`(yXxMfJ6DZDO7NsC*ck zmJn69cL8U(azykgOIw{@%g#`7#XN7erSS8q?L*m0&4@*7jziUHjJ@LiPL_|TJbB18 zBejR`S(_<5R58|M>&WxB?fuKuajL)CMU~Onu5rFP(><<#C&OdF_9CMhEF02jo1sl| zuS1oHrLPgfZs=Sd{%*drF!Pfm1HXZiFqFK%r4RpuvO1ruvo4`*u#Wv3$_eru7~m#4~L2tX+I#SRBBK(Y0=vg>IQp6qWDu?=o8#*U6;r4=P+I_OJTjP>Ij71nNS7|D*of7D>{|IEQ& z*I>1dI=A>Wz z9+lju(;6$TESmq|&(t;}GZy#)CnF5eI?=+h^(B#ulGe<3ep0#j<|V0W5@A8w=gE(8 z4p-wUg3TvAR2$urEm%#GUVrh0z2*iKqKDgLx3$`xexg}X`9il)TDby>S|Jyq|@8pERWH)qb-QEZFZg* z*{nb;pT4(w7i_Zz`Pva|xqW24`BnGf5E5$);DQDP8*l>{V{^`dG7dq)-nOGiv06pF z0TDvUkhMpb`;t^_yD|Lfv$N)qqx~}qyCxbG=NEQQZLKMrQSOJ>t1hZyjMxH0K6pcr zYF?n=OlzFgS#vbmi!D1mOL`ABMS+CaJvD*=W;tvfAQsPakQo5BCv^^z%l6@#U-<@J zC=t-21fl}ew7;?1+9GixMWR|B9XUOV0RUJX)rlN%X+K$we6SISJfMRkC z(9yOTI069V9A1o=>X$fJ3ORiZ)(&=Y)e=&2J~q#IVUwkW1T->z9)jYC2EfTNO(sGj z>;ae_JGslTyyblVc-h&m#QrSN%wEgJk;W_l;*L~ekxFC*wsj`eLSg`Tn8~_+D%Qzw z(K<_%J`9$ACSfZ}Yt`#4JsjdRviwat*}na&`wXpDr}I~VW)D4xO=pO!ruL^rB9`vu zr(bS;)7dIv-VxH9mh+l#u|1%5Uv7D)bmSe@@>hoUCXVt`8YDsf{ z*=G^iS3=0W5X}WG7zw$?mkjPpapHR$-H$z3uk872>J zC?jF0uYHL(Q%UGxUuevULzJs~99I}>s*g>x&pM}Xsw-oHj%&D!59VWNc-Q^#HM#JN zRId_rrm|2k9HDUZlkhfhL{@8fj!s0qeZ-B_i1gNo=GllMcv!1^M4t}h06e_DG<+m9 zvLP>Wj1N7k@pe>tX~aC7VKLPETrTWgNK}Mb6e?HLHk5JgV(8@@92eP9$kd6KEX0#oH2?p^4>A2^q7&TyP_fRDbz2V;SA}hpizWsgmA5N&2*x zSoZpZF2;wsLxT^xAM|zI=Gl`Cx!)LIB!3o39>RA@vK&peK25eIPO;}saZpHcGEV`! zgr&G4&d`mfc%7yos$YKGsdxbjsX^wcAz`UuWvLOPsZpn?F<|1fIPSCrg|sB|w3M*4 zw6e4es?oHp)3hAo^t?3uA3URZLATXLIu zMtfRv$7x0vab_zjnwjQmA>|h5+K>j@!JO@AF!LzGbBm14B9mokNRTnd zzR^D-yv;yze(FkBp34?a$gxH|&S3@!$zv68yBy6|-O0}oL?Qq1W@#sC-xQkvq&rPJIj3rLOy z(ZEp_tT>zW$hFiRQvQ+%)h8deln3g`saGuVD<>ksG$|OXu#~K@qckJu@E90_46Q-P z){Lnhl~h`&mwEJ-<1&b`&-w>+3u!)q4vYOq9^(aviI!L?u>CVZvZUXxm)b4Fr{irS9Z8nSUF z3evi3VbtnRny z`ysxkEt}))6NNHb9O_zL>bGblt|Cm`tI>VW?)+!Pg{agTS7-G z(|W5Ka$C}PTh=;q3RZpE_}z^2wj##1T%P*;jJt)F?TF&nbM}gM*}nFw^|l()jylo1 z4N4tJmK|IX9k25`IxO45pp8A}clvlbOO!gLEjva1I!CeE$1NKtBkoLRbat&bEa=}^ z;^|xw)bBK2r{3u3*uv`Gsi59d>NpthJc^(`S#LiR?YiJ;eS+7ON!p`O(Z#ddUF^}L zoYBKo*Hg{YWrx+ve!Dk&y+=q>zxRQCZ6=k@`l{qd^H{P%@E{8nP$~2KqI$k$1 zK4dkFOffF2JV84q);@v=kHIDvc_tq!PcB;B8K|5z?3>)2px%Z}9eYgP%ADM1y0ga= zIi;O9buQL3u`yAPJ$Z#aa0u%6B;IEoU|( zW{Ab_;9kz$>73bLpSewThthhsUu^bNf0lleitcju4S4R({SVW~IYXRLbDCl^Y?rr- zCgyrQh6O&{wa!d_Vm&VuHU9!IS|GAX?LINDIol**P0hbS>HTQIq-J2N){4MrK@JD? zO%!F+*n%n_szMdnIwn8^Cp~*&@x$e!5!sRn-_j?(QJ-34Pz(|%XDRXs0NMsv<1E{d zEjwH;;cNqNjFw%!mfa?ovzGuVkQE=k6+h-LgfW2NsFjeam9P&;amOn_z-k=XDiDA~ zpthP~y_y!a3hV(8^jKtXuIBKq-AcQO^0nylAb zqSo80)}KF{HWmEXL$=Y!w=tlyF=V|l61CBAiIR~S-*LGyO}07Pwa(l{2)Eu`j@n$U z+FZvmLFm_N-;(C|w)R!F4#^hHdN)t2w$3NF;8_6dm}*G}|HtMCG`o!!z3$uVD@uS& zzpzayfP@d+!DHFMx7h)2ITBazkY?||ws$DVcW+OXFu-{ERp~ z3zMVb-DG6x0I=G=guonu*A{V~sY3Qm6}G8qwQNet?M-QkLyW9GX{MRU&i#*_^MW`> zaOH#dQT z$f=%PCXHPoypQOM*q4%jM|H)#}Ujsmo0pvWd$}GLwtAU2n>L)hj~#eUa#W z6Ynbl03O0}m;;4MGlz;T;SmRlr@mSHZaFB{4ij8k4!g2FZ~nXG0K6w@3<7qpgYIK_ z`sVfR{>ySWXuIe4S^xf@mcv)$!omZB-z^6PVN=e$5C!93ECG^1ec3<$ZaLh1+tK2!Ari?Tqm- zL!9o4d}MZruk$)qEL!%7Dw{mcx=KOeV-csr2K@P$5G*6XNL0$>G5n#7K?bOum-DlF0aNqx%K8sVG?7SmRoA zNLnTGQ*ZySIh;^xe)bUzX15~#tot!e&rjxp_=`R~-!OM_DeTLKiE^_eW{oqf4^#E7 z5Dsh7DbcXneu}S`q~`lGH4t{SU~8DgmyN(w_udMarRmOG@t0xkNK5mBH8{z$J!I=^ z&!PKjlFY{X;tUqIb46YQxH0ctgN)4sq(CB2+U7@rSpw_O%J0?K0&!w`9%Hd&7}aBO z4M;pe|E@WRG4g;K@!u%^2lt#@bx zmHb-2iw?Aj5q@p-nHAu6#=>#GcIL`+a0lym(V>$Euk+#EK_&lA_6bY8F3yDr|1R$D zro+Lwe>dO7IbM&zO;|wBV{Be3{NAVERR@?AexK0Y$bdc(rb_&NQML($>cD$}KOpf8 zMj1HpT8x)qQ2LE>;Gm51chw;>a7e+Rl3-Zr^F-jV3c1wl zv?3hUPK>;E9blD&WBP>?L1P~(F9^qt8ezfXCf{X;LFM3yFQgMzkV(^p$lyuyjY`PW zwe2uveQ^Pqw!KLfLNaZS%||rjK%f$W&>gIaW?k+^h0MA!RT0g35VK8&%z5!%63zQO zBMY7P6XSy}1iVoRT?kT!T0<8@-baNlh8a{rmm)q-hAu@}UHcA>L}X#hah`m{D+z%r zVJk_I*4M&ARM=`7B(sWmEu(NUY%S~Bc*toa3t!LcT@;-#{o2B%37*QQ@0q z8&xD*6$g_9*UrNw$#%_6@`&v^Y<|+61_IRx$W9Z?hIF^(ZZzU>8Pm1%z%~`J*TqY4 zMY`Yfj68C`4fd;wHGpgkyH8ih_y5K{$s z?LOSX=Nb{jwFWv+;?JN7%RR=60y@%3&-6bRs-4g}~Q3F4;%f%}I90sPB?_)CHS{%Jwre%B!Wu^@ndY7n^i@xN&h zIKO)kxPN#MxW9-H*B%7!A0hU=_2&9Xky$I}!0KN$Fi(tM8;ETY%2-=I_ya*zU;KB&_ zi$K5#Sd0uLM8X6jgP};M7)a>Y$T)b&_%LKbiW^X>8zj^>V04HDJ>dV%*Db#qpZ{gY z|I3d5cOd`&|A9RC$3Q;V;PB@_zIX$CIA+w$^m8F^x^%se2VGwj{XUU*fP6TfK@eob zi6;+^&Eq*AM2hq0L>~0ZM*hb{9{lS@{>MZf_{&EA`$Qi2&qhA`e7A#R3xv2U`hw+R ze?Zpo`mAX7#UXI~>(<5fUD3;*BYE6EM)JVnKSuJP7{qDO-$wF4#7h2pBu^t$i=nV> zPt*a5L7bLCtmL=9UEdacAyfwjhW(hy%eEk9^4BYQP%M_#^>xv0C(e#ooUbkQWVqMo zMX}<5&O!~hug{C(ipGJ1TN9@?Vu#JIzV3!`N_o<<(=q)S50+_2kBN7E z5I06|LvUMLo^$W|`XG*s(g?STDB$h-g>)oPb0Dsan!lDtI@X9r7N8fA+|&V}0&VH` z9@@$jnFM|{Szu5Xrjal63o^+|WHd&nQJ(Y*HrQQYzGqiTt$yhjqJ4jn)mNTc(P=uU zU1KpM0=-jP1W&a7JdrIMepid{W6%@|4M*jt4ih)n$B2H!eW~tG9iOSCA~)|yJ(`1y zn)W>q-K$6r*`-N+r9YY@==@}-H?*!Wn|P}yMOTZ2yv_a_!L ziwlPbH^06qpD6isREq4lY5IypWjpAIX^D-=*kDcR5NZkfnAkctCe z4b`K6xvvqZPe_etR{Na!zN5G3`y4rVgK&~zh^ms}9JBjE``+C>kH+zL#M$r;MZx?4 z63m5L-J!Wgd=I>=Fc;bLL)$-)eT>Szy~Oy=tX<)GUNoxMGA~1Mx9#P}I4d5V=URz9 zE_?-XRGFHBiedC2lb;grX6lFpzd7xfMK4U{*VmTO8XZje@F^_O=)GJvcUxxD=ZyE6 zjtXlE1Ig5dnP0J2@87T(sk=PL+z?%tCb59Ssy=EYiop?wM-dO&3d}!t>gz#B#>T7* z^KMiUXd6!59d9Q4npbY{4_;;$`+jcpm>IjQwR43MAyiU!9Z4QmZfbuVo1WDyu5o<4 zEgt-64A=0u&g!E@S!MQnpSw7Fp zg*fCSKEYT*I>y@U+7Yh5=kus+5f#_8%O-Cxa*q^~<7mG}@AZDP#nFn;5m#5RxPMZ9 zMFS&;d80H!hfJP>tjd;fb*6gCXYDntzPmE-&c8Ite=xDOd$XLNq3V#EYJ7u3&|+di z#W8`Obc?0eVgPP^K>AX7#n#Yux}N!n+z;W^-4tlh)x2r$I0CRmjy|$ZO}SmjSQz7C%m&R1a4Pm5@93tJ)k!5%9lk znKIeJB|pC-dnGjD{W&K3YF8@gN@Us4wC)k#1V*w=?^laM8jC^5*PtWQ7#dsZTAM6c z+bL(O6Q*X{+Tkw@$p(a`cHJ-RNFZ8hga`4^A!nX6NgzJgYd- zw+ku2EsAfI#sEZc3OxqE-a_oC1pojY&}$HTP-30~Q7$3}HaUkBU#S-WhsQ#;qAt(} z2XE1V4uBj7sS;qqLnW1CM>SJbiIi`ztx6ovdK}e5oODZ^Tt7RC(6ZCJ*f9{w4eIK2x2yhmd*CNjP&>XcrAw&Dju(bPVA*@3bLAlKroG3T<=?7DNrG8EtpAi(JX=Q`BS5K~TQaRP6DKQ{ts{WdEnC~81p1M<*+c8uyX=W{!%OZMp#8)PNzMop%?Lisx<=%H z_uhy77|o2T*^@}i2L%tQg9r)v#4FO&!xxS93Jrh+_nFpo+X?YPiuaLg@gbzs&eQf$ zk*OwNa%%MZ!l0l$^~KHM81EftBwl>GPmB+ClHkSXE114C{mXootNJgAL4BbzB`AN`e&R{gs3Rk4yj^EiUgq zLqlGc`nl`)pY;U`9tYc+`AX3Rw#tP#Lpi?48<>^`o0|or+z16wg!-fVebWh5mS<^q zadbiF2$%m7i5|ww<$r}7iY5@2=<1hT8rZMRk`8s~S`Kpy^Ehj5+*n9bRiG{?82+R&~tTet+8KS_8V7LSS+|3HN01~9>T zE};T7@u)c#>K2h1;%X>?o>cB^E2(R%NT0|rlITX8Am$eTAUPfslDP2IG}Vkb zrhv*~_wG9tLpi@;q&gBOM~=qszD_l=e~=bFn#OCMlE%rh08NW?OQlXqwP;IKI*p^F zcl%bKR=S$@Vl_Sr$})%sC{Rd$%Wai-nr6tGzpH8{fawzrz~VCE(_%} zeTF`HW68#QG+C#OnNXK=!!7%yExUm}d_67-i6Q4`G-V4T7e&#g*W6*x%@IvE^Tt`m zX<6wUU>o!T2Z-7c-yw%of(JF=+2k z*`CFxn3Z&zPgt zA^#0YysTcK{93jmNs9Z8BA3+y_3-?6ip*?bMd_B-Nt9s|=p|9=AsNCYiIgQS zkBFT}9D>S9qGL;ZFo_3FOZ?lRMO-EH!4}W~;)w8aXzNL7Bq(x*t}IeA3Oxo9ScR^( zMWtIn7v!_E6``vuWw-NRLNDpciy5Fi<>i6KF;#j{OoiN9OlXZ_h5x;ZmPgQLn(_*t z*orPp=tGi9PvfH3FDt!blS~wMt#XS^S=S@7 zEvwqHAliIX{S8#Drc>=8S$&F0WJ5hxZI6_ehFN2}Q;l*C$>**?Rf2q6DaEXSG>2v2 zRzM;j)mqPFN)nDkGG9J=s`qf;jV-F|CfQoX;Q<=}v+n*QQwT{0BM&OflC@uvjV`@T zb-PYatsb2#nl?kc@6?*XBh$gQ{tkMc!Z%5(j%@z*EPILulGOUEM-iIXpNp4643dK0 zGnXgdB}jh&DSF-^`Rb)^D55dUqdIn{aZp-Om9Y_(i}s#A8|<#+*O##=Z6OUb%_O6Z zIE|Ti&Gi7N4G`K?jdFdd?V;@VQfI+8;HbtAgXFt=tV(GSEent>VK)}oC# zUt0~%v&|RTEZcx2O^sNjvTHjuFPz(D5`Q%g#EF zPFd2`5%_JCyBZ#_pXTRb*)ky|Udv_Au z9`g3i-Why1ks%5z2rr`X8)KtV>P}|t>u~GU#OT*G7?{iG;^giAGBF&=*p-kruop4x zGBN0ZLGAf6SXFA+V`I2#e8_NO=tgC~DR!?vHsWqn_E81g6*d}FITUsQkC}kmSv45k zAN~YWtk56L<7uzG+jXqe>;h{G_ZW72)9HcTWiQsAr#x1G-65jiIZZki-Z`qbI}%hr z$VoTUQ4kr+MUd{5k!b0e<|CRMao$>K)l7l?I^S|+mUm2^wB0(gM{#!4G8;A#RW|Y2 zqL3wWY&)`#hpl<^V)9hEKl!3)U_9|74^1GZvrcErmGWe|lQvXrXdm`k&MJ=iZ2FFJ zV-4ZVATK;Ua12P+)SN^U>dq@#fl67$K!7#t2}7gGVu;n7U7bLs^I}LS&w}sGeifK2 zvM<9}=V4bteZk|8&|)>5+kc@?BIeBNBWMN|dSc{cwUsq!_>!c!`x z1*y#jIcA7t)q;`Wf(j1AfOJtGREQMmBX-$qlK`D`BO?Vao(fNwX&Rkp<0lCtul5t&16~VizRq2^|@+=jCOtLjeq_qNOf_&>WarL#b zOZ>#~wQRnC8a{$`3Z8YT+@2;f0_w+(9c?yO-7SJAhw53DT9=+Wj4-2^Wa_>&> z^5bZ*pz1#5y%kC!8}N(jeKOw_aYJw<*#Wh2#<}>x9l=#Oa&U6h0cGqexe6{=nf#DU za$Ca&L~L`2uYRce3Zyf22&z3aA_rAd@*P2aJG4F=fh5<>y+JLLM})cSwuYcCzGM8} zeWz^TG)2@gX77dvImrF$7;Wa*Zwu(FdU8|nB%~ThUUhLo(i}>c+(=D$|-LGHCZBDF!kt?12O4YUdVvZ7afbT9){`}#;p>+co4 z6>t`4&->4cKJRQ>{ht+m>^;;i@IQQ6A6%=5;!g4xU(O~}2xUdA=zs9#KJ=z1#OjC@ z{dc}xABbY)r==$m`<*ZE`-(DL`LKWf!IzyWI8g$0Wiqe%a+>!yCoEA%+M5E0dOue5 zdhXu>QJ&z)74}UVs8m{9^W|m|shiH=mM*|P8!!kdBYR2qPK+R_hUVHf)&(G z@!_L(Z@h4DQ3>45p+yE9`sF)cUV2FSTB3$Ty9I>P@qI==ldG$E&j%BvF%AGPlOg!> zrVhSd*jK+NcO$k5l)eHyOu z@-uzp?OVu83g7Ya`ogot_44m{ITsHX&;ocy!StOkUrh`AkMJ_>{|GN5=<=Qa8(zMq z%T)hIy!?YMlmAz|{DUr&{};Uc7hQ(^C%pVOT_*iEUjBzJll+O7f2GU6X{5xHS{Ee4?r^^IC@$&C zy!pnad$i(w)EM2cAdAjNM8-1h?}BEaI`n^43em5N`iQlo^7 z>BZFC!mmgW#lfdlIK7f{t+V_}jn+#;CrUdlOMk7(h_U;V!IWv%mgUw}P0at>kWU~j zJtH$KJ0~|Uzo4+FxTLfUah|rSx~8_S{y%51|Eyp?TwPn=*xcIQ+1=YeI6V3vWw3Rb z$6Ipao4qOihZ$@-k%E-(6Vd-`2K%md|1TZv|CYg~{8tD2&kQ!@O8xIhW*D5 z_Foxn(*N4QMr5#I|ECW2GU@-RgG~~o{a-uSB;P~W|G9$={Sm_c4;^giUm@&&b+93( zi>W_D*uQnKA-{&Of7yr&BG!5<;)A0tuxKSS8pBT)i|&5}Pk zUlRk~|1^Y+_uESJ$4C^%0l_(;F4Zbd6{~p2yeJ^4E8o~zuEMfn` z`9R-G*u@+Fa6a(S#1GK-59b5^1bu&VKH%S=?7w1%0@GfIi%g-$5S;LHd3NeYn>>>|a11 z?ltNA1@z%wlfGX-AMQ2j`vvsj{vdsSgFc)er0*}#hx?QC{RR4Pev-bQpbzIS()R=O z;r>nfetw(a*(}^;f|!mj>#GBXTD!iK?L}_uGi!0{;dPsj_DAj|bn+g_cUyKjMFSd> z`QPF7*xyo%L1Cj60AY|gG6%$F8PoDpQxZ$MPWs1X<}N?GS4QkP>K_j~(B=_LBlh(Q zNJzOy$0Hf`IuPSdKw_ft%1d=_;?Oq#q=a0$N5)}Nkz@hMaXTxo>@bL9!E6to!+2M< z{L%)xSPN6b&FG~cPz|Q4B&Eq@V#(xlhi1Y|*zT)btiG$GA8PU`%y8#oP{4~HDkMwF z{J?~zgf;hr^?gTugg>{W&jyrkybP`d%P>3fGAe;#>2k%aMY;NBj2eTt;jM9rc|mNL znh#zEb(+X9M!_x^yx1`s$`NjNmoqX zE-Ds1VKNH5JwEgQu=iGBQSW`c_RtJSO1D9$Ae|O1V$qFsgTM^k9YYSo(A^+8G|~+U z0z)fG3kXWb{=3$CR`0!^eO>##-jjU>CmjE|fA{zP+{{Py3KKXpCa5HS5!B=u_2yI0 z;nvSoZ)lQg@7uw0;7M|ElBS>;7d>={?X7dNK7o1S zYxIyhpO8|wlzEas33S9lz9xBzhg_rqo%fnJJbjnvy;R_^)x3%|ZB%q(`H7WJ>3N@#!DgOT4j=A<6gO zD;iMJpiQrd&c&n4`t~IZN4+rt&*d!+)Z1^-dqKPO{h#e2;uLL;{wfTN6o#+6`I%y( zY|cM0P>1$?8Qz0Mo`1ZtF3>RW_>ObDuGg0(%f_JzA{T{RsW0@EC&I5}T->)>=tZgh z2Yna88AtQE5Bu)cISJhPo@ddyt}tCRQA75f?5_7_lAp&UZ-lGOi-k?`-6p}>HrMb7 zZ#qLf$^J2!!&*Y=ZC+dckh!I{Llk}BHuG5B!mf;a)0X#ccZcU93Gs2q(ZX&Je~SRN z>~TD|3{4cHuax9X{*%}0ZTmE`QmboDCnHbv4l+U&)*kro=QAcOPG!k@OUhSzWC2KF z(@1)|4%qm7zVN-CU-2uW%dYGE!DUb7)rD99|9Q;y_TxK8()%u6=U=$>XBbAOy$^DI zJ=dQz9}l+6v=mleSZ}wT^&iU|$KGJP{6;Txwi>H&GC=~}sac}k@@?n;#D8Tl-hSDH zVwSyh$-cUn%elJLBm~j;RSWoG(+~oU>|oJ3NDN$+*}|$Q42R25^e$uPo=?4n+Z4A3un%n z!O0Pe&{2;lNWm#eiO{HuDAU5JIF5Km8d*()Q_G?ispc4on!yf5MY?50wiRKw(nNWS zMRn8Q^+KW?KvB$tTG#_KQMMgX;~dyyV$tzh(NhrInUv@hRP;P5e39iiI)f(W6HUZt ztr#GQy$*@kM8%ZM#O%yO?9s$liN!L17mGZE#MY<8o~A^e&&0MI$AUPcFvVd?EYY|d znD{IR7;hEqS}`W^35-)6MlOy`rX9y77e@oerf!V07BysIq79;>jo%Lry{&}L2#r5C zkLRMr=S0VA<0i!I#WM(^A;RD=9%zDxT!IADgaw@t-jN_T0J@J(c*JCtcuPEy0FkJS zhN+?xNly~h#p9ldC(&vrX=r0A&_fe+2a>c9NnxTdjKKJY;>j%INy5OfX=A)NePi;w z8Q{`1%$gP+A`kcWOBNJ|yICc8HYPYKqT#Rn;JlnEfyMAHHc%KMC89AU3Y`*x_VI1> z#S%dL3~fUZDcT5_mG|8gM8*jsS;;@d$~I96Q~*saLZp@;QWH5-<1{!C3EH4)ab&GF zvR)fm4)w1DBe5KjZ70ZT1oBWcHIEkA0ZpsrOly`)QI$vZpQPbrr7@zAE#UNi&U6;V zw1L&MDR9~hlyDfG){9PWJ4weZN|&D#&-iSGT(t_A-0+=m%-9x3425TGpppCHnON$X zyWq@CMCMzi%yWy(BXAaqGYdU{P@>7irAx<`2)Lm2yJnsB^(25ymMi-ToJHB>hdpR} zy(F7tGecH8m^n3*evsg!wmB~QwDD;0cS}Lt#!_g2tR&;T+zWC@lyg} zS?k76yPNr95lg{3*`-T)F)F%6S;pc~nig1eW>Z z;rZEF`MQJouTH_{kGS%1#|!LF3moYR9a9Sg*YoCZQz4rO!;&oHltLf6R5|OyIX6%c z9Vo~H6rxiUnpzarR1~9AS(mK}l-fKho8~M(?JZTY=qv9FFaMlXK1Wclxmlj0 zP`=Dou~=NODu*z!s#vowry!{O!cuu%uJXrV{v$4u{nO{?5gI$E$j>E27%GXQbhK?8pk!)2bEHDsss;hXgrgo5VES>dUjG^pUxYbHvQe&uJxJ5_ux=kVNcr zPl6k&Pr+3QC+dFeHF!GJ+k4gA`4BN=8WlIORCATom_)TqNzIl(?XArk3f;1M^u)@| zP+3VNZzPcvy~p#p{JWAM&eAuRI(5@~Z%W{GS9B!cIgeL_^;21ud{yxJCxms43J7P3 zhSzfi){@2dK&3tr4f+T@6*Y&7IYYtd8{>gsL?Tiz z-=i^Cr}4>h1Ei!;t)nq(qw%qNDN3iQ5ZaKlg=rsI%UmppT{_pKyw_9}iCrz(EU(^N zuY!$2Hs7ymZXv{OV>xSINylvab8U*Q?d z5cVE&_a3P9?%DKyjqKel?WJ5B>ixXcyF}PG%iTAv(l=?-_b#$eZnU&-XsEA$tFMQ< zfglTnE6^XT+8>hEAJ)kI# z#e=VO!`0{~sZoN@ zdZVjeqw8s-n}jW++w-G4=c9WJV+8wBW8d}04!y>X3Dd?-TgJ}k$1cyu2tbVEn9}3e z`s27@?{WO}ae~(IYj_Le#QI@m|FPJa|3FSb08aX~fc9tJ1eoi9!v3qa4(RNF*A9s6 zfZz@&?10w}`0Iew4w&wM%MKXrfT|Ao>wvZn7T#EC**=PHOOR@ep`)nyU{YrSS@wb7tf8S?g`hB47U-#J<{upTcCzqoC^FZ6*xD@?g2HO70rD*>$ z(Dr98MfyW0C;4Ng?H4RX`sbClpZ+=NUsu{lqsaawp#O@cNd6?C z|HM)xz)sum0y==D$o?vz16YdW&jLDtrAYoPpaWQnQuF7+QQMf@w5`VW>O{*_Dp2TKwE!li!4Qba$w)bCh| z=oc>aA1p=ulS`=SR04{ZqaXd^^D04J0=5u`7n<9BpiKkJY*mAJ&ixm}=-^a(_&0%VN{1Dvn zKDp_8zJNEU@^c+zQt$TyG1d-MD+zk!>i0t41#!+NrYq00E?*Qq@P7O>I^uaATb;0q znYC*A>EwrpFG~c2EngbDyez|gS*q*$ShKa{c~wY~u;~x3mv_djQ)^9J%WYYsb@K_P zQ4!`9s^z8n_X0I@H(e_;Hn`QV2((7=4~xus%k*w2Pa(;LG~OV@ZPl0w(L;*$#o^^V zhCsJB<(USs&{}3HL^wMd=2nu&`_)`+3q4sPQl0j-OhxwI*=%2xNGWEKoslWGZt{DL zhN#A`w)%wgN&S!qD<3F2i?$_#6{@83|{cW^u zf%E~e*%s8Rd1q$C!G1TA=E4Ntei2-xgr~X;li{T9culsWc@$~X$uZN=Z6{#Jq^E(a zA|^l~vlqM+{}JDThM%#1FZ7I@} z$KY~bvy*4F0@gUJnc-|&dOK>h_+f3MV~mh9dVaO&_^8QMZ797EX;B+?1PJH$Ki%{D z*XU2j!_;bXs7Kc2C1p~5$TfhkOU2=Wuz^S%1o+k!DzuV)^VL#Jb9aEo@!HvPue%j6I5(}G)^*Yje zn7F`4&UHV(ig!c~@E|DW4As46ZAO2euFu1_iu@|N?Dmpbys%5D+`1U~b{pR7;t-4r zQq-1HvU;_)9l#dvY=yPeIbgWwqe>gjbFjGZ?6OOMXiD8+)7 zd@@kiRup_R$6f2(BLE^r9gJhS9x({=sOPZM>Iki+A$R9QMY7!SD@%51n;{ES^y&&D zd#32vw@#*F88u`{#&8@(Z|on(5lw$5l6@+LtPCGi4v(gO7rAIjCUNZfQH)GZD~9rE z(7F~GOw4Obj0|2Bb#s@4YJ8%IXM)))kdw_K1HUxIf zDTV++x;X;_zX^)Rj|0U;mCMKBzm28QCe=s9VLN$a-QI}9_!fJUmedg+Nj^>DnPdEC#l#DJ^ zCTWr*f*33g-|m39A&5y3@Qv}9!+!X3M}ol(aLQ-(*r0G?Gib_+UvdN|u?;$9nFR(@ zBE}U*d>r>wxMhWy_k*XRi9FGW1(vw14I-*=uiQqWC}`>oETsfWGz~}GC|4pbM)b=n7(RMu{BB1UN1IBBwJQQ9ahpaA$kt}42UGrgbG zUBxPW$Ot*Rajl6BojyG7sR7RD1BDYRW(=^TtQTJkY0T&sNZf&5n@dUfh9-OkPiOu? zdo3Q3*{zN^FV3Vp%IxYuV9pX|YG<`B!|_iEH*Yj%wT)*G4-%HqW;ei6DXa-O6~m|_ z2wfYqQOlW3I)o`Kv^glbtlM;ipP)JQMrmAJg!$;4T2P_@9pRpMZk1g6O%Xc68bog8 zI8v&F!1*M%NG@rGGp`VoCNGgU2G1)cP4>Mpke4NptU?UqDAnp~2+WkF^~j`b{ zFUax0&%Uiw62;;b8&Tp{Ra91jZ>*GANr10lky>MokMt-F@+)pS#q+&2TN+Z8+o6mv zBT?qdQq)H?LD^t6q3~ird&!X;)Rqe948^PFLliUV*KP z+v8DX=~qP1j9W2VWxZTP+Kd~2L*k94dNGv^ZXx`Q6)2t|5_h*Sn^_XKP@>wjs)DT< zhxG=s+9<1pha30%X|*9~sUSVBicXDTRrE-BjSj3_N)iX_)>(}XOWA!S4p_JLxlxtk z7WPn6t=4i`>1M4mOSvL@ou*MavxY8?46;tiu~M%Td+D@Jg{8ve44Xx#Ue&L{vKgC$ zrMdp$a)o^)_BXl)1=5<=k=O>F4RTqP9xB)uvkmu2>u>m-VFgQ|9%a>rAh8sXD5;LB zC_-$KGgMrILp(jUlvOE5(h$~kAqIjFTfrS8{S1qJpfP(6Yq!ude+a8pp{ckO%VMrc zST0GAf}mMw*@bR7tX2|>3f9~**W7m2Owd8!(ka;j5UE~5&z62<3xK4C=UNCx&sqQ` z^#T1{#&uPdY^a?796$r@#!IH>Emn>vggbdUbN7b#k?I0u?j@zVl82hAu&=uDf8p zE)lOTv9vA;!j`Vz2o*y&!F{RjhkD)eUfqh|v~D1TraIrPMtI(>&d~Ers^^7X4;VI1gjz;0fB zfNu`e&Vfo`Ul4e{*~hZ_7I>B0qh2lh5!%)AQ%AC0PO*Q z+TfB>U|?Ybv3Z~_@GGeS@&f>=afd*8dHIr(;lG9o0MG+q9RTzIlm~!2ztSH7{QyV^ zNEG~x+yktKk@L?M(*Kp8ff_;l|JS0@-+9pg4hKE>yQuVc9`s+}Aklx}LH`X7;`@UK z{Wmy>^MB?k{Xg1wjA>5!FucsMra=dr_&{ zF{Sql33C5>Fr%thUi(g)ok)20MQQ8Z>)NGh4<~IUi8_Aslhz&|E-7O3f-rF?1Rx(n z>D5r5!l-XR=Ij^#)FjE)LE;qlx?}y8s}wq#;Gwd|s1y5Y`Pc9Yz&sGD^W7DhjuLoFNDMH)dGA+wcD&Lk$!a=06|ZL_rd zO04xZom14b%P0DG$9II;{G)vsR+zrgC=4&e!RUa|raFk6&?YRFh43>Oi1QJcI6nL( z_owySs7H^i?ZXth7HMxZp!<|F_MACCGZTJHQGB8nlMqaCiw$kruMl5t8h;m)wcK*x zrC@bZo!)1D{q;f3_iw^uygu=B|HFUVGP(d(ukL=JaX&`@02`WvdV75R(sf?wUwn zY)}T4jOSzSr3sj=D+)EJ`$dp~p0rxtyMm0%nS3kI)&yxeXIOr(St|kPaBmqL(xga=J{L-lIRy-8DEw)VQtWeU1vbOP=fICt8k)E4i=>vj}L#;Z3b;Qxk%`L;TpDmRQhSWejSvR z~vFC9YK1~{kfYgLYH>P;V6M65R*h|VvR22E^ zJLFup#(gFzQ}+6DK596M>EO68*)$F29fe{B#)Iy$?Zcv5KHr=`^+K9?-Xv5=ym-@%bKDLshS=J8Kh~0ncmjGW#~5Rl^hG z6gWUGk93eG zyn~Ex*K15`p0aSig;!=Tnr`~zGL>8RwD7leHocE{4&Mgu+1bb1P=7w4I_aNy>Fl{g zvcw=N-nX%^?>${V8E1AhoYD2tWps0{aM?h-hpf#dpTD(-Xl$j)N%FfBW%}}C&8AV| zwoe&CdPNt94pyc16{WGrEq&Rh!I~Jix}*4|x%AWa0b;%y3x*B)vBnX@Fqalj#!aKx z=Jp;%_nh72HG4zI(nZE`?ooR3yXdBv;JKbo-=Q5-aQe0bj@{!0YW|m%zDzyYp=axjDv&uB!XIt!t(y&_ z$0JwLhwI6w%iIoLb~g;a4LY>#2vvH&kZrHYfy;d5)VvsCHrVoXm-zHNhwA zx+T)kPYG{f^y-a@s_}MlG5Upfw;(_{J+B8$}}Gl^#`no>M0ckT}OFs z^ttJ1`OD7uW`zba()cMtg5OsMc})jB77uT!q3E^&WPF#mrcACmAMNKAH}(3*rcf?X~h| zHy)J_<`l6;uK8V)k0_ECp25{XLPFBR%$&p`iWD6RAQ6PC;T3oh`7@F|QxWe4#qUqs3zMfLUznu}AnC5H0M5ixsZ6 z)iH`2Nfn)m61i6F1sF!q!pMo<$dY7O(|+jTZC5H0+s&ysI!$@OqgT}7aj2ZQCR0PI zg7~RJgZD|&h?z(>C^rXK>+nN>XF^z&nSQlv0*R%m@lnE+a6;9y1XDacMPVNaT3vkHz91>X+SfFwl}B^EsLniNecf0p$5NDD|b$Zf=1TiIHk zB;DUgo}_d(Q%tTfN!H&;a6;I*G$v@H6KosdWyi^PG~sRoToCclpn=54#c;jilt?8V z4KzGYG%OUs6^=-N;U~wPz#pUGXjbhmcK054L`git2J8|gj)*}&@&JeS9jEYFrWOXI z=6p%=K_@$-Q?>F_Udv4};St!V$eEPEo<>qs8eeCrBtK#JMQ#BWDIbG|fIKtynwbJ10D=I0L4H zFcC-Y(sF$mfVU`R=x(G=qti*)ldhDwFlN(EHW1O!^rHbd$spn_I+GNYmHR$x1DsvP znT@-N?8#4h8vcwfL^*v$DW^X?%P~9yFl2*^v)Bo^7*f-(b7is2BENuh-)&?HHE{}i zWQcBNVR`7^w9Yy+W0fh;l&Y_ty$a1%-po~v;Jg=seAbkqpzJxfp&=EK z&61R#UY-9!C*R5=*O*Sv)FYqPy5KQgCN_cUgW2q1bb%7I&=_23##NY;8c^`)G!LgF z_s#9hH)46u35xtT^F$E2u1z^19*@K5kdpmH(aMEf5}8j26EaUylQx0zYtxie&H#d9 zB%PjZM4>fZVnkDBeo6kZaOvsG(zP$~aLdx#&G@F-(i7I4+SJl-JsAOwMQ?OC2{(+e z!SW47WkmIN&n+W+v8=8*$|p+h-UgfXNfd?`MA^Go*K$M@h(vx&ExImW!PF>NiYk{& zeP+|JTA?&$vpZ|^ZPR97$L5Db<)L*YTE}#IlS@0=iZN2daNE9!<))dy77#Bn9@ zSrsY$8$c(b(0xNC=J|#O`G&6f4a3|Urn5IJx=e>OLO;Z+#csv2c~&dIs*h-@-Nj)E ztM>(L426bDg>1y}5H%^Q)fEBpWzCiqYZ6A}C*|v@6^%4;vMRQz{WbpNcUk34ZsFFk ziPSwKtdoyaeuAufT3Yvl-b7OpTTi#nxU|ktw=4o($udyu8*j^lg~7*Ju|Fkf&z<9l ztg_pR=LNfLmDJlHgYi+p9AQ2N{XuppKc>6E;U2|aNOsF2p8(H*2ge?dXo^3Tpdvll zqlbchB~eydC^Z^ihtu?4k4D${Mt>E5*c^~e@e|j|DSFftkiXS%2KU7iYsSS7F7-r} z+e9Q2HoKTK#mxm5mwH|HHHO&~hio)EW45p!H`Ycrd*XX1B2fdSEsmWn)+x=m6k!|u zuF{^d6&iDzoy*x&MbK$`vy-l@)fT@MlP7pdrIl>B)8bR7d)V6s`i=x-3F&;BnABUh zb#J!$#!Hn>!Sk-5&8GXJ-DWvm^4GcyQQZ!xw%5=ub$=AoJ?z&qDrTE? zRoV2Ti5m5=*rH_hAZg?2ew&(ZHFt=y)4C1OlUcueEEj$C|USu z399xZ&rsBAd4`+F1JAMh$T42r3MZSmS&1=Y!6hVVE2wLlI z^^TKE*HJ`|hg=t>X?;g*XJTp7i#fuES87Hs{@%{=J$|G+-cS{fUGxp|i}wPI6N1tc zcl9SkyeGucCnU&QC!`i8WG*J;7$@&bPd?P2lqdI|R7{^#Zk<$Jm{hx%RA>C~j9mJ| z3;hoo-XFBmKj^f6&|COmz;^M$h;hn9dg_%lD_9-FjFFEjH+C5JofaJUSkuyH)gU^t{kO>%52Q{OHBJ*TnpU9WNzl0eWj; zI(lI?T9j&hVR|29$y@X~%i`4i#h&QJ4>^mTE{j6j(!4v23*UH`ydTW|xZq*#5XF|8 zIr++S)+)+qguW_Y^8C7l)5e3*hQ`iV9(q3$5WRed*oTMI0DZlU2f_G}x`Kz+{v&-1 zPkAev{PJTb`6p->`nC@bTZ||j48&>A!|fwlD|GRxi+lwnuyU7)MLB%g9+ka^d_%dAnMOO{9_4P}0 z=a=uit4)GSb9Sp128$rjlC=*vp~31M2YYU(7|{#qHOJAV?#ruyjY8w?MD+w_hr8yfRqZ`F0E4f6QWj8lp zs%_>~uN1~?=HzS&l`(C``K**)Zf1U(lULg+=)|xcT|lXEJ#Sm6S=>VIZ!I%!cV=*Q z+b;tBdied-j@E77VFOOf7*TtsubF~hSM9!L2Yj`w_?p%E^~26rB<9Yi9nL^v&@F?V zWWjapy$#+?AK>@(-3rccJ39wzoZ%TeN5q`bKA#SOf<4m~l6WuPWaCn8FUe#NBDBYj zf5nlpxR;Q#;Ys{0`sOBs+_$(-n>4$7Z$?+>D>>+{zD4b?3=@6jyLrXVkcH9UvmX(^ zZD+sFGse#6uyDthgKun0*nwlzU|T$vgP>|lVClfU^Xtdkyx&9a@5mc+C}w^Stp2Xb z%)ty=-M#q!*mvm#KgV{(u4XLzkB4nP{K$88c7M3WuV3*Uy1w5fWInv_hhgn-XdG}D zZ*a(}juF3eXqU4WpRt2t+;NL#BQeBw+WF)?#+JIXgvs&)BnOfR+-1vAn+3x}?b?rx zZ*B%$&E4Za(a+h$2Q5(hoj6plILUm=$Yc|#5UmzDwKv&{&&4PrVQ=CSr&e4#wSI3# z0|POTZn?VvuN7>R>YsCKFq#e7FgsQgT&CZWu;Ke1Yk8b&H@#l>-Bf38GhF54`|ah{@3+-pe|tWkG#P3yDwfR{v$6ibx0+u=??!(UO*d)Xhi=x zFYu>&{rowDw7z$`+#mV+OD0)E-^}(u^7Rl>Utk0icoQI`X-Nk9 z^}H{zjl0woJ68bxKRa6Gd+^?vpRzNTaa>VxR|PlbVg{U?7{oaUG3G|sOEw%G95 zVD8T5Bgl{Bd*`&Tf5{8{lCS?zs@)(@)KOyuq9$;uhj`PwKSu-o`d|A3Kl}A zLkQ7b=r4VNVoi3L4yfDLIbEHfeSzb>lt#FyZ`&W2`J0_KyPK_Iy#BvrO&Gx>t=3tU=rW^ znXhNQ?#GM$@@Kw2_JjmUg5ROi%iucylp4e9^Ur)e&tM;O=+At;1y?mtnfD@Bj!Ur> z9OM!Q^7T2Ki0)tV^$J5kEh&+~@DcsY3nb-CJO|?S zFBPBLQSwY_346Z;>h&*QPwKiZd>GS%UQlX@coQYhm^3p?&Adu_89!x~?(L@VDm$Gg zdG5+plAj7(EI*@buELXa*PF%-ayNr-XUy_XDPz!!0=u+&H zjFSks1CvKcXl#ZDG_&#+)0t1#GprCpgETNI z)d*L-sDRLoG^q8|2ww?ONR~_*Y&UB}z^2__cDIf+Bt^&gniWm?^(oTOC#lB7(6AiF zL((wDCS%g@qFiPwcCv6Voe4R@vYb_fEP~3zgrc}V>y`#tqZt8e_}$z zHV~F_Cz&kTTilcm%>n1ECyPnYHf7jYPZFIXiw!}TGM(TjNglG3!M>lGvVg6UrK!l{ zu;^diq~%O{AVMB5q5bMMCn8ZngFN95;uRa4I5<(ojXaS9{fa{gn(!o)sRFf)y}0Q&AS%20nj=n2pR7p)4^K0~SBHCW|#FOC2EA8jU)^*Vz zsO{@vbnr=PHoUw)-G0Su>5aoM0-~nL!j=~hB57In;$ob6+hgk(v;j6Iz;fxNmUjpd zZZRNs7i(k2brR!clrqtYnCYJ9af;9lH=)vW>EV~pk2JeBLNhIvc$cx)NzkFiRPAu4 zFA3Wj#wq^l=7VQ_3Oo66NooDJ8AWN~sv&|mGNry~DWeCy%bb(7!_D|s-VGVXn5HoH zTvHH)BcgTXg*e7qRs~DYBM4iUw72*c64dBX26uVqj33*tenyg9aPLxHNv+Ef%#J_i zd7X0-#%L*j%jsQgMQ$!DW4EH_>@d9db*gadhN%1O#H&2U*F`q?*3aCXCJRHdi(f|f zyl9XJE(>v`FqF0-A7q@JQY~|>kVM%Sn2t}kR~S_VGv3pqrVs0u7r}~UY~5zynnUB* zy470Y+gW2d&Q9BB*VU!V+Pfotmbh`<&kKUmo#2T|t1|A26XiRu&C?24jz`Ub^bXJo z`nfr85#Vzo72E4(vV?=cKWzi ze38s4)|cV)4FIz%88|0oGJN4A)jS#0UUY`H1KY*$lVK|ZG8g0$!tzNvBVz8W&wFeHpinsn2Ex z3`pH-G8w;GX`aoZ7u_1#8FxsW;Ait22JX#EjJwe3vqjRAMfdhA#yx-P^JTChq>Gt^ z>07ww`6t>XNS`dzejJ?S{4=K^bjXnD05N^OD!xPt9rI=So=ttRu4L#jk;(Lfs95u2 zQ+vr{x}E9p4V>g++se>$Zi(rrar$Bhx3+48l?e;lV(JB;z0wQjF{B7}PKf`f&_q5>m`t>X_^> zOzv?^eo`zU5_K$57?$KXmMkgu19fZ#Vi>l{IQA1#oagE|nlKy^-Eka4Qe0DYTniYk z%{VTJ11X-1I-WZW&vP8lmlU5UP#r%Mh95bOA4>{O0IP%HFtF)uAZh#SM8xfXNEv{U zG3!1Cs2MZ%!@F}SSL>PQ%Q+NxBxc&G0P_N{ zHUNDCusHzv0WbBVf^in;Ks}SKPe;d z|4l^vzln&yr(^yPqzucyN*QaE9epgnlw(5YzYSo4e@Yn*+xtU2+)ww1??$#f91$Dy zav7CEmUaK;WKi?o9aMadaqwQXLA3*LG8Vp1s_@_Y@qt%2`p1+C#qw|0IYw9Pb|Mtq z^jW`L=eW&1c*rY?p=vv_KM%}53NJp}dMv!8p1^y&Y@BHaS>R&M`sFqUvf_j#(><4{ zS>g6Mp3wMY)sYb6cp|}tPh!=}H@ah^sIcO6vzU4B$H%H$Z}wLJALDp1bIks1Cn_Q4 zh_nUI*KgMai~oG5d+6fFw`O1^>fpipsMqeh(zND5dDV8YooT`>>96xxd>2Pip@vP< zpCxyszK_LLN|UXQwK^XTYw~xjVGkQVyf|E`Z2#65nE6AdLFPtYrQfz>+F1d6$5}Xu z=h+G~_!GMkP|SIHYeuhp~PJTUm#sDIGPuq=6LM^mvbGC|gMEVLqKUSyw!k zaun8kUN}ZK4ism!I`h)U1Me%|n7qX6puznX0=Z>E6b*TZ*R~fR^7do)9m~h`#JLN4eWLyv*3CCX7H9qX^UW)W0w(^xXAnW8wDPF@Pn4dyT}KSmu1SwI_8QkarElMT_31s{m`6YY;zM4(U?EzC&&LJP_jzqz zOT1MT_qEB+YrT61#bfTucQXiC*J%4JM}yQRGXsZLAMqwX?f)p`yN`WKTHIsYJC!1j z@x!_bZ}>CSlGphZHk?0XlqX)_e35n|pY55;=Cd5!io$Yi4t0LE$sRk3*gIx<51a;H zl!5UX0e*E$;+4XFO5@QAnZS6$=7W>v*DUX$VBJNFl=exkx<_+7lt2r z**v|m`LxjFC4#m5w!w!+&62|(p1va`3~S1o>2%5E`ad?!sJW#gMcj<77#CGAwlrh{ z$g@R*xwW4cYZWEGtjzS*vb;5?QJyMNlhewhm1M1%5%VMS4bi6QlBmu=NPM*uudOv6 za%Rz41d>6t^VVEQW@X>L`iNGIUPwtqD~_<0C#}-sbtP9QX3J zNBk=}+y(*VJQ&9#l|Dy9OBaiGnx}<#q1ZzIpwmU!HFtGr7fY>_&yP1eHLd)9@z_74cq;7(@~ldW8@Ewa>hZ zc^T&QiVcYvke4}#|5mt)lh-(6&m@?J_F5I~LXTQd31QL5`AHu^$NY$Ev+zbNWCF#9 zP1Ih0OUzoAe-QmH#{N6`P2L+D4|Gon5;gAT2}QRnxin3{iR+3K9&bKFo=z62)s<+D zWNvCSpMI#4xj$yk2tvF|wUBHU9(69aN#!-}?!a{``xgAEfHQkWb%L(?OHYUA^zI9H?f%a`K-*~PuWU1W>#7wfUm5`)0$ zy3O5%4HxKCab9ZyLUd{E938Ox4-guv9QR8Lc-S-gimMqE9ROHG&*b?45njG znE2Wy50+k0xAcalx^(T%H&%>ruP_rIcF8VopYcd8zC`}WcUW+ZT=E*1^gU@GlHPmG z>}Ag#JM&KV zj651@hd%UQ@>n~Qd7HxT^P~I;-!}J)vuwE3XE3ujdMsmY%vWy?HbBCAAp80vXsvC( zX4mJKZPD#O!FSZe*M!>lq|j$c0jg!;>Hf$&J;?*pk{7#32TTKfW8tT00X<{aD5Uly zL-AY})1oZWKJC-l6!lsU@{L~cXHMa{v2O5QGhlqu2lv?L~e6 zHE)i7|D8TZwZgy;oq>K0-uV4~(qcjEtUe@=phQJ6UgcvCRfzvF+`o?6t2;ChXu-U5 z4wg&t(Gd#}DGF9kaW~=!nZF;he&qRhJ@{NR@N&&>JSot=pU06V&}GKEb1mp<%>_8M z8GIjlp$Rd=_p#;(gY>%vTZZ*J8!{&wKhq3P#y3r|G^VK!zYxZ}L1mlM9}s`c-NR;D z(P!IBW#*M9`ylbp7|hJ8P@`scbMppp2Ls1XO3aNYVM(7>0^%JD-YUC ziP@To37)ceT@$srZ-o(xeW>NA0ER1S#=Qg}+3YA1;_ z3NRr1Vvhwl21(l5NlzgTMresVRAR1)vyOPIWtch!+}R>L*>)od3U@YDO3pTMe%9yA zz2?MG2>-|qk5hyf2ZXs5g*j;%5h{j0)C!cxPieZDBBPk%iiZdk_kNZVYV6aWQuZmu z>Lf6@*pGcZED0PA4@Z>bAXaY&Y91lNXoKw(Q|nDq{l9oy_6NVWFbY(ZblFHv#YE=+ zfU2duW($MF7C~~0kp-Q|`Y*^*Z9inOTNl{R*DB2w;q|sStrhLrUhMzD%HvVL(L2uc z%z!iuHY6Gy>~|czFyPt)0cYgg%y^-YF{I>?g-+c*;o1>*--Bkj9;YNKW&S{O9ihFZ z8vTzNLY}%H_m9%A8j+LYfzPH9R4JL*gK4-Po|lbT`RJ@-L1>0!DT2AqVtF?S0u!u|V@hF4C3(r6dC9@4V(uwogQ@%_SsGlNTC*>7=z{!@^S=Ga ze@R!6AXq>qk;T)LIfBl#9xUi9&eo-Kk4(vDrYl^7hrKRIQM4}1|5U)VnWwLu&Bs-= z*`!pIXHs-uCzqZpDAFS@db8+jc%eC0p^Hwzf@>~|L{5}WPD*OAjD(|5Xz*QvLXp`p z%ZRK^tb+U!jznPa#5EiUzNDs<(rdW3%$717#dqN5)^Kq&DMsF>jq24h?beDK6NzeB zFB|iSQh_HbYnNp>myORpY85k^oGl|;EnlK5|HxIL_Pl&bM{H)6bL=<@Vs4HuGT+s) za4s-^Enj(mHR@n9+Ur*34-bnguIS7a3v3ljTuIA)epE$GC+ul>?6t_`v*Rkl0T_WN z4+-HLvY{#*1nm0I8%n^*ps&6uS$$i#8W1ubaUiR?nyY!{s`<{U%c|p9O>G2igzoCr zu;j)IpV@St#tR@h?jaI{pb7Kxwdg6~4UckJ8;*Olk&iZN6~J|)A&E-drI9LUY^%LV zHZIB8xpiTW$si(yZo1tNJCF+Bj&*)OS5vV}S;*Af5<;umC68Wfs zhU7@Llp#c%QeeDIo{MBrHX%nYH-bj7>1jk$hizSx@B5~L=HM#NCQ+@z+Aa3_t>UQu z#&D0U3}jZ=Tyst{>TP8Dn}L!9kAg1ll4RwQ{vr1E(#Em5oOG4OacIj2UAE~nl#fnU zQfd?Yv=vI;7N1&JLDx#u-})k=%_FC6bf~S9u+d7Q#LkMlUHg9f@K)Ost+tw>cDGOM zY07zi5{=Z#9nK~lzFeNQh>mJPgKK&ngElE1ba@EwHY8!|F<~h;eiLOGE3ubv4Wikl zq@zW$lgR6B1{&o@&_Xlc#;n)!X0w&2g_U(UjeWa`Q|j$wCwLdHR|y{cf3WvfQEe~k z-fjYI2@u?ZYjG%2+-WIR3KS`>Z7FWWEx3DecXxMpcehg9T?=12JJ(!u?zP7_=X_^h z?#tZdA{iqYnT}scjeCr{;U>4kBu3C}^R$?^ML>qTueTXGlE0 ztK$ypQcmv*4{Q+BNIPn1)k|!A{Z`{mVH@pYM@?Y&4PVn;L-$+JZUNdJ4Ydw!lScdf z8fLHy+HUP$L!OmX4_RYIx_csZ)olS;c1~ELb`p1A6u2*KwnEIhCpd^|p0{{0zi=}^ zaVb9_m8gG%w`fbsS!tjz$%HBe=6X@;&mJuN5#L`4)=w|&uVz#(a_%Z=q>_B|xkF>{ z<0Vgb;$WxqU~l7K|NWrz^x$aWptX}}b8Vkp+}-@Gn=YF#j4UYEbLE-K8- zT$eFO(ilP07!l^u7|G8uGU9OxfpIF$acY-wTFRtx`lfNlrEz9rLaV#+v-Su!P1E@| z#3*tTA6P$)kplo>6H2_AIKzy@=l}o~`lPaD=m0qoFgi&Fos=vJ?PCSXpaAdyQyXsL4-jl0Bn}fav3;covAMZ3rcHq zxoJR>uo=9qxHV4Y02Kiz^*}@<(?l3HcLA*3xG@mEGH0`nkxFMKtT3R%XDm+U!Mbv;xw6DWaGkVr%e0&_w({_EWh)84 z+huV_29AJb)d*uH7-Kdkcoo3%gwz7)IkftPb;;jy0dPE7EU<=THUlVK)q*6ip@u-* zVV@xLf-B0*^W{|P>@uE+*)&$`52bfMSeNmLmZwPgmp7~|HyqODlw2l1YT(ts0vDzo?9j5JX-VYex9}b}% zB{}e6JD7Y&JMDVl5pgh&O1nt5eBgSzyZW$yDR5Yzb+~!ByPbSk*?hR~x_9_+SWj|9 zdb+%Ku65MvdUP$gf7^U?!g%;VatI(jc0N6Lf=2U<+U?jZ;uvX!8oY9Bbb5>)N{uOW z^7;Jcc)6s4z(E2qqlr_8dET_LCPc$`b-oeNExoojqQ7eKt2VZQiT zz{Z7Ub`)}q-KCt~WwqbsPBBb1aQS`nav$|disEYg&6T9) z$ho%1)mZkG{>YUA+U2Vkm*_%EteU4@Dd#>yC&|RsD&McPo?Sl(&R9iWz?EL`#hr+y zPCZRNjU&BG7`aI*2~E+y417EfmXN)*+lvYyz4aElEi`|UYI|#(a;wyGTOs=-{PLo7 z)j7lQI^L&B}4ApckT z;S%qe<7Bq52%`S)h>U*%8R``iXaA!_#@~PpoBt#-(lP;m5*Zy{SaiyN5E=4t+_vY} z7#}4<34Re7zkm$A*VSKdEI<7QGV)3?KuLTsAY-m#UFLW1;iYO1+#f*3q@i4Y0mVSD2npaPy&zLWn$myr|LMNm_L9F-ZU|{=MqH(wkC5GhC;eb z^OmMdFd&24`OeC0XgXzMIEuI^R zi{nLW?~035(;q~JZq=1hva9RQyBn_Z7uQV>e*_=yu0t8_!7w7@8Sx}sTQ{6DjL2YW zl=BCryhraq6}Ug|^v%MP>HEk-47(L_^f1E!v;MfQ*4CL5tHKRPNSvi5aWGh(WfYM$i+a z+9#==J31g2ve=+-Q;vEdgn3NME{smND<5*=63E&rx0#@eOWr);wnHKX-A z_wcLsV1sk&Px(zuZ|-w=E{>Y0c~_OVe$nsRgYqnA#!uWumh3ELakC*9j`4NE<>jL7 z(vsfdE4xbQf75!H)kcAAjLRc0YsTs_e3z-z)f3*ESKw;ZrX} z`e>8YAJL)o_?xkxlH)exNsGrfROo8~YiQWR_}d94f0Q1I@wZY{4afVC!;}bi<3f_- zVK7ECmys&S!ph@A#RG3vK^IAaNqAfIh9@$w&eeW(mdorx?M5;oM(|EEjKw%BRX5k8 zxFv)=l)NQLj+E{7nV)oMqY{ttu)0U*cdCQd~I})2CSmZ(aUi12Uuv^&mdX#o^%a1rW zlG}^K(B#|QiUh{x+x;pb*h10t*j#i!Sy*^A0M0txYH9rm#D_n9)~6;~1gF zmFv={v*Is}imlxj{lU~fRU~NAN5lNV6cjjfy+mx|;o*wpRL@oW(BF?o#7&Y>bIkRD ztZ2p~(_WL&imCPkBF4Yvb&}F6&GlcEjYpNeCT0AjI&d^L9$i;S!c1*5H?Vaw9@F-c zgvDERa2aDFwy%44e36ma0s@Jk3P>92F5~)tpSDTpT; zMB)yZfOl$cII?Ua>Et7XpJ_vN#Nq(a=jI%X{9tax^3d}A12*k8fMC>}q|)+@A}vOg zx16sab~;AX_G>xYT9IQjBXyNY>g zN(EXt1*%iK%6N72N+qs1g}OL!E2bCI<&1@&^&U*Yh1%QvhP#Z zI>SoFYuZsda-3Vxx@yL2K8ZSV-oaGYzFWp?xtTh0y;;zHe86i3V&NV`5jAw+Ng-{B z%Eun~un|}Y6Vi?mee6Z2p^K~y>A;^p_Tj-?)J1cHbduwq_(^K$VWmL2=rEN}0#q0E z@LC|lcc zh5_FS{)of$S@O-IfzTuVs5=_&c`Bmj7g#iU%wPFD9Uo)qi?|TMczE=ACY`3Cv^K#6 zdffDRHqVlwtQ)~(8tz4|B$}p?LJGlDp7KS$>XOli7BqtCvgnIKLrr7#6@r<%X^e|v zhb3d}M}pZl+{;pbO%r`IQo^}D<;(K8B@+W!>vs$-`m!=l)6`g-aA9uxvbqj!$<*A9 zaB&Uysc7d1Utk96-^q-xIjH|6 zGr+2UBQsF{Kr{X#Gf@6SGk%j9D1V_DzsL;KzoQu>C*Z%M8Q_?|p&6*Z*o=Rp8Q|Y+ z#&0wO6I{!5M$D87P0k8Nb;K@W0@UUu*{W z-*Cpw3IYty(5~|M7a&3YPe1}jN5Ciu7##uAkiw7#7}x+~9$@eT410j_4>0lphCTdN zQ^05k7#9JfBVY&w41|DD6L9FbFo^{mCO*t#0rOJ81QjsvD9k|$laazK7cfo5ZwKjr zp&I@NU&5>vVgGBb_+LQce*uaAAAkhTKLH85{gYpS1lE7Fby#|^|7%<4S6T<_pKYB# z(mGiG#bNqOT!Hxyhv^@X1m-^+rhh0*G5+o_{X=1j`45Nb9|}{9KifM0PU~Ra{w5{< zmexU^F}pjc{aacG{l)!p=ik#h7{A&&|48eg|LHLOOJR!k+hO{b!W4bx4~OZ$6{Z-e zfL{*NzbQ=7|4UkD5aXX|9avch{hw(amLQCOq;-CkbI%q?! zkaS3;z_&j4)_$D*?O9_^>?uv2tjT;u%F;cJa82ikyH#H5v!Zc;-{`P6wPM;TOjAKh z@Z8+UI^j9?r}vx&{$oMR{_!Q}+2U!H(z+pvQvI*DB_qF$N4=AbV;0MkB-E!@2p|8J zS>L!L&6lqboL`y4h(jwE^%|N+fO?TDA(Q_;rhhV<_aaXYH8@`}M`|ix?IK@gC|?<+ zrhlr~@uENnr9h=e3U#_X{-W@UOo7@=|8#Xd%0-byaDm3w@2(EYA7Pz8yE-VSg*ph* zv+V?zrG7GndU$jUz2A7JV}oo)^l6MIdb54YW0P_}y-1krGqV*)y4W}1A)Jp6!4=4| zq1iW-Pf&|3p5`yIIWX?SQtus(uCBV+Fttor=s%q1YZ}@y@4(Vnp+L3c>)P0{EQwr9 zVUFP)!r8UX!O|Rh$Hh0*#B5hypm|_HT+=Xn!)#wfxqKqZRloU(*>OTW>9ib=_t2%p z`a3a$_U33!bAR0n*YJ#$p9DEA56@Ad@QjH-EC}~ARKZdYyusCHeAjL0uuup#&D%93 z_1Z6p&Bw-Ug=^CmwH*aF#4LUal;1H@U=mYi6qv*`9;br)C9~^^fga=f@F;O@cS0R& z2z$mlv37QkU@!wr7v4LR>edxXf|1EJv?ETa147h+8nQNhp|ODh6=sK- z-q0^Xf%p`8AytLa#sI)9FRSd>tudc~nv&`)Rsm~@aQL8y^KV;cP`T~k*q0cd zT@wKImC2eMF3i+lLEZ1Br{E8nfY6v*4bCs#hoFFKpB+nhp#fqnp;5ugcU54nehi%Yn$nO(le$)-Q0% zfMI|(4m819$D^tu9LZBnpew!gPQEaw95x0KQrTuK<$N^&yFAm?L6RO<2~g^p)B0c7kk$$%_@C9 zEkBn~nQqmuJWuiHZt!fVu%Th!-JXAhhNRn2XF`L^?cHafS9pwIPFnZ-9>xqjcidRJ zAvl|xYImRUSJ5fe^+Tb)33(aCytAFmzv8 zzj*EKEb0mvbQ&x zDS@-AwzYMl^(+4F3x)WZ)iQX206zA9!MO|p60e+9Xy_pRTOSCZHGT{U+xT?X?kOq+ z=Ue_6wfJB%hnMyNS*(u1Jpq}lo&^YjX{?Ss5`oF#jun)Cfyv~~HNAley56m`_yb%) zDJRZ?UxMP~gZg{%duM|Z+k$9DWAR~0owziw=`Z+3AfLG}_z0~oOZNDyF)nKm{Erad z4G8|ZfiKlg?MF4gkb`Xq2of5`3O$wx4P*_y-u{?xAA08pxxEezNb{wF%Y&eQ33z&g z_arXNH{2c67v{X?jvfc0RPn%)grwPr`>uJA5*R|tXT!bSJ;@{?0^AXv?!nYMcNn+@ zZ#HFN@9^=WAW@Qrwu+K?3Jy^gCw}p{Q6}M0R9bVmagb<>G5?P{xcCmyrt$$seYkns z#tpO425SLmbTJwy;iR^4xYe;ShG~J${j?G2}x#SmqA) zV|~~}s!7uBF|%>l6m?0G@)67a*m3yD!XmNcKW?zr9g+oflXv^DcIJ|K+mestuy}b= zMAssd3{!Z>zmeY}V!xS75ls6An8#Y^4S{#W>UKzd35-TE!qS*aM;!Gv$5K;Ednsid2`UiQ*DtjI^uG0 z-E*Ah(fcTGa;RCBQ`)6@NO(O5lR_UB@xC57#xFY)A|0+UIj=h36$^N`0f z5(&`X-{t|<^3r$F($w;Q3g_q6qg{~IuA^YN$Uc z2}-|PpinxMPqIKY^U7ySbEudbQ0fRO=EK|{#Ve+c3%K{dI1Uwj3@A&t6|<|kqJb#v zwZ-IOged3nmE$a)NCTCFrFnA5C?~vC9Xh4Td|;)7s#a*#OATaj-c_ZxI=G0py0aBX z5C)vXtNyGGCbBSo#Zq9(2fljtwVJo4U9N_lrNSN=j1RSHRG6=6Kdy<)sqj!oCRjB_ zTJ_zv17HA>d>4@X)m`&x%@B@l*6`};S?Z{2Z86ZT!gA{poXo`x0QtT4dJOf&$JLnw z_L&I)rh@vi*7_nRb96u*Z+Re6rFsJ*Aft{CxxvXCaTHK{*HEwCfF4%geTUSB+|bp~ zn2QdmHclNiMjjuqFH>)-Tm=HcfFq~{P5G_B)qe?0!7zjAUjkF4WsFopR`k^dzyh51Xte+o?h@RuU}5}5w>mm>cmF#YW>MfxQ${p~MBh6zmn zN z;ZavtH$FZd78dsE)vK>xzjk$Xv9hwRuC6L5C?qE*fBg87oSeL{un>z?$h-s|cFSX1 z2S8?KfnD^#s?P@5&tTzaOirf%Y(M`Gu2}!GZ-tNq1$cORdHeYK`3D3B1&4%&g-1kw zgBfOH5KC7V z&Q<)s>PBNRFl!CgmaX@MQA#Jjtt;Oc_#glbZM5!JzQ7HH&rc*QFuTuAzh%zcp)QR&+|NIy z)e*gM=9mu-9az<#_cJz5|F(-y{6&`N@n9)$Ld$Y+OKXUxiXd(}wvh+td!Y zZ7Rhs`!sB?ez!)};U4!T7fny+r@h5X^M4e6=8qQQ=L}iqN2`z3XO^tr21p9LE7#|L zjbykHb^lrwUQIV-? zps0gGqnhOWp!h{F`p091e3A|A;H4+AJ3L0rX1D@aPI;Op& zFC-rekviY~(4d?K5;_B3Fo>H)_bVl`4WSMuL7R z_n?Z+sorXKGHq8vGFOrGzZ3eU;IX{g|P*@WpVVUuV{G;&P?O$Kzvb1e4yBSUAcE z;5aW@JK%{7GMsxL@9`eP$jDSbycY-N`-;(irGAbg$4$PM)zf)R`vac{BY+({RdgVD zoN{i7X9hwDBIG9|<28*I_dmiV={=&prb~P2yawwiFXw=s$?#>;55Ot-YBm=LS zf|QL+GIB(~umNBuengS+_;TxEcgqVTv>pMSK)qahR&ByR@1;@X;x;#OeMRPebeKNR zWZa_VF-5BFc6enFpG>x_@c!=eOM(7_1LwXxrHxL6GP zdlg#ob!GM93q~8F-?WMWM$kPyiOIK>$)A2EbK90*9ZB}277GirBwi$QoW>F*rlwJ) zdrd&;tU|sIB=d76d=2o8c?D`>a7A7uJG82S(i+M70Jh143qff~hDQMS!Z09`1c-`$ z%cWzU4EH&#x?@D)P`5Xp#0COf+#m`?FOc6SzoP%V78KMUf$GD1%=S{y>JQYJs^f3B!wCl8Jgd1NV^S*Uk2!k+>|znc z1O&W=9Wg*to*ajG8ajsnp>>ihAtrG!>FPi`BB?a*sa&ITNbBoMIw_Jf{!NWf*}ZhH zxgKGqoo}_eMAEixJGvz?UacEEV_beRuA>HMuQIyptHp(I^!$^@lZP>F$4GY$%Tj~UrknLS;`NeF>+Z>J}=uS*Rrhy32mM_|yX35unP>xk|}EDO%l68KnV zB2DE0sF0v1it<@c#X4~tVg}Puk)n})i~z5@IAREJvQ(!sifjp5;nRlm=U?yfBHJgS znN|gzKFQVCst0ioON11#Za-J5qu*RXZ?uWSF}3~2J#pWKlLvfwMgn_WcpDL-rHI-h$l%xCU-bA2 zTV^hG(h&5Zs0fJPNY^uXmY2lI;mPydi=sq9gOoTZ32xhXK!ri|{B2+|dxa_aXGf6; zx~qb(C?)n!9cPBC^K1Zqcy7nm?VI^kl|B|nj!6Tmk@a7KAQRu8E;-FzJpE9_Ydt19 z)1$Q_m%yO5Cw0MtN30qGJfpvZPx3L(>w8xp={dwx#!GZ?3M)zdxjp?a zLi`%OPOm3Vwi6%X+2eo(Pv&ETE32ZOQWnlY2;T}E1-~E`0WAR8=zB0fo45r2(56JQ z^dzCCclgxU=7$fd@OCkazPMPgx`os)dLBC<$LK?H>}+;Xicnt+HsT{V4AedHorY>$ zO<@V)_2o@y#gyKCr?|6kc6fFgjBp?;Oh{3u9l{iYC9}@eF*C*zeUu2y^TYCa5zPw% zfMX}f*|pV10_3O^A`2C%>+SmtLUyBNV3IGlTM3DQ^K6A7{X9MVFHlZ)voi1!oOJk@ zdi;69LE|>phNGt+4yPv?nejXQ>Mrt2_M#5YEi|m#d2U;Fyp#Oh0NoebjGqdoDrH!t zO4NEsHb!#s1|-@^zQ4LZJuFTxG?I8ST7J%IWCpZN^hel<=0N`4s3x6xw zGJNr%n6N?rh{hj%S+vTGc;FJ*u}sZASMvEm8Q|hPs9Y^HB>B)+{ud8qg z6gHvBg7&$;sr_J#8{=O5S<4XbvpDmd0GL; ze9JODAod`EUR{8f_cI+z2!fVGET22b*QeD70uoSuqoU1CsnTgB#_jd(3(+-#W@@C+ za)m;?xWMZRzwI@^z1YWw9Cz_2{b@1PmGjP)-; zwmy8HY#fyQc1~#2wtO}}n>uxfy2A@FyyEYNd_{S05kUPee^$UMmR;kC*J|h6nMlTK znv|iZr|NSpSV%wTa*7QW|GOpFEi{qC4N*Yn`S*s)hzHEO~ zZ)fRAm(HCuOUj`+&3)njoxwLc>=YRS-YUS&-BZ0!qq84iQ(y3+9N~|7%)BrH9ZgOJ(bbT)j+ji&#NW(1X+2Y z(+eZg?)ZV8*S!EsGGWqvp?MvNrOFYJYkFWG#TgA$Qm}+%&`7=%v9e=d%O)g7k7n?W z`VtrY*h+l(Lk5M$32qf|mBC5ikEo61N~xgGWXWOg&*GRDZ41SZC299Tu?iu6<>q1f zo5WH4q;a04N!6rjhosrKr1`p}#kr*Ao1|6zX(YF4WCZCHyy;YG>C}$twDIZm_34cB>CCt3ECd;BycsVk z)iO98Gq~e3cfAKB1@-j>urr_L z+k#et!uH*Q+Wx%uI=GJb!v6Zg!TG}B+rqMZa4%BPq*~FmW6^AUQHD`r`%WH1D;!oU z+$wMJnp*J>$KuWS;_dq4-TC7E+u}oll4IVIQ?(MsbH|d)_>$}TlH2)```eNSf>J<3 z@p62@3PCYELn%^1DY&5&b)gjft`w8744bbEm#_HgKv5}A5qv~hZ)+Lc_c8*)_+GU_ zAO(;?r@Tk5oETQop}dPHBP3%?0J9blv)`4c(v{D-m(v@Ac>{_011qTZ^5Cs17^ExV z^eROeVx^EPnWZae7s>@1DxsVp8q!1}$W;o)M7j%A?s-++Csl3G z>aMiv$It5iW`snaoGSIziLCFcwE`8CH!mk-qA((gRPfXZ za3_o?BoJ(;4)%3IA+-pDDPq`C1qn+*NMS@oqaY+QI7$F1kuZQM3~34jJyN}PpSHH4 zRTB;%U8; zZ1zyt7lE)mFrmiF^FUBKCVT)Ch!YE9ek<^p&~#ne@kmc(|I``T9{vscQ(m0!wQz4h zOr^lAJNh+3!3hj_iquXRlg=g#nnfm`R5oWKgYRbmkdnd6E0Ap6>yYf$3W+|0>LJ921L?ji8k)ej((}YxP)$I9#lsDF6)bdqR~c*txlFJ zaZN@LD|+Y72{2Ec?L{m6;5|`Nf%1rZLxI;&ls><(3cNrr@E{tf`<4tYk`J~HW|vMofy5rMMZGp4mh;Q(ekMF-ptz?Uq(%u*rK9EWpu z?mZs^)<}$_yeB&fm%LRchK!mWrNJjWgI`I97a0WHB2T{10QC<5Bv{yV`TGFSe$-J( z-ThvJ&##B|ntBTQy5+j&j+^5GrvS&}zED7)e9Yk`Xr&!6vIvyX1QCuhvxETC$TCOb}<6q!)3G{($c&gMbsZl1u}($NPQc03+*0~2z6B2#f}>iNn6 zi(}6VG(a*$b@0`18Y|+~S(#A%;2zV!00}_-Zp|BZoO%nouT?kdAy)YcCk9P9_Ac1fk<((&Y#a<#1V?;o`bzW)L{K{B{poegU89HI z>yfU7_OOML)~Q(?+Q5mg@|ADcEGDB*=_A9&L_fbCL#7N|@10r;?_=#7H122BKz-u6 zl#TU0Gga{H8xdyGr;s9bwDXm3&LBi0!ET%t#m`NPX+tlYh2*nWX!ymJGFs%>I~uX3 zN}_0{MB>ixD{&S(mUL*-(^tJcR~he5nJOFPNTNYIaPT5p+AiOhsGm|kg(Lgn0^%qQ z{;sn`@$SPjW@MM~NFGSc_Vif23i%7cHMV9;{^@S_CUM3t*mo7d{7f&2t3w?Q{%(Qs z;(0?4oM4;i;Qqjm*3!-C(nZ52&@Lrod&jav_D-1GBoMkS5Ytf|N%=J0mm7jg^_aX# z8o8S`+9_Dn!cl6NKK|0nOo_ZbHZ8fIS!;{bjta%pcs6qVdBr;NXC+^`JqBxN<|0?V zB3rRuKX2iWuM2(iKcH1i{Kdh}DSgC9`=CB8C3>3;SnpzC_TvxeWGcb<|3d zjSfPWo!M67uuF4tliKDMm*l?3Z~#01(&U86o1uU_-xvL5Bj5NLsx=MA z&ijvkb5E6sgHX4WAGVwzQcY~z7aR;3xMx`Uw0uqC4yQ~7GsxS?$SB%i$YUe7WF#n1 zfv#%_1Z)BYW(>V1-I6j7yd5G2eBVAeK4(~YgEoRhx_I7bJ+HDeJkzv^IK8DCgGa{au2&__thQ9*JUsVKh zr`(#e*_pr$O|}MEOVJs%5QvMIKc#tV1^vSC1y%6~VjwBMjplWR>rLK038lr=T;vsd z@$%=A(LM5g$`{unM@JV*=77*ef@bBB4QRVD@)tMcvKL@4L5AoSQK7FVA*+lsEmHnP z_rnzZJn4e8D*FXb$PP_eT3j~AA8tLXFY`;61BA)_hxqJFHp|59h#!eAp#U;&Zb}w* z+Am}}P+AZr89FoFog2O zT9TDQ{fy|fcxKx#J9RX2_|RoV@YfG1{3l+gILo${jnL;8vR%^x`T=M90%}Ir%#6wi8Yi< zlZEOHt|)xfkNb0#u>Z|qzM67_a=s3n?hznGk1;S)(KpxA+{PF2it!z&{{3beH~CUk z;tg2$UVx#2qP7=zYufmHC z2=~5z4)a-D$(9F77x;!D+n&9aTO!dZq8Fcfb$v-(L3(}^ArgB2)Mndi;4lzZkom*M z)K@a5QB{i!=$(44ACp`0mVyACt5ICyCpHge%jN%pBchC{p_+p7sciD>F9 zbwv_Rc#pbkzW``iQ8ak~HGM@2&8ByDrsPCxs;53Ock8PNQ6-1x^yyzqp9}DE`#w)- zl*rw%@-|4KLy7Za?*EbzB8X&|8S#!D!}Xh10G-(9&CiC}$*%cpBB>!#MtRxE0Y>?8 zX8A@1S@cGv1;xX=bcL160mjAkhyBJSNta0bapm)u#--f>-1}wy%mXGB!-55QxufsY z6N(b0q)n?!RSQgOmR*s}YIATI%<6ETOB2>@787FB?>8g+m^YpdJB2r#FE@Z|YUTpX zTiUnx%v<5m-m)~@;}RXWmE;Aov?FnWEjuw~`FA_8TaYbGakU%Gx{2JtR=pI~y~n-B zW{Cv-)W!v;{md<3>p{dSM)nGl318D8;u5)-VZlf6*U^qq#+N0p-4w0HMoEgkPP`K$ zela0e<5*$xUN*^SO2yoM$Yy4V`k-ozM|;6~wqMKCcHXEYm~P%gA)mvRrGhVEJs);bxj&!QPANT~UGQzxm@3D2y}18kEASWwW1j9qh9UTz^5km9YmYxzMMbOP@N>) zi6T%uUNfe*(bC4i)608HzgH&31FuTS9PhGl-A;diZz8o`neY%0Pl2K@xYluUO7NPsyj;kvoE( zVnTt!j8F6t&p>am_A$Uup2%Z~B8f5u-b+(F%L_?p6k}mTiqLt0my(!-mreqASv~PY zwg5e23sv-`V+-pS1AXT{1588m$i;cOSz`@|s96WT4C;TjfP|6x0$`ZoXQ{KeiJ;HX z?|nMwKOJ)5@{1r!t;$%qvM2r^%Hiok?0HE)5!9!#^>krp*z}ox=y@Rxi7%wr`OU6K z5cX?ZbVPs!OU^g6`VCKILvkApOn`^(J7gLlh~5-kefP29us$8W-b1RyQ{|G#sb;ozGljo z4aK8LOzCr~(5?``T;YX$d)e$3Q7{iux zYhs*XwcC&I{8XgUS&oWhLxry_pIuQ+~_+%VRGr}JJ|x%anAP2T!A7ZjFX`0!}F@Tup$%v zvjLAEwqHa9^3C0h;pX91@mrag-!mAaBSP%z+Qf=!{00`kBjEvra`nFk7EBKpO%^RS zp;%b}m+{s^haRWptMtd3CJ!BR=XK+i+x`ztr(!*>>ut*y z(=*LyO(oDXKCR?Kqna@ZBT9Y-(C8aTdqEV5vFU&;LJM{W&r~Td&pjC$JbY&^F_!vS zhmMUFUM+sbYJeMqo_y_fmw2$Hp_BCbI=v0L?O>Vw-mLiIeObfn;XscMwjbL~KLiaY zb$`{D{(i>KV<8`H6{#z|LH|mxUTv7pEXE`Tf&LI8I95s9<|k$a$8s#9uk|KB(dDfP9Q&Vf6w4&7zjNXp9^kaKVURQ6o54sae>|)2 zt(tz=&#-5V!_#w2ky^5G67w3No^5TCoAv17 zux#zC9RN*j*In~zOO3{_?SJ~{;G)y*<36drbAfGhPzAi^csB`e_}JsJtv=z;ZCgp+ zH)JDwZD$R;PBHa~GTNrRdYVHt!nI*_#n-b6ohNu1F8RXRLFMr92Y7Z|o9vIqtH}Ci zGeN0KRdu!%k%INkct1;$+%`AZxi)S1MW$Lc)qwt zQ%H_W;3~Tg77}R{cR4?;sb9?rxpgTZS(qsiZm)YpgJJ<)G1#G5=}zkg#*GH9WZjlR zG0rvBP{^ev*#|ozF{{xUr|}w%4PSSB6}Kz{(*iq-Szm)AJSdF{_~|SQy$Y*67e>Sf zhbBBbR){<1wfPR0-hqK3wTdyn9}j7-^I*1ncbEIMom-B930sVQ(FdT2En%{+2Mw%7 zTcKrZ#ZW5_Mo=+osc{|q;90HY%I&9stzt5A1_gn z_mHRcPONcJ=~Wl>GvlK5X}q+43GoXSc0M$4x9jnFUF$;}s}TZmw4?<~#p?Uc@Rs`e z#z9b^w!{n_@F^0i2D!Y5Aiw5KZ&W{5BcC4EuOECqS^2WFdyP-|ML@ugcwXYMjvQ1T z>>qqNa{OCzITfzyoFOi*_QqZ+ettR*V3mO8nt+cg=Kg-pVSb+iX#;C>-Lxcv{e=Br zbbGz)HHrlVt#7|<-=^zS3ADm<7}^fXi48LGwT#d4^!glhgX^|%ZClo>XSQu;;pb&F z>$dnN9vg5n`{sxnxeyXutP)z25z4|D_?R19x$S;U z3tGC?i}v$-;vdnL8V0P0h)MS*0(xTXn4y8}EX5<)dV|pH{fhkbzy9@P2Y3{gt(nfA^Icw+1=;v|o#+p+Y z1{(E7|GW+tM+EC-eS0nu6XhRsWfd2m9`j5$NLL~p9TZDA?X6`PuEP`KXXh@a8cc!c zZc|tnYZ~iiuIk@ZYvZFDha3}k_8q|6N})#?K+zN9F8P(~#=Cq9nk^AOYHwG4$aP;L&Dp*gvF{jbx2}5PrM&)yx~n8pF_AnUt(!)G!b1= zeqH=2k7=P~#LPwT<4n>c-A$4iPIB;ZgvxXb`cBA3oW_o1Y{qQz#$^H_FlpvG_6vT* zd0glb{ZG3BxijIs>_vW*Eu;dv9=<037LJbB|%Tm8dN3>3_+ zz;G`!MOHIhZgYfhbGZ90o#}IZ`ZL_)bDgDft!i?(9dpC#b0c`ZY|;o|%A z5_a>Fkn)r3^WI5?KFafku|V??!*g&eQv*-4T-Ou9W%;U-c>MYWH30>6js=zZ1x>pJ zUoOpC=nET=GWzKF~z0?t)K+9-A-iDm(+9h=c8kV0U1lnm|NjUkjqIK%E_h66U9ne z-jp^*mjmkaPazqFtQlkD`88uD7zri!0Ts7u6$E@>4yO{{gbL1v3T1APAaW&FLkZtN zMLSt}>PF=|S>&qZ4LHDR*==5F{A`h*Di}L~9ygFycNc>Z07&4d{)nt=U|g+N06xvH zG~KJFVE~)lRa>i9bdc424XoMcH9720F2$;HBdoG>tQuD#EZ&=qn|GwwVqk;}a42op zh9cL66V}np)wrrxLkVk%Fe^L>B~+11x5nWqw0l?0S8fT|+*yNHSkCK^n+w9Mr}@Q0By{a~GIB5SL%j*oe$Wsg^#u8dgUL z_9koelLNJ+VlS{_l1ZsfCfGMDgk?Hef8T4=@@>@r*3_ZXgd7e4S~Lx&*4D!}Z%a2j zT#D3h!D5z>EN*l2O80;F8aS5#Cbt7WgCX8(yf_ zU){C-h^$J=6YvHAQMup?$l83?0B)RUdf#!+ECAxpfM>#u4S=rqZ@XmKkl*SdBeTG} zvSPWMKy+sCpYo${@>B3K3L(FZ<|jhA$9ke-l%>JjV-uA1?j@dOW0q}WkJWw;8ZFTA zZIA0>j|;fhW3b0lqnG5hxQHT9LB1L;J7Gi!h3SOZtM$IF1U6`+(eq#!rrx5snQM!ODA!i%o%pz9j)C3;|%YoCXfS zBF2`9*_ImFheW*xj3Y;kBhbWbl^fw#f*0D zzXLbDn+zD73g{ac9NbCtzFF|?Bx`*mg4WhbvwRE^q-Z5?!eKuddjEc4Onx{|1LaE) zBP8kd$nI$W%ZVPviMCGH>CB0#&WTB?dI8$t3|in?r;^xDnN=!{r5{SmUzLUk@x4^6%8O|!Kh#SRu|9(s$u$2pAIj_B74k9sVdCX7T;G?WFL zVX+*4#{;T{V$Z#Wk+*zN89+$Sax5j4taCylkSurz#IogQT!ZaJQQe$Zc%^nciT>!-E?e&0HH zaA_${p)D>2DmbN-;#O#JE2X$Yakm7L;10pvT>}IU?(Xj1;_~vGGw(g;&fNcC|Frka zXRq&C?f-IsLgL-%D@@0kh!=qkO@~^_URuiYS{Yc{NjYULgoLE~T5X5ryK7iC6nGf0arrVuq4AfWp|P)xrG{ z23<6KqIldo*t{4giCyP>0>><`CWVslPh*b;iy|mg{o`}ypIF|%P^^(7i>1o!^`SEk zYp#ntALA+LW2nStEkDL9Q0rw|Bqeh#^`D1aH5c>1!!Fc;b=LTQm>2=1c+*g#Y{X!# z$ecZ_y_|5~s(E@6VCGkA$R76TV0>9w*t#_sXk@uEfV3Y{NU@gA3nBD*rY~%roBfY! zmf9u`M&){I>be3iextwkFxyDIh!1SXZxzR-M1M`93ba}3^C#IZCK&o-4X!L7#akU? zjG8wkYWzL83ff=s77QkBNv+6SZh9PDwRMi0O^6#e&(I^y@xkTTxPqZyXXDZD$`8*V_K5kW*S+mO+hvRfv#bW1qLvqF(L3VeeAMv3E1 zkD}?1w(W{jZ>}Hz%LL59e$~<&+N0yobt9fn+HfCb}KZl-ZjVvLsV=pFhsF>7QZ=MZ3$J zdlu~&rv?ou9ZFN5^DW0kN)Me(@+TV7ovzg$P8u$8TGMc4oB{wOX`aXQ%l_PN33;h@ zHt((Ol~U2twlUHqu?!L>U?(LFo3JHaeKDU{B~ZK=4W#+}v1NT0eU*FLn#L&QwP6nc z&ifMG?|7nWXN==)Z8xYr?2_qi-dmx~QR_TW!I6LH=O5Ct2agW$)Ba6tE^;uhX&78o zPMsq$=Pn}A&`xCZR7o6{8C!{VIF1JHN-u}P(!YNqGbV<}eL4(QJ`XcLhYW9y&?n#J znJ}GP@mT(gTyB!_Y4jZCFf{$!h<-N8ch!*^mt_D$5r?Y|zsU{%R2Y8o=kaX#Vdfd$ zt*$K`JPiBu9bNDk(%5PmeQZ1y{vnn4U;5*5X4=EEhDW&8^LnmJsCj;nCVI?BMwC#%X8C zzZ>RD$R~30i%ASkzaV~vIU;slRCz`Rh~%|giVW{}3Hf_=_2+itzTM#fZtv%^Z0ka} zdpwDfk`Ai<5hcaFfx(rd1Ifc;nF_dPZj=3A@Y^_skvD)|*;%uR1+keE4i5*wQ(}Z%pzb~k!k>EqJ zFs;Pw#*D1>(0^z-5OjOXQM5(GEIwsgH`c z9|Go6Gk2E;oA_-vO%SrZ8 z=hCk+-zfM|h20-pe|ved-k;WLF2KKhb&j8V+$;8|H2+~Eo_7RhGQ;#1i;~izS94n7 z{YOF8aTzS<4=a)@2|da|i(h}Fy1k6jU<5pVxAS+~_sUuE;jG2$oR&@sb(eC@nZg81X+IDRg~Y1H{s7{?Z=;)BXW4xWDFE+C7(q(As_ z(rLZChx_*6mz%!nMpttxcs~9_sbTA6YVmsPIb+yK&WJmTb5fFpIJeq|qaFP-RuN9s zfF70d?Ztx2GA=ug>+;Kt)zGlJ5x)W*+@2kl+J+OjsXOP6zWxg@qk+7%PY;(YnIBAa zzdWzK@03QOj}19vf5$}M4iwg^CbXdnFl6>+Dqv?KGXdHJ-8d6q9a*`##-?W4MWMQ9qtVl=o-8Ajf5S8--437SgZ7%+@Q{ z_*En>2ExoH7;na~%_$ftuPw#ary54aMf)s}U8A>8z4FZeyf0mo_L@ShJJO&Qu}SZ^ zo2-1s_X)|i#sHF4C^obWot4Pui@*LLr9GPXmG9-acSxF^sw4jtf;n&}Us01x-uv~q z9v2_a8^&ETbO7a;;JCRGW2~bL)RjeORd(2IE*gFnWBU59*0ASNRoaa)rr>ge(x0iS zbO|FC5vs^M+D())2`pcdMM&PB?>zG<3ike4J_h!Vxb=&FQv5#Aiq>M5*=wKK?hcsy zB}`&o%_Qe5nD*s$FrDWJ4DJ?vEF7N?a-u20xF-z@-c!k>{vgdSXMs&kI`LBQgN(}W z&#DxSU_ZH<3{AR1^_P8_CpAMvE? zWu1U5&6U?RpB(PG*6}Z2Blug&PmgjmAqJGYb*+G48!uD-m8%7Qr1PKG{)0`6Si%Bj zSYdjB%Z@8JO}n||n`J?p=gUuq{mN%Hx;Rn`v&7LwZCeudtlAarjnxTVC3Y2vYphD9MY8C=m3Q)SWLnxi249w!$W>hFhO{?O8Ni z1o^c2OsCVSBkaN|8}b50Mt&Y1_mHL7s<5TRGJ+;#2f(oeZ;{6} zlIPjHD4MpPUp8Xa7n^}E?jJN&)G3l*tStnz*R{hMh7_C@Jp;pcjK<1t&w+7es% zx(xrT@%S+3Ld3zY#cP47Rcc4O0TJBa{2o)wrGM%ZMPaKoF>~Z0Qn*Sg{(tI{7mvdD z5!?4i*6b3Cf8#EjE=!L8%-=4R?6|bx3>`na^Ke0-vTL*!>56vC%b_`%`=KyXenv z?FOT7hw)9DCyH1vW}dY9>+bt5&os6LHoLfLE`en@9?p3l#IDXJmmbg+?{-r!6%5}Z z2A>OQpPM?*`vAv()BJ8dK2v>p!qwX|kjIHS4)B>jUbi4e83Y zV#NbWYkZs019KqGmd@?~DB0mXl?iZ4FSM2bWM7#f_c@^O=M2;T8v5Qqb*3<2on0G6l4 zHCIah8AcxxBWMM{)eAsFMxm2L14sZ*yzI7X6rFuQE=fcgZv?O*0GvGRjz0*@)e3trt5H>THgrAs}tH zP{Y~4gzG>iO|(fz{|D0`Q>`ST=r~vdE>uQ@DxJVi7RZSku7e5SV0&V2f!kL}0Av&| zZcL^L@Err7QVTD57W(e-UC88$$LkQgP`6-9R|hpJ&)q3cTHMHFbB4kls#&&3#nZ&C zZ|HNHvAKe&^m!23v6L9sz{r)b&Z>AX^QUKp09KN=U6cQ~0nj4SZrs8?ti)ft0{eetdOHjDnlQaPC`;r zdD2O5JQxW+E+F8%P7N$inoxn*v!%1hWp$dT7H7tU?fT*iL7@#PuVga8Jc(Uj;PvF8 zHGiRSD%kU0H$5#Fm_Z?5BaTr!X9<)Q;*vhm7RR!e`8G6Dr#c=-JmJVJCKM3!aWBnW z2oIAzqckf=$Ll6bX;)05FVSv~;Lz0OYl7hvZjN&!TwW+LX*UmK=Z=k`J@+S8y)oa| zFG)NUmsuOg9tu2Viw2a(i)sN&uEDQs!_3@)X%=`SNDPt_3mg9wdkg3vJZG2Cyw0oq zgPD-8h4}`OMWFm_^1kcOm4Z`6gX7zvV8 zOe+mUlEOz_czTt#1kUydC>$Q~T1KS{3qd~%-NSs->3U>{t7i4Qn(1*8_1%r@p? zXJ;ul6vBEcvVpNQmZii2rBxOwU7#Z0q#}Pg7=k)3Z!NT=u#As`ceOA$%{Q7mF{-*Z z3Ljv>DVXb|RgP{5*f%T3SplGtq|~}qq+3)3h2;4ACuFUcl(OfI%w~O@PhTq(f79>m zDMw0^M5TjUcFtZl7ZR$vmys=4W!2Af{eh-DB@PW6KexIey00AYPoNUyTSZ({EzOZ- zF;`7VgJ-R6ReMwR{g$B&uP6vmZdX+Jil&HAs}c*Z_OLJGc&^f?Fe9F$w)mH*9y8cl zxbe)MMyQ_P50|PEwbxidlMQY9TDe!VXOVbDUAr@q{ViW;ydG$dU-eS1 zv3sBXwRV$6V`J4=V|5Y2_O@QTKWjX>X4AQeY}F;owSIP_p*guhX|5;<-`_;1>;{lZ z6QZ!-fjAHjP`L$;_IoXaB97%*RFmI^$UjrxPhYoaE-y+aW{1Z$wd{Xr1!Yto;UWI? zRIYOP7oyOEK^x+S@l)|V`SP&ef(7lfczj!>pMKq(CJf(xgz70^^n0ur7TH)EIP?K{~$o*|wA0@~msmV*I<+pt8 z(=5&{LNp=K;R_@I-*^ily0jw^k@y@qBcz|8te>rzvd}9lY1$tx+jSw$q<$T0@<^@F z$`lAP0l%Y-#!H>^xu3v0_JJ;6L*Ij%iza9t zHN>Rf`20Gi_j7qKTXk>3ovUkLZyKjW%0X|ERbPf~NUm;Qc1mB_LT{nMuVUK%&rbc~ zlKr-rZJ4W)lPv^akq43=Q}X?E2N>Eqte`CNMxuLrOq@l}*$Y5Du@Py~!D1rHm z$e5NTvu4j#j9`QR{TKyhZ;FVE~SXI1te zsyhe6tlhV(eWRUL|yO$E*q0-sL*jp}THqmhoQT313K%<=TgFqGF z4<%7OAmq=4RBYFp>}ZPuq4$f!4gs?f!qM->KPG=+sjI^$je;h%7bm~x0KXqj{(yHb zAv&pQ?S1?woPLifE>6`UNh$SkjTR@rHme2Dvh)|LS=CH?!qt~Ovtx_3%}b`_4tx9# zr;c>%&D^G8*`v?k(}MvsHSDu7gb~ntg?Ku}u)`rcXe+3s1FkoldFap)JIm{U%DN{^ zwk9pI9*BU?brsH#y`Co{oPYMWN84(i(r>%+DAk8E(!WZC;37Tv!CpO?iR(VBGV~Vx9MkYjnc*h2;cOi_dizWxz{^MN8}V zOB#WTD+J4hgG(!1lWU@je+eOG&xT1%=6Euff9x%>nJgQc8e2>KdTa6vPsvPteMPWy zMT~h>A(sQFv?>|ADiZ8KxU|Z&vC25HLUW|0Q~O&h_&2NfZ;cnfGc0Fr@0ZS+d3Yb* zT{V;Bf0Xl#Ter$x7bxWrepnYE+VHPk7o*>h65EjWUiVku@K-ncda?0Ai57im(PG+u z0N3UfbQ7>`6&PTS`O`cRd&{8IyykulCAv-lALh5()Zn(!I@;19+74l`a$=;naz3(j zAzH~5UpcT?f%R_<5bXF3ZQ6#Z5)+P&b`P06wk((tP3k>gYJR;_T)1g%v$H;>XspN7 zK-l?%8x1qh4)4-+^kg@?#cq**k3eJZ#(58y!nTyiwrt3jX~c`k$iCKlpVXxc1F^Hb zxU;ynkGbelB0lphNU4|`t&{s(cbG#*>H$g30mmiEaj9muf4=R3#2yqnYg0Q<)U|lj zSybP1!Csu%zYi_l#>6~`quH!q!yD39OO50X+C0F1b$rbIuAjWMo;7izrJ0$=$C??Y7S6pN{`uP$@966?^4A>LNXhv8 zJsJ}WUFe8V7$Z9V2UhM;O*wZ?`*(ghqNcqS@c4K={^N?%XdzTH_|;4PJf92J;nv8v zNH*H55S-gTt&A%=|JFGEZQys|era)o%Ov(`ewE?Z z)n(Uou9wNmk8-H%{INCP<@^5-w=$Y$uHc0HdeqF$6_AQeP*%&_ErgA=JaKmS8kZWc zfsMh!eXijZqhDxYUs6G1XAhkTRV4L&U=0xy}>6Y9BO0ck>T?t5NaNa$I6 z!MvL1JO$wQgf^9n9Hcg-gULF*c#%$;fNV)J}t;5U5kHkFbPslAk^&6SJ z^!A`QGCt4zFq%%AV{RB^X4mMISHk;O?y;O3Z^J>7NzXspk8p*s@#g-OBBt$?;Q3{e zEb9m#v=zXeTre-$X7CWNfR|76jrL_LBnx@reJGqLz@k%arXFNh_oGeQYoCZU;x#Al z;(P7Fm#v~`KF>@j{EwqE?dxL2--0Teni{~)i_Lus*5#L{f3?4`lOIiWL1~1&2?QP9 zltgg|C#@O=|Mm=Bjw`{udY3ce!g9(q^@3`7!aSOLaqil11#W3NPoOu%^I@UU1;Q*@ z?SgozW?(B8CerYDwmHE2&4x(yrnEw7Q#+Ik<<4~Ar2+rJ;`fi|IJ7hQuP!yai+}B} zsUgb*7u)N`>$&7S8vBWVW4E`D#D>0)A}oe}XlFen{?LfHQQH94dZ5h~TGd=&Z+)$V zyS|{AipO6>yp9*GGlh}oYxkKjub^FZe{U5kQ$r}BdiwJZK9du|TF19}=13McUaEz@m&jbJb6M{m;)8&G~)C+C`J%-hjk?p9q=?f8sWuiQ2T_1QzSMVWvU#P}iC9_CNL zt>-Q;wk`*>7Ql+(OVA#OkF0MS(GK8ALL;%@$w8+XK=Vg1KnZ%VUxU%xuYkzXwyQaR z9()20Xa$Y5tS`?HI>4G|tXsUQk%Z$t(<@i)@5br)Q~lfE>yL0U==ZOK%57V~JEnLL zSrEaA0j%_DM^x2K(SE^UcyNoCC67R%;0wk&AD|ffoE<<3{8CiXwo?3#=fmCJyOA~& zfp2juCkK>h81Gp}b-OMBptjA7Q(<`-8{-TZh+%k?rldpo4-tcsr)xIrc@UaXylD8- zcv~q(ilEEoq68U%9PN;h)Mjv2@a+1DxAHil*WbsnxAC8`xs8$j__O@Pz86bXD$NWt zK4td|8>Y%dSM!!>4`thRtq&Kw;V)aSkMSM6?hZA|TCYzH%398^k3CjzCm-vjKV2Lu z`<&nY_&3MTR!%7CAC!>9#0jJp<=u7G?x*v6fnEnX9eUXj9dhC;oxDysV8x+;i68Tn zWs|fB952csv;}w2zCgJ&JHneta^vr##+gI% zl}beQf|TA$hdz6WnN3s9F=?WBi_lGmY3|<6IqE&;M>;fNc?(Ap`w2Z_^J`%zv0n)< z*|+$@uX|LwpON}u;+=mtAN1UFOuJJUr+Pe92F#4TuOCojato1EsLhV}P5qotEuT@u zD9bf#<8}M<({_{Ya`aBWT%lXs$FYb|r)=FjvrjF!Zx~dgXcq^1xdnmJ21I??kJQ`8 z>OV1`_Vc`_{and~IxzJi=1F)jp%m^qkcs{yUx;Z>B|v*3%jS$+?vI69#?AL!xyjr& zh3x8OeG@rSKY%hhd7|G{(5eeRlvwE%>C6kEyhmfQol1)IjvI#>_|CF0VkHe;+b0oH zpGMK&Gzv_ zY*~eoc-Rt^)uJ z`eKZF6SYuzYE6kP%yI)IG~LMjPtk2LR}aHurjz?)__380YC?OWw|2BLnefm|-UBY1 z`mxAsP}W%Imc4zcwyu@%&?eya>*MCf+MdA#(=3n0gWB5qc`su{b(*JKDSz`Z%$86YjUsb;?Gl_ngy#dWp6B6+@rK#Y(GCOeV1nWb^*WjB>APcl8wQ?<#$iiux(3k;>&+4IG8so@h#_NFFiI{ zQg0%VjVBG^_a_vzchYrY;x%tw&mlhdUoO3mdXM!l-r`=LJ(fLotbTI(x1RW5KHYl% z_~^Mq;d<4h3bT)$3$uUcq z&=^_$-vD)(1^kQb{Mi%ynNa?$z<{@`0o*bH9GU@KW&wiW0Ab1i{+R%ws{j#!0C9mp ziJgG2l!3C~KzY|dMVUY)vp{9mATCzj##mjF1o?(o|Mp2Wu^=^rouDUk=PzKeF>CN| zaIk59un8*Id?wg(C)geca=Z%u!wPbi0lDOZ+)yCb9gxQr$SWcEM}835j2{SS`<=-c zf&#;Kz~R7M_+6*=ehRx^07E!R}9oQ-`Y@PL2uMA{c)3GN%Y#;Sx?hHEyhM%&AzYV7K zneyZd4!gc`{EIS;nF@d02}i??z+j8Ol8wMIkH8Iyz%Pg(?2RDajUdI1e8v_@E*nW< z9!V7vd7mFHixxqL8^y>L^;$NH$vlcBB#NydilaA*Yd7jGZZt1jH1|8%=nv-6{2|eT z1<}I2(W1N2Pa)LLY%yPCW4@ZlNQcD87R1Q+#whN_DC5R{<7A6fla2jh9;@+GMJ#IBJl-WF-mM_sqc`4b zH{J&~!H+E=KsF)BJOLDv5W-xL0O?JD?IwidCPuO)M$0C~nkPP`P!kIhlY0~4yNPMI zNf~TOf3jqga?F$RLXrv!l8SnhN_LaVaFZ+8lB;BsYZ4PXM}poYBpc=@7mXxy)hGYF zO77@Q?qW;nrb_AEO-41O^z5ee;ie4Y!aJ{1#?0ZPA@B)7_&_3jp#VO84PVrPuQkBe zsZu-5lXrKYPV7?;cTKs9*I!Obtr$0MML%T`C z6iUakfZ!#i6S}7p+@upVrlYfG0I4&ug)%6TGA6S!UNmNq7pBwWWs(VHQfp_voXcRu z%b?lIc$Jj-<|dPwI*YY0lU*+3-CpMVB(%}4G|{B+G7{6z?Ac%BvcFnnONVAZrBLPj zvK9BTmGN@Eg=QaZXZ>LJnU641EzHsB%hB7*F~H0H$(}1|0amxjO^IhDGlZ;k>0%k$XF^TNybVbAxI%MY-~4+_l(73POf=i$WVn&K6J zf!XO&1+f+d@u3BYg$1t3eFgBnf;7Cs4EDk-xx$<|bE=PM{p1Bja(?1_=zzcCD^@QKGnDv)g~0x?&UVl6(QJ*kfBAbjYVB@#XZ!;-ATnLq2m6%;z7KUmYd=+>XLT3 zl4*;Q*}~$<#*&%2lDWR(CF;@@p^^>ll1e#AE=A(8_NKeOZ z9u;`~73ee-t%d*bD3CXmGc;Aany+BgDd$M8yx9{!-9z8&^F6ev%E`;UxiMeCDB~6{ zJJl8~>oYGatPcYQYcU*^OvM*-i zE_$+OE()m+)~OLmt_K%YL+9&5J&GddtLnpb%3?UG0wE2F_zf|Z4axlt@%s%Bj>gE_ zCu3)0LULo4a6`U)V|r6#?rmc_$CJ^ah~}!UfyTZ}yxwrWK1seQm&F2sY#D2>qr1iR}^|rtDe!q3I-+WHK6+^+3?4%V3io`8O;twFNIm%95 z<&mVGZWr@Na;r89Xd4xxwIZ~M?4WJ=qLiAm{k1|nlU2J!G8)}o8waQLbIP~Y#GY`H-kcO$8`{2t;=Z<9@d&G@`PM2G(O0F=AC}yEhTq>%te>yj z->lH@3F&X-91yC-Yq1)rxf7Qc8R#dFM|chl>-N^}^bIHs<}>$?KnMLr`o?t!7dg#_ zQU+IPd!C3ejU}t0*MA1riib=ViWWSFjy->^AcoEdy8i7Aoj`}1bO-(p4BLA4UlWW> zDfHeejQsW-MuUww#|~i)j)>h2;}VXlG7l3fjtY5>kitgqsz=BNN8jF!e5WED+y5|1 zr#SY;bCeM_Hi#Hz8XV)f8)YLLclj{Jr8rLLImQbc2Um}N7#t7t92X><`1yWZRB__| zcKkDJLa=)L>)-^D=Y%ZbZgut~?Ni64WLJGWz6gj10BlX{9%>mHLo zVN*(9x_=K&9o$Zu6HXhxpR!h*PVktrfBFDcPW~C3ZuOXUBb>?oIO3%^({(%T2b;mJ zoDNc)v6Y9FV0mBQsxuR>ow0+m&^;lpQ&q}t(0z_Z&Q>?Je-e)&$q)CMo#Ct1}Bj4 zg<-hjK=W)5;S+I7)z7skdblt(xY#1SFgrLsvADPnGxj4`GHzbnC|N4}vlPU&B(qPr zOp}>=s5rArw|q&5SF5=E{(kA|aM_V>`BrcG;(i%R%mZD2`rdj4Utettw(?bRg@Ah% zn!ZfDG>Rv-N_8aJY`v;WxB8WuXbtCTk#1>%l75XP)e}u=jecpCwRDxdWsSgkjaz9w zG;fV}Y2u~+IswtTz|!!$mi5mMz?sAKyW(|;hhE3&zDSL$vv$;!q@ppUW-m}GG`w)@6+SHAt z)crl(eZ<2|{loqo?Lm9kOr`z-#Pi@2ifE>@<-irOKSZ?LM}O$Qa4q@%ka9Se zIyJR)=z4d!LO-=2c66Y7w3Rx!mU?81I69!8+<7?C6FK~JX0v#te{4c~d=)nFx8+#P z^Z4Om;fDT1Z{Zj-Z5-fpq9SsF_jUn$`9v+{goJpU;M1wp-N}oxd9t)q72VTx+LvRL zkEf!D(^t!L^af{Av}dg0qiBk_C;iXIXWmYoUtJeS`5`Z%?T2du0EAp=s-KBCXuEWcJ zZI8tlvexG(a-5pPf0L2_M4NDvM$3qvSG6X;#~wjDyWx;Qq9vCDWDKuj3@$`HE^t#W zV8p{##~0CW0incY$%cS9xdsNqTU@*PxoAE#yPK-;tM=BbQiGN7$HK_R3s}=lnfrCp za#`{`TbeTBr0>1#ay?qad#AS;GAfND4!pP|=&u|B$j3t76#%>vzz|Wt@`-ERhkNT7 zeW&cF|(7z4<|%>!gbM+eeTadVSkgV2R|p)oPMSoje{B(ZQ@#zb^{ zG+Z7K8YC7sAt?o(i=Gev4}M!iM#cd^G`F-O+uA!iySjT&z5j*Z4v&nEJ&sRIPEF6u z{=dYxGaLOWhd3=P5?zL&SHOWx0iBPLOJ;;^3xy>@ zbovXk{EK`zHdu8|x~%XA1(n00L(l)=Z~rgtt?sI<%%2)=F6=9twy3YY2G{@q+M}ZD zB?)Gm#yevnwMj!^bZ_~3X!;C^N{mHUH= z3>FxKPg2D({rH^85F1FwLwiMY`t6J9YNS4$1T_W#noJ!zqB!zR+LSUwdGtg6Hw*wp z+TD*i`RP(OIqlU#gW;Cu|Hf~XTU+jLucrG%hhOeD%P?ZHuK12{kZ8(_Pp|hf-`ex3mw2f>wv~Ytvay^k_eH>42dN%q*UWBFOS!c zcz!~sPWf449GsYL7b)}NdmITS$Pk6{VFYZwiIV&AiWXan2=ljCfaL75o7QgBa=c+M z+g5^cocUIwS#H5r(je-`Qk?o~(0Ga`+WT!u&7m_ez;~7@ib%8VjX!|$dnHQ{grJ)C zC)FESJ79VYEp>ioCeNOkg95poOxpWY|8;5bPx5vskF}i)dI&RxkkMvsfTk$yf}$o=&8vRT>fe_yJJ;_=9bVvqbumpVB6dn9{iBL9|&&gsPgwM<;-vOtE`d#~!D*?fTfeo## z@Bx7@W>m^4s%PkMsIS>!2ls)w4ompJNEN)9vDP#n`li_ln;%CQNI(*J;yAAU7+U0j zPMOk{|1+?WwV)?7eHfu1;Pz^g$psNBT4%2RMRv{iubE66rN6BLbRQNhMjHUI5EqqefCuc zY>bA{5JPmP$%~B%`1V!J;2W7hP2o=_FHfWAuaA3b40?Sd!UER-J@39!Z~Z1keV*h* zVPgtlc4Z0`8X03aYyz*+^<6V#r@)eq!wyY%+!w|vpA&W5id^Cj9CNbTP@ z1{DSyLGa%cNC0M@6+#os)n=6UHVZx{G@Hhdg+RX2f zSkycUbFF3KmgkPA%xnqag?lwt%Ik$xZdSm!*&e9)&iun>0eIuLAk?9lX+t>82 z0*7#7i-ombDJQ8~tO;fYm9e%}G6ph?YB_K8l^@yc*J4=Ad{LX`FG=XMs?-IUqVYE% z&*!On*!tuoDTKeK@ujn&t&qG@L63`$=m_XI#NfqX%Eu>|8~Ok9Ol|Ka{GBV@r_>nV z-=lYHV=|+8kbQkJOrT>|V`6jNDB1L+1eDNe`}t4CBv}zt!Td~3`Hyz+WZx~j^j+jHx!1k}4@}QRe9;2T%n>C| z_-#3?#L5&B2EtYgV;SI@XWxhyG#yqB&Ylp+W@Ht4_Mu?1Pkkkf?_Vo{Q`Q@9r%a*HH zYmI>GA3}JN#pgG=Hm3D0M1lKk*yT#1Pp^5LYU8KV!M;v`=`Wa8R>jSu35N31)1Hix zE_8M$pY3Ec|J;Q6Mr{_bQ$|cFeg}eaM?SNfh#MVVJ!4oU!((&0if?jufR+rFMSv7G zUeQp&YJVSsEjE{NwWTsaO`(QiQ+cTw*1fE>d7pmry;^y-6-cSq8RmOMQQdpc1ppdy z73+p9hiaJuDhS_VXGMHh_}e{oi~qsp8I~ftj$Mm1QT2a%OcD^Q*1O+)l%&$A9xp2+ z@-fEGwbCY{NO93mi+GGty8-Jg z#bz_ceM!ZRnfw=dSvTt0RiC^@a==#H`s^ZQu0 zoaYMP@2Y5M^|lh8&}`cebG~hy+Qt4W;+V3c5Z*Y`qjXW3RrK zaf*UDNW55FH>n?PTDly%yi!=)(5tKKYvsCZ?+`iJV0NFm9lGikemLC~zb!0jvGJ6>b!3F{)iubX`|sW zCa-sYrlEWDOyqL<7Up@}^mu<_P6=c z?706*($4rtYk0n>cdF6wf3F$vaYpLg)S)ZSpL-_2ciOCJ%Yo4~KvFYMYKEUh)Aw(; zGv`i#5|9OcX=i`wE3Fx%?i!@3;kOO)>;LGXd==C@_2;vu{TFbMiEFS~0<)}uz&BKo zHBPXdCPn9^^Eidg_k`el*94F|>W4PU8*@2U}DGTgkXk zQ3eED1@80&sDZ)Igpj2Akb{VTzuSMd)Irgdpye}ACZ$aRDx|PJw78xy>{_q`>8pnhT!m8Xc=*ovc4`i83ULQembZ7P*rU=Dq{fjrXT!2|ICxIBU2*sYowe!$W3Z zV`@Gv6yf8*(A6EQZE!fIRs<6&tee>zupQnB_P*Tq#W0T`6O07phlc~h&QTFHnxSkU z=sTH62CAqJ;D|$%^Ygq&!fRNtM)<4QDDGXa7Y*TX3NUL!WD!M#DtYAFhG>yq*Efj~ z%-6vm1VhkVqI7zqL_-o|H0$}&Wm2} zBcHstxDqwJiN0pw?pc2qEvLTOlqc=&;Ce#sZ0xvTm{V}FX94_4dyDv+qI+pKb`8hX zgloINx1OxGi9zjdF&jOp(NvyBf@y!Aw6`)dAeWxkks(j(E!W9z+LQIRyB+>VFdgeA zz1cY(@5y>A$Gl^n+yTs>?t4k5{U7P=zqWLp0zaCk0e3uu;Xl$_dWxioRP#(8ZT{wl z%=g+^rl2HAenZ~ItREF=AD^7Jd-Pm;nHIfSQrdR!sIyI;q_@Y6nXmU!r7beW`m(=k z=NR#3gEX?gxo5LltH|QznkD5P`sQ@krv7fs&@0S!5X$Rh$mPDt z)e*{ZYRvP*BecSEL}yQR-HWxK%L`7*b?!?H=*xPN->Qb)hi2zd7uC<@6=@gM z%;h)S6p0HJA>@k6UDaq(VJ&s_0K_ToP6!mhpI3HIb(>JoV0e)8T zg7LnRwZ??$n-WN2$$C;LioIZKu2h4xbdS2sSgZ6%J9kHsz3j@OHKQdRI=9=p0n=JfnpkuE`e2U;YvFp= zy*gRYe0>f}eJCU6n_|Ntu$wXR0 zZjlEC$Y&{S)oiUV?$Sy3irW~hjBys)l)-I`p6zuT$TthA%+Pl36e$NrG${tm&m$Oo zPt$G0|5TR#C$jXv8%uYcQdS*O3}~c09SsanRm7n^i$HPmpg{)EpUO&M1E{E{7@cl!jAw6LN^b(9H))|a z<*qlCwl7_zFY^h#_3X<{>B~p-6)yA@-}PlG^oFFMf_3{(uzMp?K%9(*{fJ__`h|WZ z?LfQ8K&S3Nx90#VWuWg#?YJ=TgyD|R(+-Y_3{L0{PI(T_qzuj>1{W3vm+l5vXouEB zhBoMQhqgS2c2b7+5JLwGLq~T*C$z(7BEx@mhc7&bX|Ga-ZxF+G3&Rh0!%y5Ty66a| z-UznW2oOGkN7Fn)usA|=KSDw`N+vq`TyOM+*C-`?l)8D8mU?lN{(h8!ZtRul*c-hu zX0I_;_!xWh80X>`nEQT=hi;rtbo{;E_(!jC0r>bK1KQ^Ze6@jbF*v@cSBJRb#FzVV zX|9RSdJ{6u6S9L7a@G?{qLZq6lZw`p-@GQZ;ghK0I zbjs3t>NkAK>VC=^jxPio_4k~1g-^RTPkSy-+vrYH98UX-&IIbs1bfZ=93q%`;nf8| zq9P7v!U<+0q5V;dsBp#E1h3g7y4l3R*<{h#H0#+^(YbWRxlFyee6P7ez1iZ$*@DHn z48r+buKBXVxgzWNYOnbU*nIuqOsL|h^Es-SVF4k&V7`UwfGu>DEc7(@_a80{!WTxY z7bXT5CPf#g;EOZOi(`w6!~b6ZUl^e0?cLG+-QpeI`=gr>f?cVFn z-R~{m@$KC${M*L70`TCAzK!1z4G)B@iTCZ_!z>s4ZK(uq&;Wkm0iNImZQ#8);6M)E z%L^Xi2tMHg4dJ|4;TXQq9KMSZ-r&&Wa&^;&n{ot(f90KF~0(ED0b3 z$Rz+HAOa*XXe<97<1g&u%VH{b@&U~u4FvE4FHjCgp5%Em<39e+PF^e-KpY1V03^@> zJm6N4KmgJJN#)?>dDG)he!@VGERYlxz)(#BU;#z`3}8MEZ+AHn6$V(5f!&#Z1Hrz8r-peiE3=;5^In*sneFzLOJKrz5N zhcE&y;0wDh08IW1eZ>pKJ`$9kT^>>E%`)h%F2QA9EL(E{#y~Y|ZtF?|4F7@bxeo3I zfB_b8NEsVI0?5(rq8(=Qrc+ z;m!;2-s^|J=HM;>i_Qz_)$UF6?gh^aP~sR8KP${4?fPE8{5~c;(dNX!6AjNo>h25Q zF7bQ<3_LOKz7Xq!#0z@?@g?6b1YcUeaPrI2@f+XHIv*wnKMW4i3j=@i6aVux9}Jl= z0=@9?Dj)F(pYpx{3kCoKC8O~>-@nz)CLezcD=+j;Z}LSy@dhvUV-NK)-}EzY^vK0O zO!O>0fA#77_Foe80DlMq|Mk6q^JNe5P`~g?|LNeK_GE7l;}r=kzbsY{_wXzC`SBn3 zP7K&$QBx4i71I;7C_=-=?q<K`|Pay_OVQzPV^c9?h6nM$Wi=H5B(0W6t@2IzOY@)4-A>$0KmW<$DS;SZ~ftm z{q~_zhu;fw;$FT09e~gQFbL4XA%I9Y$bgukI5>l7;7FLUfG9|KX+b&IF@iYr$SBY` z=+OAcg9!8J`U)E>J4;(@dke54v4Y#{`wJW_JWO0{e2ko|yv*F}{0to}JxyJ0eT|*1 zz0KY2{S6+Tt;;KJUTJ_R;1B=d*s5e>OX5 zHm%yVY}>kh3pcLZxpeE=eTnuiumlihjO52rR)BsBbQopoDua%ZUb_NFus9UR#5ykV z3*gA{BgKprC1CJ$24=m!BATfdGSuj%Z`m~%h&{JZ$^M)_KbR9=Z?mRfF^)|6b5=H-`Sj!9;jX8IE5nQ*14W}9xl z31^Z>#>pC+bXMMpXP$auXy=|bS?On>f(}Zkf_)aM-F%2H%4nmGVt44HC`Br1rIucb z=}&WJirkW%ehO-+qT)p9sMr8%YO1QP%IY%e2&9i*2?RTB~h<&2|fJxZ<7|?6}NuYi_#g zt~(vN;iAfJyzZ%oKv|3+XC>x3NOs?Zr?T>3%n3d zOmW3Gy(&Qv67;Zf#~gp`amXN#JWIwRpFA?k9+#|f$}X$yvMeUQOf$?I(=xNoHOCC* z!xsMxbkHXGJd9~ZAB}X25(kef#b|YcMfAeSf?2ugp(>?e@EZfBvxVpCQwr4}AdimIMZvKHojB zZtrWL0|!<;|NW1F+LK`K6!nSx{!W63ORsGu__88d0 z4T>&?7TjDB?0C;EIa0A}Dz%%1+7hl)UVvq{1i+P38}cmn7yB zIax?UelnSH%w@sq=*wtI(@&BtCLpc3OJULvnT9;%F2}~qT7J_l&YY$>&xuNC_ClNA zGo~st*-A35@s>DTW;f-i&L8Fzo&4Mfp|yyif&Nzq}>GohhGBu8<%P_vK{q9iRTf~t8@c*0Vi7+q;aWeG=dR*Rk^Eay#Z zYSNtQbVMk9r!3gHJ$D8)7BN-mH;dZQhO%X(J8h~{(F0IQ3e=|&%jh=iMoXjOjHVyN zX-$~=)Ub*bQVIl*RoiJ%s=@-O*W>0vJ$g^CE-0j9t*c!L)YG97^{hcnt6uNNRRr4e ztx8>sTr@UhnG;799e z%hT3Wv(GGSXI~53>_`+CY%OhMLF-tTY80lFB_Lc0OE=fXcDTfC4r6h(kbsUY0uDl1P&%e~vOx4Z70Y<(*$UYYhxy8LahgWrjN%ikIK?bpaRE~&pBBq_#WY6oiwDt@64!XgGq$mR z2`K>~2qDNr7BZ2KY~&;td6`HqvXh$(=tWC<(uQvIqh%3kN^2U?mbS_xCIISCi+a?gF14s>Vd_+?deo@-aG+le z>ts~9ShJ3`tpi=_U*-DNyvFiuDYjbcUJDz3Z>P3< z&hP%R+uXgPd!Off;3mcPI{Egvgx?sPpli71?ri6UZ%E(-ued|kTT!1Rj{?Etcm+HT z@{WuA;~l?o=N#_vblXSe7H>ImBdk%4JNKPV$lGeCIgFXU$1OahD4{ni!9{ z)SfF#`GS<$2MYSopKeNn|3i7^0`_Mk<+0Lc03k+xhh4vL8HqemWO@q zPlfqayIyCr&m8Rm*L*VD?=DuiqqU`SH*nv{Wp}&_ey38u1=YcKXK->ow|@tE;0F(R zlIp#$ddgo7{x6j))^yB3%<~h81{=^>jwC82v z5e8ty!yU&P2H~BNx^#QbzV^g_r0N6AdIkTk_cj%ij8&LcrGJW^Yf3Vxv;_%?#z2a%kc+vxY`{=L#^zF`m+o6&CtbD!i=}#cg zXMY12J>RE);U|Ar$A2JVegx=%>|%iWgMgy9+4ETdV$b<^RffJ@iAjo_q^m|j-Lra(*DAgy;BAqym<2Y@sh%MDWNRH_UC+KJ??5K|KC@qOs zG42SD^_Wh(I4|O8kNcP@^caig$d3W}O8`kN{wRYy^pr4JjxF*(mzg=#Udh zO$-Sw5J{04IV0XSZykAVv5+zmAORvtk|Sx7C5e(I`H`?dZ{Q}9E9nX?S#B`d4KfLC zG}#RpsgXM=eC}sN5$AqD8I(ZzZ~PW-9Z7Hf#$`i;k38v=gCdD%rj&9sm2Xp&a8{LU zla*;>4ifp4UpXd{lyC}{cs((ek7I}kcb15AV`jOQY?*jMsZuD1auVk%773Pli6vwA zfuZ$BLrH{?wR=bOmr_HQDfdc?=P*sFmyPKr!B}AfC53@$Ta#%`iAb4McXaRFDncp>>(J7h0 zS(pSCg~SPm708*$xt-nVo<-7}OvGH&d3DlRXZG2D)(M+#)Sk@JpYI8v%;|;nX+ZeN zd%o#{;t5}Qn0=kupzIl-5gI4)iAf5IngvRr_nDzcsE0h*p@Np2wGt|#9`c>_Ii40; zp5_^xDC$ba8HVrYnM^sNGfJ8Rx}X?(qtltA2Wp^VIGa@1hN$(QxNw^^YNQoHp+Y5D zk-4Mb)T7slU_n}fGJ2#{x|cS3p*otS8rr1!=bGqweQy}1;GmOMYNo_dqBy#uEGkcy zd7@0ZreX?%atczPX$@t1rg{1sN&1|-MV`n7o*TMlKYE2y8eT&hA0euzi&`IiI-gDI zq=M?DIXHn2I(i?v4S2cAsGABKj{2tq2B%AEqL2!uzGbM`Nu`{ssyw2QXwsUL>Za$( zrWrS@M(3dE$)L4lr`4dTtLm!@whWWXhMx+ZDk`dKdaP1JsAO2FxY`Xx`m53U9iCdN zWLK?wcde#Lr$P#*b4sR)Dy`w#E?AnZT3W7Lx}}jys{83p-Rg!ADz5RWgu6O*uR1}o z8mFP!qS-2#?#ivVs!;Sg4ZSL_1uIhI3aEj)t_ZrC{JMIZiKz#=iPr$F&jxF;Xo|3) zs<7#*uIVbQb_lTEs-6)$4&ZvRCtD&JTdex3v7_p-_6oAwI;A1|uHTBND2uZ{(y9l? zuo8%`Dax|6YOMQ8v;Qiyh?uH5tF-z-twYqK(JiPwu`w8Ml*d8TV~sGGXb_>7a7wKp5Ov8xxal8m!Ri<(x9vrxOYyJ@*A3%kp^l;*pz0KCB~ zX~Sz^s_P1?#=NPPw6Y7mN&>qD(!ABnYH1d|**mTRyD*Haz2A!r-m51S3%=z$SJQir zT5G=RyAtA?C+<|szVW+$kjpUeE5G+US?GH$48~zB#$jx7rZL86jK*F(#1M>3DJ;$F zOwHZgQP*rJ_?*uZ%+0_|&m5YzJ}AxujmnYSfln&R8)eUdywD82$?;r$mAuSKXS65s z%oZ)h`&`Xg$Iq>d!XS;(vgWncI?|cf(y&+2WEInFDecC+%%g$|qX~`B%S@sIP1E>` z(Fr=wBCD(uO(7rs(*(WJeOc1h+|(qU$w@8M+G);mJgiUc(*9htjXc#|9l2Ev(aXAs zu4mK}lFeSds6pMZL%r2SJ)NBEpg4) z*=y4sd#=J9v89^T?OfQ84P%~c)iLY07Rb@;Od?1P+3YJxL0|-)4cebA+M!L_qiqDe zDP%_=1gg#2tL@sY4co5WmZp)~uWj40joXTQ(wkk_szVOD0NlSV+`&!Uoy)q(d)$uA z+kAc4J)yh?^xSn8-AXvq%PqgJf&@s=0^7~q-R<4q{oUD(1hXLC+1&t0aNg;S-s`R2 z?Y-XOEeqw%-t+C=_5I%2t$EcoZQYzr*uc`?`^~m{eI^19-~oEr@j~DQ{+7DNBpAPDwF6x~w>50zhLFn8@qcIvIl&$XSuMX?6F6*=I>NE_%H#lzYgrAi|d4(-QYa|+noU|knGE@?9I;X&+hEe4(-z}?bS~0*KY0Ej_up7 z?cL7p-|p?<4({VF?&VJI=Wg!lPVLPO-YqcR!@kQb0Ppio?-=!N@Ar=H`L6H#&hP#1 z@Ba?)0Wa_aPw*H8Z}10?@CmQ*3(xQk@9+-~@ewca6HoCN6>squkMS9=@f*+a9q;iU z5Aq=|@*_|37A0@;Cy(+euktI;@-6T3FAwuEFY_}OPxCcz^EZ$4Ij{3O&+|R+^FI&t zK`-6>kssaWD6CPxp0i_jix?d9U|-&-WF5@ArQX_<=9@ agHQN{Z}^9g_=&Ih2#e47jsNz60029}9H + + + + + + + + + + Without CUDA Graphs + With CUDA Graphs + + Launch 1 + + Kernel 1 + + Launch 2 + + Kernel 2 + + Launch 3 + + Kernel 3 + + + Launch Graph 1 + + Kernel 1 + + Kernel 2 + + Kernel 3 + + + diff --git a/docs/examples/te_gemma/media/transformer_cuda_graphed.png b/docs/examples/te_gemma/media/transformer_cuda_graphed.png new file mode 100644 index 0000000000000000000000000000000000000000..cf22822baff5b8c19a377c5d4e0d23d3225b0c8b GIT binary patch literal 369694 zcmeFZXH-*L*9NL58bO1iD7}daBE1OGQL!M^3W9{*#DMe?0TGlADhSdM6jVy2M!IyQ zL+GI>NDV~@36hX-SMZ$Aa?bmnZ`?8NxPQJq1|vH=JA2JF*E65lRtVPBzQC}XefPF) z+Zfa@s$JQ(jUKXX+YVJaD7d0K)sF)&+nujmII}IcopTnv*-4s5d69cIOm~9! z^mi}bbl$d2@&oN}`*8~c?6z&2uhi8}U-N*^Q{0~)x}H|aQ004XvtnsU_Q;!$a@I%R zq~3q-%r5odV*BYcm=Tw#(+3|Xy41u;dA7&hwQAsGjDGpYgB>5`E}Y&utvFV6d+kx% z(DjP5{fncYD^a-0y3s>7aty8)x#d{vQB*ef0U3ThUNW3f14paw_MfhI??tNa*bfeg zt2=(Wd{JS3`+nNpb(;Hsy3D)qZkHM(?RKeiKV80GSj)|Ldh6i_e}A~z)_t>&My8L! zejZ^qM2XidGGXh}mx^vFvEqzBJ@KUw@qUn!{U5YTk5>k~;WetX(?5^LxdRYB;(L19 zTh(UsHBWlTPs6H~n(Yf3nBd>KSupbd2L3-){~x%%o0Iwnfs}hrBg&=Z#j*wUZQTbj z#BeFA*1Ns;ppS2N=5D{VtE@wVmF`b7qrG$=LaNqr%c`H4&xaD6KlSWD%ufoW*$?+k z99qfC6uLe1^aJdMz*SFPZ#wAF;fTmIX-vZUMjiY@8n&cD)La;A(ZLaEMSkGQ)>6v! z(5N!?qi(VO#SAxwW03xjU+P8$9&XwApCi8H+LH6>wf79UY!nz|utoGE;XT8_A)TGO ziwf^z_t{(S3S!2zn=@SUu#K5$%)2cn`+i={6xG72!C05vFQy^%_)1TS*b{QK&eJ@j zvmZaWO3LM}slWJXM%!!yx0vVXA-I37&)#d;7VSE(VJLk?%u-QXLpM9Cn~;isBBPkk z!((?svHS~@>v^qd@A@rW6a+j>I}+&31$7)96*>aq=a1TzJ$hQ5L&;yd#Dh#ZNb?nc z9gp3%7)rg{G`7FPsf=NCkmT3zRsJqhDrPD*&2U|#K=E<>!9y;1jK^7LqElG#ooQle zSYCj`yeE^Ty3&gVI0_-`diCj1BpAXEXW?YF6m{k&j@w4#ZQ8oX4?_pSd|+Yc@ke*5 z+%?p>>>ohPvKnXGf$p~sL$?nTPhXx4>^{b}rdQ}kf7MM2(b>7)W_=z#@lhtZ*j}vm zd6QQ`zh&5@K6;u##94X>e@-%4KAagXW0%n6yS16m8`#0HRb%#lTbx$Z6j_h3m;*gcV#n|8ymor zE{}H8rF$&YTHJ1wfm;%@BKSmjo<>^OqvFCUbavDH_Fu=A2Sg4Z6~mlunG74JiL zY6-hW(}lc^wYbx-D|xiVlO?zDKiFHe2SX6EweA4?#@)_#glb0zftA87?p6!8cI?C5 z5cqT8dC46j?aS8=+E+i(@3ZcRodM7}q;ZphSRQ;F)n6jpm+$qtc%u_>F{YR^#3(s+ zIobHg$TNO-+cS|m`_H(7j# z%vz)(8^a7=7>X=C4!{@;&-5M*+lh8|$WQbgX`a%H8y<3c>HXyfw|SbQp@ZanikiuP zSQLX{=+L-G6*h|bOrA=CN+dIypvKNA{k~1FXWI)Ugd+2GQFN($g7T>kHad28QzEyK z?@;gYpc`e^2gQ7@jU})z8)F zx|79XkGreDZ7vvzCD!@hv=$~eqHyMtCJu)G$a*61!@(ZV?0L%ph?_mc)hD}=we3?f zHr41deh|};ZnW#g``$pM3d6V~wqi%%!_6$A%kx-XVtz=OBFb&?oFiJv#^a$)5V(T1^0y9Q@HThYlwa*QW-g#);@MF ziKD#vQ0O_4s^i0b2y|##af=3$t?TKXN%{jF@@M$C!pe(bZ*XeGWPtk>?mrl_E6mlG30}>3OdS=ludh|`3=aZ+t!Gmt zywwpKCTb}%XSM><4?#zm*WU?6NDoY&WJY^4Kdi{r$2BFRrJL=-Sden(QQnq1dWnPv z4O=vB+$2`6kjj^Iwt}esa^0-Uj_cg#VwFqy(&#&rf{XHJ!>u=u?jaWcqXW1yAv3^? zlX=fzGUUOKR{y~8W;$1=vyk~^NIS7`S^wH#d2xMKvRTn|-b|4IIN^G68@Dd-y=$=x%hXZ}rFqxJ2TBfUEjU+LeSxt#V%i=T z@hW$N_d2$yTGaVn?#NtkAA-!+X00qsu(Z=*u~hG!SUQI= zR-aF?s2+{@M+bAe9NKHf_A-B(@D_rM*YBT@VPsh_foJ36od$>S9R83fd*+8&-WP^! z{*x|-$?FOIig|lW*+UF&yexIXCU>p?8LPSwa0JywDcXYm5S=e}M zLf8QI^6|gu9lWXznQr9?>f!yo7QzrWVmTag%N}A@e=|}7*gVAiCb~L#Y)0lAD}%kO zDaS(oNO2o$3(@mgqn&Z-<(^c#Laf=l>2Oyc4wMY{9(w|l#>hby z%%NSVmDnCqsK3#aS}YbGn0!_Q~8 zINI%d(cr~RG=LXcuEeX1t_zQ5SH<-;Ghy42`(+4~1*aZ8ZuIHz`aBO74@PC*7|75af1~GbD3=x|b^tyZ zoKXp$n$)oem0$P3Q0b7hZ~X(kb#p}3D?b3|{sbDwe@vF`nmqYFDIts&J%tfG@}?s# zDKu4^iy)7~UD>j_qN|eyhSu-2Di$x(H2(B8J`w8^Y*peDn|X?0-oYdtgV!iE>G$he z0UiO?-_QE;eIK@(FXZ#(~Xkz>qnlI9e-;17?T4Jxf-E6 zKgpi{#u>?NUQW0ZgtkAkjNFnn{Y^`EgQ#;Ft_`l2O|YPG2m3r>XE+p9ykE$E+E0H@ zguA>Mo@hRTP#BkpmmI*WDH!uDzs1<7_jrnqPBti{u++cdL3#Se=f@DKVuJGs*`XwJ zT50Cbqwbn4qDB>pQ~=`H4`R+MT%4Bqpx5)EKq*J!#xOKE)F6`s5JuWu=JDY<-XW56 ztzD|*0eixgM5h2Vv3^#qBrKkFBFFk{xhqr17+Hw9yqQPKC`QIDj>*V@SZbCdutnk~ zD6tsV(xMjqV%Z?9Q<})x@}p@+Wl0@CXneSmP9c8X@1~CKRpZkY#VQ5|diM^;aA=hb zPYG#zmJi)3^yoz)?)pYwLms!CUpV=<3FOqM+Cz-skOF&u7dU39PLg-JIN1X*_MnUy z4H|oi>qSPV&E8;iH(s-d4aab`$qfB2wDA?3*j}2@_DG`!{h@0Lu(ZAk)GTc>;xB<6}fT9;p zgrC+%dYNYv#vd9cT_^|#7V3c;WONt_7+<4?972v9tDt(&D8mU0;{ERA1sc(zirKW{ zDyA5q7tDI$+G*$!6BOz_4)(}8t=FzBy*jy3KM{ZJaa!{s)E12O48sG)CWFzOGVqq* z2AhZli^kE6U)=r`W(B%vH&b5Jc2V?pZQ$^C%?v}=LAozmdhLX>^m~MF zwkvB$e`RQZQiVjH*K$oxd-~w{U-2AbXw*>30BmoRW_uu9Ss0`lAK`XFSM*?x7JrSJ z8LH)Z6RCLlZLQ6_X$j^v#YxvhmVO(@tbHoPKs|-X@4~`8M8w!bF*h!8APs-b(QeN3 z>JB~hgAt2a7yN9|+vSn*7vD3am59_PH?Z!r_s_P+3GGJun3qGm>|OQQ{S%3+4f?qH zJ!ttzD~<)vu;spj!=t72^-GjPLiVmVhrNP#7I0&2-nAq&dEtjT@(Qg1m}B}s_2q_j z^`Mk=(!yR*&FZs^Vh_RzwD7&bhM2$sCosN^Yl!ZjDdI+%^BN4#)kYF|97;Asb;>7$ z0&eCC+DVhUPzXV-`GsE{p=F-Gsx2GTiQ`_uF8>i|`FJ(_Q3WTUN(GOp8&=3NKGP0| zRrcucfE27%ko#jy@+hlag~#U`(_MGc0Z{{ft}~Q<@vWlDtR6u*Uq!rr*&|0G-rKAr z>bvwkPm~<^XrMkMeE?2Ax}V>PTv2=^#`^~6a39YCaR|?rF2{doxS7!)Mt`mb#~K7^ zTFsE>-2T6We`a#32_BiSTo~8gF#724IEgV$wSl(Gcg%OB*NX@rkQbjKdJ03*Uwq3G zM*5VM=Y5jVr77K~CzNGb6r0@0>Y&6io|B-vF=`BQwyKxSpNbcN3WE(f#QFa~Ii1H$ zJs7~VGhJ_lFG$9(4B?LsZ6V~fJUXx6E-Zrz(jWj@dJZI8{mmG?gb-%5c~hkK=i*_T z=1Cc2$L6cUq~JiC2d9BQq5hM!y#5tR4V`}KMI6nJ)#ob;<@Q zRY@doH}uCAHyTDh^4KLu5xsDn=oH@7Q}RRGNAr`@>T?)4(%)dlX&gs#&XP5EBxbU} z2}~`eAessu>R68H&r-S&FJ=Hpo@uaQJSTAc%CJ}Abg1PhuMZCo9D7j2esVdTQYIoB z?^M#Jm$-5e?GvJ5TjW}m*JAv4`j7$iv5kcjCOTK3A1p0yy;*U}|GvHd-Mkp!zFC)e zlMR)at;u}7uUPCSgRClT-o@I6{pkrqJ#hR7(OxIosUf@2l1EE9;jSzTc@x*CY^Ec3 zqRof5;tjPPkawQH3P;Hk_Mp8q)wp&n4<)PRVm=j(+T=}3HFzbjl;u_oT`zp!tM1s$ zR0q$HhWn zyX@sp#on(Znr2TLmtzmrA;@Whk=4nW?l#VPTa_gscx_< zORGnWv7jl)XM*tyt7VFC%isVdZ)v>8=VGm{o^z6>u~oW1-YX)X~Eh=HFShd z!R#UCRuuaef%>})&N>>xvmtkh8}))wST~kT(a<65CxEyb3}j{O30w=rLHr@KqP?X= z!vo+MLGrv`dbD`vAwX);B8W>b!$HVxSC%&|ODahT1wOlP;$&5dtL?j3hp-}1^A|@! z-j-ESmoJJO)t{CHL3BZwZuy*qLHU%9T3p)A3~YH9rvTi8OXbPC%t$M4v5nD;?hm5A zzoY|#pQwQK!jt=>aTQMB)g??5cS|Wi>*1AkE z3Cj>IFYa|HAJXSrHn1*qWfFnof(f_mJ9u>D7S<(KLul?|l=7umE3dc$L^}TV{qo zgd^%%b(98L#2d=mPPY=z-45;Q^HWNZNL~mF(5v7#Gddk|+ubKyq1z|u5O5cTWQia{ z&+)@ppaPkOiha3FFPTwPz)E?Y;A8yg1uq7YhTNEMzd;%%l@HJU7Uy!F32N3DgPTVsCGO6e3x_*<_-(lJE0L zV`1{gPeqB>YX8GhRF8^c$6?J6$7McQ8{Qg*@|$yABeFN`5pB6=k9&xfVaZgkXlCAH z?|M*J?=semX18}@^=^j2U5&1>4Tl{qnYv!7kIQ+FJ-Fai-h7+N zJPacFE@J5I=s+ruZYmFr-ZB)x!#gUmFa< zZj8qlYE|L>dZ$C#8N+s>wwI>*z+y@mzSLi(UUWTxC>e>l+OzuP>)Tp=$1H)euwtS# zadr9A4V7h9&euho_F=n_Z0Fe8wmCX356M&!>(+0cElQb?3p^Lv zFY&y|6}mhGuL`&185n2@arcB1ID0^Kr(fcrV*ik@?DWUo2B)F-z1J0) zt&y;LkqdV>NjGKf$!#qj54JB~87^I+)eRX?l*_j$$FnM0)mS?`8?#1^m4*-Eq)3-L zx}G+<84Z^f0O#QzS3lC;>22m~1u{1`=3B)nmSHhz%}o2?)5$i=@hb2MgxIU_3a)=t zEfk4;z-+DB&VW9xqJqO9c%_n0d-}6|a5hW^b*(SO73Z{g=uxMeJsgFM3rP~4+(+ZR z+%WJeN{IwKEr+8F7R*27#uo0gcviMGKRu`MVtjrBQW!hXLd1;sWL9LS1~+6WJv-`C zYzO@AS>T{U%jP|byB1CrWlEe6#Y(Vw899*^*%u_i8`9Iy^0PN`^@ZDIjzWmUUn_L^ z1E*8N_CnTAe&epQd)$sXErs^$8@E3z-ubC`{-UKjho<`}KN789 zvki_zXG$^0eyUtMVOS3NgUb{J6vb|HfMllsR?@;a-dT zCKl{kw%@^U+0Hn?3{|T??5c*$6?(gX#jvN>BYew@V$%fEyRyQ%4j??!>T&uZf%>S( z>U64LDL%SOcBGk6n8Rb@hQmjF_j5%BPwnjXpzcKTp-X{FwtK$zPiI@likZM6mY__5nn|uMKy3H^!i?D=&BEJ@Vs~ z%EQu`^8Vx%FS^e2GHzYDxQ7mMTKc~|Ai$58eO~6wFO9@2Tncn*TUWaOs8MBPaN{`q z%YJ}25nMa?wyxN2H{}h2mW4Aye`*chzt*_-kNvbO_Y3cfGphNIoCkXNA)o47^Id*h zS5ysDS?SWM?T-Zb{rGqrZ!ZY8x32Jt@WOU(Le%yrTh@2)gZ=Z_T>5Pb z!zEXZy$R3o0lZ}A-)^t-D3oGjWTP2z)2{hParE(woo}vY*8DFfW|69dB--HmaunI< ze$$sXrrko-V!m~sDKF1E4y#hT8dc=lVlQz4OTWcPKZ@;qX%0hiu%NJ{^~10)5VzXa z9iCdOKYdI)CFLZ2yx%ks$5W%KY$Dt0(!ay6a(#*eG5r3Cz`s0B)4k2EW-qtmO!LV8-W;$P-HTC@)j`V(c-fXU?)&N*P!uoUy{$B@ArQflC$JQ>@pQKbN`UPo= z+$uh|cJ1ZtOjbfB`eIXlZo$x|-;nXL^M;U3}r(O7ipn)kOHU;$Lby)dbMTtN2#>thX4y zO-<$f?_c?Ox9pa4tZ(>X9p4tqmZH_EUu*Ztd#%2vO|T_aEbzb66n&$`Q~qr()D6!h zW$&SNG)3+=tyq@cO4RxLTrA58@b&@%X`2zTCkhB z`z*4m!OXVEPd`!t>C?)&P;4Sikwlh0?H1L@kKVIi;t&kJDfRgtTX~?afwzcCc+uFS zAbo@L^3y{YGJ^lsSlUT_0pi$xisY5ex@2z1(8R4EeW#(ubcybg)t+~Y2{%N`*oq0U|6mS3{|gbeUy##;74qm_ujvP)LtdAxwK*|)=8Z1IFPAIw=dvFK zwz_|rk>FDCQOtgEBqpXYk_S^++laEuw2voJsKh9F?~T`Z(kY}I(7%J&T;~&Wl1HO@ zCVam1l6#wsZ^?9kHn+-LxprqRkEVSa)0lSRsTUKC5rX8}uN01O|FM$PHnR3iU0!;r z%j}f_`Fm^Rr6hcN+{6U}vG3|-9(OBpt6TiSq`G_Di@}bJ6xY=OiE!#C5>b4<0v}&S z%5G0S^%j$<6MHcsXv%|wwCE{f=-VLJA17sf>E*4Upf_Kt#C$VdcZt63SLwj*AB(9!!jv#+Yh>|xtbpt_FFs&~MG9P|k8ejaiQ7hAR@yliFl!Q&u zTS)%weF|^IGfgt&95H!kvRJzQuDF6?j1Ols<5E!sSX*H=UGmCoO=Lz36V}OpqBr# zR&~sGGpRkXD%CqMau$`Kb8bEedJGornQ@!dW}PHedUxe@dr~-6;{XWYFS62te=|Pt zIp+arjoKyFRNo+^KL=clabcZl)ty6JtOShR`K=hg$P^^eU0jj zh-qqT?N+Gdr@05hro|h?t4q>gJuX8J*$=(sx;1<08dZ+H#@!!SBNb{?-9804Cmn$We6)BA#`PbSLsi`YZ-ZXnadZT(Fg0o9@ z>|QO#ulrVwYMwfHpzd1^jZ@Vc&d!TGhT5}H4%#ytXs~LYKz~8ceQ1(_*NahyVXM;s zjI%xay^+*JCWCh!S#i9kyqYgn?rkRRO)5=&HkR0#kdop!6z46>-j?U{waGC_=o06)nNj~dl!ZEj4ie`9Zoxb7vhtuR+LwvE{w*{a=tjes|tDAd#p5-X}V3awtpfMh-yF!+A>lCw$e%d?(V&iI7=~o z+3jqsS;D-AT~RaBd}_0oU}5Z!U%0vG-rTtDSK%yAr#(J+3!^pao{#hjsAhA zI7m-99}VB_S0Bv(32mBOuS)X*nD}h%yX8!IV+r2`O?ks^&~n^a0n@EqH@L8{R`opI zs6OD}=$QBAO{CmLc|852M}8SOidjqAJYDNj0#ZbE^%+!#N(u2p{p&@cxN4mVuuvd{ z=6|BiIzwJ^86S3UU#ma7Ll%p_ENZT**gdfD-HkC_0b{`RG1u5K>s-;;EfpjFxC<7ug_Y-;&|1~T4)5vpx3#IXHP)|HePLc>#SV-* z@Ou;;#|)KRiGSBhva0=%- z)suk4J%AgxtQ}dsP@}pGgq5J7^&E{FenPvg>V-(|-dv-yMBT>}xs1sFG|kok@yDJ! zo6*x9bAgyZLK#D5FWET2k8!`oC1;gZ*nMFmcZ0?bBPZz4gK4Q6)nnE+tW7SCr*Mn0Zu;%Y>P3fQNKnN0vx%Hds62s%!Q<7l^WGq z2gh1fe1BrsKI@2?p5SzO4d;Ke!Y< zbPdt=#~7oh{w{Hbt&k_ruFATB1wWc&xgEdXG__@HHYuqe`dZT$xv&@|@7ukKrr{+f z0Dwc!S%vcTGB9ygdPqGms=6-{0F-PkGM`h_Qxbcpyc?5D=uA^ri%*%RK50MIET?tT z^9FF4(Ul6`*ESDXvo_t^v4*RHZZ=$)6bC@ zmCa$r!^f%OkK-DJTW&dQzB@)D5C2HeRe?>1&ARv>8UYKI@Pl-Kc|-v?0ImVRorwgY z(Qxn)`IISB-gIX?F$dr7zK`~4aICoQtZj{|)!pXu6H@v-0e+o2V4j(?4k>a7(%6M0 zvLFC~2F;w$TXR~AnWk!rvqB|xSRw>q7Mv@Ko|N%D_z=|iL%%;J|nVJBlL`#FdI#O{4fh>^qK!r zopj@vgRCzuEHq9&4@bXWL%T+mRo%JDq7KQjx>)=ntCAVGcDwJE86*Q+1wp5L0@ReZ zmGghrGM8R}Q`zlIU<4;-fbP#_1+>U#Mq*=!N=~=9a0(xz871ie^n?DOnoDF zXpP-&Quyo|5Ohok(0aqCvl@U@RlWy2YD`ec@C}Cqo!L)MCgkr)6KYkh9;-y1<%#1n zO&y>S?D5`{va0+(PKUFPumXw_AN<+x>L2)dFq#np!eb2ww7TniMI?BMCfu{S1U^WXxdA5e3)ANv zJqo+#MGN0RG(-#E_-O*s!neJD&V~bE(|-Yz^o9$Zs4|{TT?MH5lf%LHTtGks^hJv* zz^nR2*4L|GHo-AJF{CuJENrz!tbJz8I&w}d!0##d_lZb~fCWXhCn;oJ&D4I@WWpAc zyv5hLsUpw!l+CKaOKGYCPn*5-#5C2t!;TtXQ0hVg5h!g@o}U15YlV1cy+c*+!>9{M z0{r{J?n1ziO>h3=j6#{9P9(TphYguJL8_m&<|_ug{wnl`chcD@@AM^zR>Vo(`xiU; zZEM=RBlk+jJ-HKv1;-TW9#Ai_(qRVEKwu>=YuGB+r2{ep+Rl*%(m7 zA*iJ6Yn|k>nR|NQY8x=^s67jI+J0Y-+w z7sjeUpc(%?+Bb?4*S$?!!<#>cV#f~u!)OQZ)%trbT)n(@Ny1xrb}Nic5`YXMo{H;^ zf3IbJ`m&tP1{6&TY&sLO+tP>mlD^SHb{>Vf0QP-j3xe8Kxy6k)6Ql@n-Sf}yn5HgO zl1FWQGq^y&sTQU~bJ!r{I0JwP+?QryfZsTnve^o&4lA9M4%3F@#Cv&PHI2R`dOIN> z5WVI5pg$M-5b9^&l)SEi@`}@$)5PiL%G9|~ zDX$|TL@n0iyr!vkz+$wd1Yy!7e>`b@+IqE7%R0$V5Lxovhu@I3sjy7tjy$gLEQWg_ zKLF@xf0xmJlJFucU3%A-{Z`&yH#XNN*|hW>Meo3xI>aA4el6>-;+~%yqlX;h{BGut z>h3y;oPfEI>MkcKxC&>#5J;FRVzzEibbO#uZg8yP1Br8!^{L;P$u%_jj>ZBrqThq*V=ur2JPFlZU-fv7 z*0_{co&&`EL!&f#T?df-FP%qSHyUq^wI?gi?Rc<32k`^}$EgKPmR9%Oh0NmfWe zn~y%SLgNokE}#EJe+n>{xLMVC-=LIqBIolzpF?72fzDt9R z+{{&@+7B!;{_oUm6X>VO_+faLR8?+qrdGhLhj%E*N1odN3*8Ey07H;o&0_gWeSxX` zsuK%nrjv>`2t$^!fNe{hr#-av6hUR7t^FC0WMNupjp3)+Poq)9dVLJchX?}jF{QbP= zl+T`%J}&TdmG{)8Dc&XxKM?h64S*D=&=YU9wX}et8VwNAPx*J?Pnz;F_$ZXyOSZjO zO6&)5rJ%JqOnRl_t2cYr)YL6?V1Vn5qGyk`xoIDUMN7mNOPZcbt+GTvOr07gBpiT2 z&LPX=#MFIhuw=CohaH^X^?Wk}Yx0I3Qc;FC8*qm?BQExeKdul73i_DKHeXz;N|X+k zrs04c26@=m_!YRg{ieLD>SZcyAG-qlIQ>E0fTuq9h??-two}RUkfc*R8EOF zY07QD0QKJ)AT+k}s&IUJS{D_mTlPXAD^}EKH<=?N)(s1OyL)T|uAO1quby$ce^16E zqK|L{F`wy`+|mFp;vCZRdF2p!v`ypK0ROxzi=}!DC`>*pByXSn=3O`1kV;>`8tdfy;-QqS4zGawUwKA$ppXB%>KBaZ9 zszX`C<^XIM{=I?my)W`w#ISf~WwOQe-^%ZrojV@T{qb8dJuT&@DBi7B)4ghwicC`R zg=gxdxKROv1_TIF=E#k!KKLcFFhpw|8d8^HBt6Y^!}*6Z8e+V3~qBh zf?V=0O`exPX8PZX<20eQSZBU1PYj_QctG~>>+yS2g%?GC_SZYqs(z|ngZ|lKJa7vN zZklE8OKvo`@ZW0Zv~J!>6upA;R2UPT{+ zmb+hRnEL*%)v#Z?PjR2>Ls_;;%Kz3w%P6bv4*;S0!~THZ!TNinQHysA`TyGaT3j3R zMd7DAaI3SJ_Ns-BAFBL*HqrGP{*rgQeEVvO7^_Um1X1cQ$KQH`T6CZS>-)cGIJRX~ zgR1zXB3F~*v8I1(WM-7{!(2ALIf9lC_?H>KNC4WYzb)(IzpHBOH@j}b+^4nELBIE_ z#lP6dR-`2W4({YR|JOC( z_xN11Gf4^A=V|$z!uH$zf9a0VI?h^jO^Sz`sQ=#crJMQS2N}w@49&e&a)9_R#L~J$ z%&7E>&E~%}eXSnw)442<*T4I(%`;ltJN9quWXAVf^WFz^GXHv$2ks2h4+a z`0slHG%I~e`M}Q_aAhuH|`hx%7d;Rh;s##hmMLDOyx&zdge_N@acT)IylZk>+ z!&Sajw61nIhwL0s?|)|u&?i+n9wm)}Dmab@I5WooOW$~_=ZK++@3XPo_}^9hm9jQA z>$~ZFlQh=yH1|EZD%N#!Wvz5|#GQnW>!X%;`fidh>7Uh+(K@g+kY_$rU>!-*ou=3k z^wJ0GSTSE783%skdf*bob?T;rt<7g5^(zE%xHtUXFw!Pt(|xvRyK_?b+ty9w;Gl9r z5>Zl|;))pCv^&%O3X|z5qGMB%8?9Z!xc}-oywXvR(T%Z92E~%TWd7a?r^&!~b^(KZ zOI2b?%CS81?+Vou>Gh7{8 z@<}%$>&OXUeRBqK&Q)PwTKaCx$;WOk3wf^$>Z$vl?TB6iGzzl+iw2`M^>*=0gh4zD z0Tt-$dl;E}~eIxKV!X$2;_U3-d1 zdp)rbu7c=G86A2dR;7GpeDaK%<{{KA>S?T7Qs;xQJDpF*d|#<-OlT~qn6xL%;+1*> z&h$GdXYZ^U^dOA-4fa)4kzZ-2$R(_?n4%<{VspQy+^3&M- z`pjqDn4{=(mV?Wrq#y5Uv<@rk~`?pRDv>FL=8+!sjV7Mm#?Z|6Z5R3l)1d%l51ZglN6VR(>DpmXt=wml=<~3 zg#l`F)CHuHhH=DRkj=j1- zByZg6XYxLGT4G90em++EBLQDj)nXC)mW!tv&;XXX+;e7sK39$7s z5%Po;Lqh2pZ~O?#4la){8E=fwyH~N7M?ud__F}82^+d5byNW}Rv$ABrQQ4gr@eAzo zt`9MT6)uAJKpp`9MtmxDU*4s8H$m{um{+c91~cmWCt}eo5ktpNY1OV*_g_+ z)`?`L7t;k%nPN6XvC5uL>O5KKMx}!Bfr76-PVI>ug(v+ahJi@6poWVqXvYJoB9;~?Dy?d zCY3(a<~P%Bko~Xk+n*mC~U~Ys!WOm%NvZIpp+IL+5hT z^vSh;j>OZ$X*dyzGJ{{`#z6}z}=7{$HoOFWA%^3KbJV<+chrI?f2v(psgXz3$)w zc7yyXao{} zqN-Nl^=vlg-B$_4BP5A%IG0(CDu4En61f(j%xe|sjrR~!!j8LL;Md6$FCY){-Cy{fTzW3A1`v10kdhftN}Jn@t91RaC-+GF!_Y{Za2 zkYV^iwMk#*PQG)f-p!CedNE2hLG<`v3_539nXn{x*8mff>mDJCo{;lU7H=%~Jiw|V zyl1txjmhtP+pS69-m+-!ts4K7t%|&$Qmo*85Wgykhq606j8^zMA%q=113@~L?xmYv zRdI~juPl-gT21}3YMFu)TUxlbnAmGNy1g%7OFMOB>{MERtSZ7xVv ztTeu9VkpugGSl0E%Hu&ZyMl~={&C=nGE0SHNyi5@T<&@`8 zoxR!pq8gXu7o@Rv#aOs#M#@yf`syp35*CeLdvJ(^$0{WH=^=fp!igSx?p4-!T3V~q z_DDU01^DgFBEOI9$uxPkVqKLbGRq1jc|-zQ{We$OKe{t(b?&MA6;as1dpWqyH+gGV zP)^lwu4I?+eubYvf0FAj_d4a!J{LrbLvTE&93LuE#!b%qxSxCxdOAwar~(>~oU6_B z20d?Y%m(XBsK^Cro2Q;;4qc6fa#L*>|JJKM3zHV{8IiZ)66TBPt3q@rA<&I#^Az+Q zWuNBW_;{5i5xMfD-Lg9Gwv;jNT@Gc2j99Cb$i%!$JMn!sGd<7Ft(xkcGJWB5c<^E6 zPN8QCPDVU%Ip*uH$bt~@_I%zdU4)8nh*N)TIna8zO@;B>rq)octito~f z#~lEjJNoUf&dB9?>`k;8vhz?N<|B#Zp~3~52IUgoyUyg7_~VLYpmvokw}kteR{%P} zt(%nP{Cw-2E5;z)u&ioG^PKeeymBM0R&~m1hSU~Fs2LcDR47OqEKqiZOPGQBDlkh) z)E~~3pGpgNnRg6Fu1=-MYJ@C;3Vqn52To$~EakH?_WRzemrT&1yZTB_$~VU7Q%?A{ zBf*c&Nz;u;f?0dW7SCsZ+WUr$1+IOfeeFEiATa28YrNj^Q>fj09{E*Pyrk~rd88G4 zye%tll;VPl_8T+OVzv1it%HI?$U3)P?Z>F}VV9VoXM9~qQO2BM)XMIF$VAoDibPG! zaEb?B$5?Lda?KpwevKZ#DCKVVhYmS2&&a0-rtOIdCdu<=B*`aB8fIPt4i7w?I(BJa zY^2VY&hAz&7cJ{_7(X%DcJoDlOl9Wi1d4J|ki9nMr0FxTky8b!mw3f*c(hZN9$teR=X@Z zh-J{!khIgI^#{0-O61hNY+35J>%jr`e>5hn^G$*Oc*YM~CHvBvp@kPW-+ewEpTP0$#r5${}4 zxpM{Jp_Zob0s4ogU?&cMy=h7y8*A30)Mxz`ZV-x~l91G^-uzn9;|@8++cIS_8wgZ|(W zal^JVy4Q4XGU+Bzc^If?c7tLI*x5e&iXQS1Z7$Y-sw$NiW>KTMq(ImYzYsSbugZG- z%l@I~?9f_zia>mOT!pcD0oczrHD#0CXtWueBL7(c?CLTFUFOPn1Bk;m9m9v(%ZfDk zT$m`!#plIuKQOB8{e%hforLrrL*Ms?F5Pe&}g%8jS`j@_j|KS?V2 z0AEIoog!$l4ueS8ZuJZ?kc|t7 zwL8GCW+*Qf(ja@aWFIwWuwFgn&**g8+ z=U9*X{RuLUA^CHYtAsh(>Z7kzFrREw?7kKf`sFN!N~7r@%b<_fRSuEM=0eJOSKwa- zgDoo%?@v7!^tmiUF#TDL@ArW=e`@vJnh%0@HEVX7xq2Om`!0#YT$0tNJ7oO=uTe@L7w@Id+8>HcPTXy##&3p_=Qrn5KL{ z%^M>Db5X6>-AdrQyY%5Z&L~?ro8rDe+?8vZ`p%Z=Gmu<)F&+G}lp(=!L79~L21i(F z)gdqg)BR||`WhMiLy^%Rmf7HXLHd$#7HIQ2?}~(N!`nTzs{DofiEL4n!nExg=&NE{WEmO)ch;z8D9_6` zvs&MK!2&wHFM$$)^fV>rN)zpGUjG2)&3yzC}V6O!M_yxfO>2sb3bn;-pU zs8W@8eZE=*`jIAmS(OogNKdzLPbjy~hgFIr>u$^1Q^^o!^r3=uBI2QX&}IzW(@DE|1u`NN;Q1YieA z=+l@=-Rg}FY3F5yxvTF^ytRAP72w-?Qp&~1>ZTlU@Dbm)8)SUGF!4^W^4cc7khD!R zJ3~_5>9f}NJdMZ(tKhVghzcj|ba| z!Oj3J)FY;hoPAi7#Cmh=F?OiIbbz07$ohLJEae zN{`nwW`LY?&XJD)s$+7qpX3O1SfDpoPy`ei{9Tc?@WE6YPG#qGlFqj}PwBpQdNPlE zsm-x898S_UA<&9=EwNHd8aDi3pOc#cp}Wul|H^g(G(MFYyk~j>jV{Ec{XcBIcT|&U z*FF5s;NU1A<1h#Yfthgx1rZbi0tQeN3mp=A5s?;}ln@9_q}hNO5s=1NV_UHPU%Gs z={$*xQkC@Y(yov~N0)|Q6+asd4l~YR|5CZap^-AJA97Q|#LqS@yAGj9%b<4u_Is`P zk9?Ar<4-_7Sxd;cyg2zrkguVE_2;(H_KuCT>$YB}O=^Ktp6~Y{&;G4%C{}k&=Uq7% zkPcm+9dZ7=^+#c|&S}SnpiXfO=falyYM`sxE;+(S*xq@#dWh)0EepgRKxP?!)S~Y1 zsPq7dIlKLlGhIgf-uDv!oTQI~aKGk8gO1+&s)VpPyP0eIs`p#;4DDXG9O2_;FL``e z7wwCQ{JyidAWGb255}&1$npkJeRfhNOb;Se7+2c78Bjm(`o-mY{O<#375+9|N z`0r>0p^!1r+vvD^CtRld+?LXxmLW%>3$>PeZU-U=dzIaXTo@jn90J)X&QI{sDS zl-JBs3z?*c9(J{LQwVjunUsACySEM|s7L%-2Yhv}pI7MIYEu@Ry3{bPv0>$J$;w95 zDFdO`iYJ3pC>-ii)zV=Dloa8sNb$37H~x9ANpVp!<$j#f0WQ#P&2XLaJ~^9rS)l+d z7_*tY^f!1H&e}=G%rvs=AziEWyKQD94Cb@#DIXYB^@Vu_x`v_`RQikHUeOo4lptSJapmBB$;sQ zir@^TKSvWAR6rB_=-y;MSSxP{pE@-e1j>yoSW7D|0a^0TlO@sOjlPBSH&F*#mrT{^ zEBDii&vpg@-6@z5=jD@v-z+k!%ep_1&q(n<4XjU}#9How>G0#n7c`vmNJeG*_X7&& zJPkAUZ)4)3v|t5$=BoA$=)2WW6j1FRr7C{x|NAfvBS+)~?;B_Sm>qf^6x6P^4k5mT z3>0{~m(p!jg4RK9=}`lc)LqtJtGE9-P4K4#mZl#eK>H(6_WF7C1tpWB3hX~Q5^9{_ zm(EYqr^)aeSxW+K%Hh;t@?YtW>WrG>BnNh=P-e}I)}TplDJ{*EqgnHOO=AR8`Wv$4|?+s%nqE`taD+4x@yOg-eDnPS&?h zoZznn`irg&cDGboeIT78hnn zJC}?48+a=^_9toYiLav#6c0|g6r8X(|0GZ%z>~6)vbSHe&zgsYtgAM>S_Q!L7Zq_M zan*o;frx|~H2ympf@}A>WG{j}A5EZ(Ni-4%ISjIg{HjkPcu;m_2#U)fS|&SR;x}SI zN(U(bq*#7~;Q}X;4`T`(bHnPKshb38-N|8;IM{A}S@5Omcv`i`hYlguAVmkg=kZ!HMdvQdX{mT15z0C| zKeyXoPIVhAJwWV|U-b0qs|^PO8i59_F+@S|%_ZF45vRO|KwZBfLHGSK(-kgImt)p4 zIPSg{CSo>PY>KUyOlQqYLZH5sN-+-YZ7feQ5P!$X(acSlNCal@th;L)uOX%3h{E2pR;Xkm#q#j3}Cv?M|^1KZ|zhIKa67esO|9A2r~^Gu~?o znF+rg<5K_|GpOkhc?Li(N&oK+ts+prhQsbeyL!dz*zWu*{_eLy zyO)MD+72S#;=_oq{vF5h{J%)$bUq5acfTed*Vfl)FL^j3KWoe5wfEyA57tItkkoxQ z6w=gTe{sLlti~0#_VtwlByFnI;A=!0c)TOZ0#quS;<~GtuzwyDROcB7{7jnLqDj{X z08a;R*RTH+({KT;-v!~~{5I>n1eP{afKfq&_Z$iC*u15W&5@=UIU7FsV+!((b1We}68zY( z;<4rk;?111J=wxZh?$#}^sR~B+L`oNRw}P1PT{eub@3#qK1shzB}Y}T5Ugxo&T31> zi`nj9Yy-TyYzR3GCG03IU);>|g|-~_CCOUk`qo!0yH9>Yvitd?-a}}Ehq7-BU;Who z@{A+M@_WhsdMJY6`SwHZB;O=g>#bGBTopO}4TbrJzk7uE7isa!%Nw~RZsP7-KcRFF{`21-3< z|1#nts_*e3kipbi!NI(&A3#XF`jYV-VdTusmvvLG{ZA~=U#U;fUdB03FVG(!i9|3! zfzyrEz>UL4`6H3J&x@F_vKpGD4~}`G^M!J;f)gFa8`>(qbobU&`s_x z#Is9neJ$cLn&f{%`a-?lBEyOoFg{{EAa)rtD{y|q{8D&S4ew>ufS2tfiZ5|6$B?Cb z|EUHl;*MeFc*#>yHoMdX$C7oy&&e_&@_1G!jV z91U4qMUG5aAn;bk%r*qwxDC!QZtFL(pa7aA#Go+K;* zDBXX`cZq#ws}G)1_7JHy!3|>}WT$Xx$B@B=h&Fq_L5YR3->O1O5h-xr_nf{;#xV&#<(_kIY9zPwuoLQl~I&Q4fXbI z4sUdIS0jjthC}#kDH(Ff2y<-(#&uErc&Q$t4(G$g$k{fQfd5xmNgu{`bvm!jufd)W z+Ao78IKZEGS#gVuvn}<~ODQiwd|G7yQq+-^&hzQt??mSLTPMKroey9eLx06cl4{KN z%0s0k%T=>O%R0T2>E7FxM$&(H4gJzP9{}jq)x@~&k*Sa|)!(E<_#q-MoW!B%w2}(~ zV2)$pAYBn|Z`*$ibn6-M;O{9%PlOr$Ct8(SRCdmu(v0-|Rd8*{vg`DGOUHP4_VmMB zu}XiQ8=q*7Cx`Qb>>IN}9Iq+g{%<1$NTTeXXEpL!f8!yQ;W6b6XP_RP1s&z7Rea}C zxCEg@=q1=sZJF#ksGcV$K@YPP`%*ttTr)C>-(b{g#VPN{)=Uq^u1w_B1K{~lM|}2q zTz$27(%#f`Zp#7fuI8|M?O&u_}AWr0orHki7!SQ+HxqSipL;Dtd#b-QaYz^L2*yt`GnMu zf`Pn+N6l>#YKABllrAKS#2Q!_!=UXkyY?jti6GWT=DX>u>A7?JF9{{NsUC$TXhr(F z>l;6uIAzxL>U#urO<2cjTz%R}u^A6T2-Mme)ABhb;}A!aETpvRzV0eFmt-y_C2NPcy5nJKe99Okx78 zACwOr;p-f9_3c)93n-TTX=r(%uH!`;J7&>uwHci@8xTttLnlx zbpWB@#c&m;U+c0l9z&FqV51`N8Z1ED?{EBfF?1pVL@g53_b+3hz5@vF&?f&`spl&9 z)V{}mfZwlzZ>|+)EAC7zykG0&HyVe7lzgi-Q$1(faO)px4nSwNGDM?Y@EC{Dry|I= zO9ChmhIOhU*;Dm8$8!OploKKkD5W#S0vjG3^M;smeeg*7Oj01)Wy3ea2v=5Z4u#r$ z_}MSF5*%gTn4-Ed9%q7WfS@tAL96xKM>+e*PxVc*jDY0iS)Y$io1^#cycegzSm+Z?AZ7=P%;3&@(qyBdn^#X zynlu7t~>yOFCR=uC1Ap%w`ZsXhaYm0o8|V9v@&n&!0RC)%JkQ!5{R-dx2zF7RyqS6 zXKZOF#$mi=sFhh#X5c0FU@6_KveXS~{`)5!fme)k37O>@T-$SaqzD3c@m6>DxQLd$ zpIxv{p`$Inm4hty2N~UtYlPhU>d(!s>Yj6+IZ$@}jxT;k9iu|8omMtpXw+ zsf#Gvt&dPW^QLjbW8@KmrTvGakQ~Eos||&~ZmrgaZqBqh_}^03KZYDBh=x!$U$z0*j)G+YpTxlutwzb-vOKZ(B}$d`cG&HQ<;dvoGd zNGA04S4nQ${&ii4(e$^IsAW^8u}z68!Pm5=jf>D?hrl#Y$$B^I_5L7%)-!!M7Ic;O ze;J{C-1+)3ChA%@d)lG>4KovP&AB)I2ql9(cZVjq?HR9U^}|6P%>UQI1po)9 zD4U=wy|j4u6lGQiMC?W}h*BbStP-VC#Gf46Xf?^>?}1cqd6rLn=nV4Kpp4l zu;Nx9VXzKG-z+{ro&{1pZ?`m6-l9o(r#!vvX^(fAttld^-k&vnmQ5S3FdZ53Q7wx` z+s-l%lA2+H_gN!3W4@hRl^ZcAOemB1T!@c(zR^P3RVjFit?O`TUS@_RPl9TvEJF+qAC2T?s17qvZ?UQ^dxDo=cNOwNrX z!{w}p0Y+{Hk~Q1-?ptA7dphK2en|O@;jO`(3qFQDZb@;*A-dl3jDyDrZ{R+|lz%Ow zC5gKkQ)gedC_0ehHPEIh$b=LGIo(I1LVjvhw=J7jjM3V!Z?h;XT@m%Mcy>6|^K0c~ z8+rJBL35k!TDsI*a$!~ZQ+h@jQ{Pi$*x_E^!8X=2m|@wpo!WyHBv{$X88iEN_hh8X zLf)Z7#+g=x(pCbyS;t(})8K>W^%c<7^qh3WS#}LWCjO40L(K-_XV>brT;E?cjC6kH zV7;j*l(l{17e-r?OXoN{CTiZr(C!V`zkc*DJ1LbZ_YP}+tLQAH9t-AqRatyfphv#? zv-MICM=uxuqm6yS0k#W$;J*jW7ki9Pe?FBi#lPJW6SbmsZvIaz?wf(4wxqL+i<>pK zUrC#BUNpx>QrANgZ&wbtE6qH`(@q6#jjV6aWygdKsO%xTd=;65p8P^LGvtHsNwFdZ z-;a!uD#YFS0Jn5$=C``@H|DNaA}+|MPTlkK0CS*YIi{6PdA-V~rF~26rnxTE zzt=V0#@lf1H^@fUYT9N$%4-Tk=d zbANA~%`=@1%O5YE?n|?2%n-mhEiV^PE*cBTF;`d zYB%!9+_R22S)SX&73)u;Iq;1~lfqA64Q}+IAA}HV7b2;wlES1j@s0CCWpRTb2EkFPKo2P3oBD3aAC~MzPIhygtqA?6NG*0+Upt518fWJ z^f)h^AyBTuD9YSBibD({uoU{m7vUtD9Ks$lNDvWiTkbj6S5o z_jXdu8w?~G0Y@*-SVTLt&Wvgw!2qv-d;o;Mp0s^huNCV_C4`|j8;4?{=C|oT=Cjd0S(7;8T~wsnIJ;_5Xm}7PeiuEqFKfp@&1Z^$jga7TkfxvtrpS)v&Pvm2hk}WW}F|Od$~}tqtKMXQD$E+9rh9u2nRm*4S^3 z9?3^SK_7<^fwC)DLo?5WJxAuD7*^B;BUlPPB-dwdPcb26^_A!D)OpQ(sU$Q4sJmVw z8e(zUD~>rRGsv-5!QtX(mG(|)*_+VfI-}(c#N9jU?PD1OOLM14n*EO=lbFFm_HpWa zS4wJd8_TSD##guZd#Jx!g6bU>sEXQ!w`mdi)!v$AcA^2N8uD7*yclm?T8}ckNwQne z-Oqp%#isd>1xZa5C1sM1$=p|7w|${#U;TZhfH27@C?uS-RalK|L#Q4^#7f|sn^BJlmN5VV|s*; z>rx6;b<6QVj+-Ko)h5NmGsv&_&=ku3=(xnsv;>M)A&cu{#>Ffks@{7eLuXWAX;kY; zLs|K^&x=c6X787P3(4?u3@_}RPRD*_5z1M&IkE@QL*qgy)Sq>UO!6y~WB3$rOUBQp zd$^@gKjic~28c<6no$7yD;csk=*m^R^DIH9Xb~@o{!3-?Fz6W6!Pf=_*5k>sck7^P z6@a-ss1!`lQi>c5y~QHWg97==*~g~XS9~@aRSI3<@>!CETy4!^f>Zsz?e<~t({7|b znxddZos1;B$|rP&XUfGGV_jOkU|O!(yoQ&>vQVEdi>RA3T{w6{O{LFghCp(){5UuC z%TpCU`slG)SMNyEd$WkOKAql?~Ue)Kl;?`#U*0?b4 zIh;P^4JR)*AnB93hE`IIi+t^xxaBHKtnEfVlJRDe%iEvW!uSnoFf*bKdp3>~%c}US zJ_bYF`6NzG9hFuJ5@U%TjJ4d?%+i$u{ZXgS)QdQ!A;jB*$x)L!BupJFb%ci5*6W1MfZ zkczWvNSCwlCEkJB9!@pI$0Em532ri+0BzNdUpv<#XDe>jVO_$jvNGWfK8NXEBj%JV z`}*8mZpuoBV#UH1(YEC`X4gnjT15X%g63E`tIOZVA|Xz(Y?nrbQ>{y(K{oE&9y4&Dp>H z%C)CveY}~nI^>2~c^?YLy2>Mvn{Rv;dHF33eg-uc*c}&1B%2k*(cVo9w;HGfnhu*x z)CezHe%5F~jdDhiFaijiU#?HYb`oTHRisV8RSXPmLRV=8 z!Moy}>kAg$Guv3dXbAzzV$0G3e%J__F@_(G^YUAa{4o3WhdL;xD<~nO%D!-h+v&FS-PoM)wFm^R*DS>*h$g8^ z?fIa#ft+86%1YSU!Z-N5}kVQA2M@*<%mGF^GYlN-F^gdoyw5RfZn z>)m_y&f^bAGg&uos_Z2sp(5)-5qKlbQU&zM8d;@&GN{Z+*#&w~TdY5Uhz#kzB($}4 zdl8`bkq@ocHs`AKl#^7ec5F4pm&#kA^Y@QLu9Rw$SgIq zdvq%xK5j|u*k5Xcwzmm$%yIxLC8xA z9F3mK!$C7o1j%(oW?r5PSzSCz_O)zp-7o5W-_`rgS?<+-lvM4^aRkxIl3A(2s=g|} zwo*`dAv;O8&Gxl@G^U}fQk0DkE6(g&>pBCmgCG7Z3udQ8;A8QRQ3Nowx|0m1#wq9b z^4hi)Z`Q-K%9ZQxdy2u@BhBWGHn$Q4o_bGi#>Ll4ugTZb*M-rZSPM{?hh);?q> zj5O=C8_9?6n2ml}*Gs{h=)3sj;I!1W8GLAT5dY6jR6bR63qm4t=N*Z=vcf+6CTlUd zW|huqv4~nsmUP8Ma{Ncptpwcka`+{DkSsKG6AfX8RKX$)cc161-yEgS9YYZA@!!aZ zV;FZ@Aot$+_J!(1B9x4=*znAZAfAp$sx&zjZ55)Ab~o7%gL;k_S-VEYFGQM;F9c3b zM$%10E+0|QcM_jpKjzKHqjp|KNvwLpmX{iE*7(;jYpz{|g(iKW*5Gt%O?cX|RF69T z^-cs$&l}bbuUR)l?tU0o={w|cZCcGK@3`{}yW|8EQF2XzULTqx6^hg8gitJgSj9a> z0#H*N1%dJQK(EbdfNBm6tef;XZ6Is&3q)D>>w z(N}qzh;ro*j-Tm*gX@(MgUDx5-tIU#A!(EN_?fvzTlR7Or&kY54U{T?|4{d4lTItt zr$rOC&pGQ$SaviGKlUvsBl*q5?^iYTDx6{w> zVz#I{Y2TvRela_9fp<+&YfKkhrDZRlsm>$}ZC*qz>|^yb8Ewsw{b{L&J^~fehbukW zRUPY?o4v&zDGJ(EeS0Xd06nskQcEG9i*MNJY1ENKuvJV)OC3Px;hYj3I)isnDi5^J z?OI#@J`&%*b`o|uIC~{6ee#3XxDPiNSI6ovUu+1UK6jO?F_(JP(Rw%Oo;$JO>Owcs zC;&d0MlG$g4Kfp5{i6gT%+9|pv-WdVOZ~Pdep~1iHp3c)Abv?e`Y)d1LKn`?$M@ZQ zN%2aSTKsYw0UqH;hyvYXs`1ae@VU&lN7~PW=>$gYYemlu?@a8S1|n;z zi3h29vlJZeOx+gluv{HN!N>40F=WmqX1}zNE0{{3_($zKAT-h*2#h-%MXG@@fP+Zt zfjd>_C#O7X)tvB_V33-3TP! zCpCv%V5!3bW)S~zCCNK5^N`?ru1!UE<5IVd>yll&NP2jBs9mQIgwQ37w#3I#%51VF zvfsy)($ZDbLO@PJj0>WH{yeS058%je4c|Y9- zHD8h;r0GD&)p?75(#DvZ`k7y&V7a}~v|Eh7%!yDd=SY6A&j_*WkL2|Yg`IQ-C`Iu?OqLEc;HjW%=)~Ok2-hNCHQ5a!c>XQ|P zt<4TJrz%3V%Cir(t$kY!r{}reN@vBTo_oD-m@d~>dz+iP!$ytYAjU=^v34n6-|606 zDtj7&d9n2|VpJ7523xwHH*zCX<=$E_oiNgK!~#*?dj|m-C(5AL7?Q?!c?r{$e-6FF4vH-w`oB*6T0A!>C89g+-*vkOZ-gV&U9(EH_@i`xe#m3djXk;9u51OcO}-`L|!L7O60W73;ET@v(011Fmc={39ES)MK~BC0quM;8eg zv5U|0T;!l1Q-P-!T`9&IlXb9|x!H;ZR<$@&UQ5&052umqJ>sa4`_FwMz3OvLQQmXs z{_NH8d5=Ve`aRdsenA+PK84NnskdYrzA8Y5(%tF$YU<%p$7}0mB9T4#P$IFt-oncr z7#&%chWxr?S4F_7&E`q?vWZ>72^O@VBs?-(0x>uj&#+9F>7FGzc?wW`T!&?6ev-AwxEo+10=Y;!JCrAu+h|Efs+q;gb;6nN z$K^EFonD%bM#j5f&{s#ubL3JYtH?1q~8lin8Kz$aQ;uL~KWkdSm}RAmd!4E+Xk zUeMlRE^`1#e^P8tyU?jOpqV4Yp%Q8D?L9TEq87ACixl)o#fU!D@|u#vFULf%iSn=jH6JXN5Q> z#^E$laB{|21+Xc+EIK=IUOoSbawKV3Bha+wDGE`&_bW_!%QK-$gErkz?Ch|M*MJhi zK}+CV*<#$OHvvZ?Z@dC*;{Q#)aA7`4i@lX>47Sw-T4sHd*(nl?)XLMu?3#lKR_|u< z&FsZnA^J%kpJvb0Ryf%r%7?Ct;wa|3nrX<*FCtxvmv(_l%zST!i(&Qp;H=;8Jgs{o z*N-tlukWR`WPE*5V>n7}`+F*X(l7kDdpYx8teAK*oEBO#yLFqFB}>(4 zxABKhmh9Zl_VJPYF%9m6j|0GPe5Gme;rqp8K#kE|M$4u*I!!3&4ikA=>Oo%?13h^b zpkQ8S%9G6}jRI7D-3+fBkV5ROB(T=)9O)Mv)l9UZ97G*wpD-bicy=GTo7MAr{xg#H zwn0vGFXPG;!N6OSX_Abuzx{yhEh}k5sPgtBPzzh=gzEMEo{K-=tSPaTIH<+PS_(Ft zsz^GvX)nfqsa*+WPOgEVI(unbnoyr*zXUXE>#e`zRy^k|x8xmZ1V2(Xq|l0IRQ6nr z(MjhQYD28Xk&m|DADYs(8=Bf|imGbCS&^RkaHtbV&7;D6>5fGdvq?>XwT>7%L?a+N zu3;rM?gUXSw4d5n$lY+YB#q>K@IQEmBGPRWk%IGf39T7z5u?1NA2W8uWUU;_94s4ah(Y8h%sRum{LXY zw-z}jV@(81qnj@R*MCb?j$vBdAI`nibwTu%?4fkoyBTw-*}X+as_3t8+s@=fT9CiP zMC?Vqb2Hj3ddo;|YJB&_Syw;R3^p#$SyC zN=!Pb7bp!f7_qQ|b>n@*-BREFB}+Ar4>2j2iQXKXDL!2e?m=ur`Swc|uFH&?JN+!h z@1Mw!F>)kZ)a79APT3^b1ml*1Vc17Cp4{1JVH}1~5~XBCcEHGN?$WPxdhVUzvIotX zl`O1Xf+n4lPId2)h+6b#^n2uVRo&38V>=*7Qa}2Qs@|&$7)=5TiL7C&y~gx^Nky4&!y(}PI*vG zhN~sv#HZ?s4m~0?$~m`mb=994IU>L^w4>=C0=6zHG~MVw%L0cU{cY}iGF<%A6e4d9Lps&0S&=V+-C5f{plg8ArFx{;g%RqZHl-_(vmrO=+5uL?`N5R(9(iSS&z+(Z$?>t?7p~;HLL9 zJ}NZ_zPF}5muN(pi@O74T5XcG;Y+d*=WAO1&Xb*;?epjK`|3-MTvjom%Lva!Ez5%R z3Hx;5{7TTu?@$u`?CbJ^CY_YdJYxYwo!gVR z4t4Q$m5h;(Jg7!wHN?`h>?neT$KXr3x87qxQUvnYuv{-JYlDiN2oOH8R||W4Uioa5 zu_9;(n{nXt-+|HHY4%CR_0{Fy+#{Z72Q~TjZIz_Tnb$liQ=PhZjw#M?>iwLN9za9upv#pFIcJJ-FZ0Ch-4NY8qjpJ;MEXO>y9O!i>9Y%MW z_T2*~{>Nx+#Tladpd14ABCa4L zFja5*00zBU&m{SA_jB$t;3*o3jMZdxJXJ3c%#*3>(nxb^ipaGxyA7F9E!Kkx`K)8| ziwUC;rLdPb-y6P)`NJ(!ronH1+av81 zBc>+OwDp7}{By_(!WkuAU-eocy;Ny|wshnfz2uHU3Sr=yP_04#vw(8T%5+tgRqa6Y zXU9O5jFRgQw;rA0m>NxZ`ke;WqvJu}ekOa5j?EoT!>>Q&bPtyM=@S^=XZ|_#ND?y# zMu3Qls5bg+fratT)xvAdxAl#Y-(ITGOA?OoK~1)vj;?LKQE02?eLnZc|IJijd(ga5 zH~yH%aJ6@BnEGmCj4LqryQj=jGGSumhX1gIxp4$}D?gRBM?;{yz+>>XeG1;rF~t_e z*d44Rj=@1s?Ql9S5m6CIcgpKaQTnL_>_Y!^oCeG`fwf>;=;T{|;asAGL_fUdDH$=*j|EZ~DP~=S@5OPy zAnY(7A=@#;0=fCnN8+uiNQWXZT!2r~bWcKJZAUpAUYB+9Rv8H++<=X$O0tT|ye1~y z|29N(Z7N$R+w!W|Nq!vM(gkviwwg>i$JnA{8_J_0%53k^!TP8+gh$yO!`z_SZ{dgiy}QI71At~|?Pe)L>_ zzU|SC{ohIrkDb)ACq#Q!Iv@A`^3sDdPRAM*>ZIGO_a>80q1yuOfw(BBVEI>=dPLRA zNZXqSRQ~l$V&aC7W$WRVS>kt*H22V+ChKuOL#w4~&RiH(F7xW2PqX`6K4228@9!#| z#VPgtmJqr67NgBD#mT|6L{a^sj6;B4>;l-vt|*0Mz#*0yga zIQkp>;7jb&&6YAFKNW)K`#VDy)^l3UD>=!As+V}@^t_r|WRKYbVcCiQ+uVZsU{}Sy zQWzQ+Q8J>oAiI5trQ!BD4KwAs`b&!p&Wa~PUf_MCd-`*}m_sz@x@#v>^#J)_Ek_Qu zHWr<)bg06>`9iOdkYvNJ)hR=Q13&r0p8vvsiP7eanZwE=TCZFs-3dp8FBwBL_aAy% zucj{2|Evh;YqOKD%D^v&lknYaSz>JR3$S*=deGN%9A~e+6!++`N5MrLTfB-Ne6d{m+Wfk#P z4KHpn#!XgHO{nK2pj`yT#q-TxGg3kl``Wea#8L!3=80V9uc3hP8MZ-m&~gl57OVn^ z-2IeogQ0JSa+-&A<}#LmCUHWKz@txm+<3_o;62gOqa}@%lyctMBRb7facbr?jy%c8 z?fZa1WlPq?1w)kX$6DM#K|Dvg3h_fbDLWt_Ew}IWo-TrVRT=2yZgo1UY&;+(phbbB zXDj;%cM*C$f!%Mqb%%^t?s9=>gi`f?>WE2RXEU+;`I)%Pcb3X@_NziM*X;Ur3!wv& z*6d#nvh(~Lms+p~Uh>?McpaU04|Hp1Rh^IWL8Xd)Xw~Ohi}@yiK5o}K8*;c=2e$U^ z?I^=%?1Fm!?W(=9#05?KH6U8GrR3(M*V<1K9o`9_<&mkg{c+QJD+xGSqfWz#oMSwc zO98M=t^MD&sb^EbepCuIEZQNMn2SsU>OfI8mT699e65)f`s28sxjB!=<7N4Geu%7>V32yu=^CTFBIYv>Ngir zm3Q+9&ZQrcZK+>3MtPt35xo5CF&qdt;ih<5PMa#EWJV@(VJqB!O+FHl29!o_sKtTo zj|cscW49y0Eho%Kw8~&Nu%~l$zhDAjwe>f#iMFq8jmH;&Vv1C}Y1geitPgEtAO_1_ zk+cOxNaNHCnR-97t# zZa+L5?}bTn%tcq(TXz-$d9w>`)Un#Q=VodRv)s934lzFRJptg?ITJ&cN5srlHG zs(o8~jgkDPDH3IN{o%E0Zpkef?@v<}?ELKtOB;*pd~m?~nE!tv5@~*1=_OQ7Uv{)s z?_0IR8oMgtb8gEsnAsBWL9;Xah~sgd-;I8}33UPk$ZfaovTIv5a^|~ED9NR$mh%Gq z&pO!pjM*sL4t4>McLbqfP-T^@pTSLWcJkw2vKJQ3L}xOF9MG#L;f!YCyPT=)ZsCh(oM4$?_rC1-87M8;XgJg~tQ< zCZ6vwWx+uH_{@O#e#=S2%`J7qU|PLY+t5rgAVtgTw``&JI~LcjZ!ox@^aW7e%I5UF zu07;nBOYg|9G;wk8>>!}QWyvuCI6vBozq?DdH7pMZG~yTZqc=H1eS5b(D!6n;zs9} zXjlo0!`sH~DkNcp>{qu(Givj%YQt1H0*en!&sqQMl;=?v7TsVYRnWo?aZaY z3ad4CgWbaazB%CX{~_(uEu(>7jUHGAoP;`gTmVzX^A#YKRjxc+VC`9<>BQb!%*r9w zbd;xx>%>G;7cC+|T-^*R2Q`ZwlBDl^rl@d)HYD<&fKlfEWr2P<1~_xj5Jr79Ockl0 zD^G0LFFvlH7%vRBA1vjuI*wdg6i z1bfc1y&Mgg#h+8E((0QO)+&?e@D!2LE+0NzPrB&)R_Mn=Yl3{~mPhKh6TyhnVYlyO zS*@)NUBH;IUr73E$E;+}ToRik-K}>{nCYxX^}JEG#YYre97A@&Fd~f8Wsa z5C2gNdhVR-m6p+Q-6r@Up+McO66E%*1Y~J@`5i4dHOpkINFCXqmmp{1U52DkvB#g? z149r`W=t${j3ra*Ymwh3!W!C7yJ084U*^W^fVXB29VVu9hO+Z7W)fosUR(q>NC^4P zBXWKC`tow-qWdie%WdR_fqlkO-}m4WJqjk`LnM|l)rkq&J8l>}8M9J-9*m!U+{-1K}GHPW$R(u8wb8KSTapn@kj zmEh|X>ZV$PL!eo5gwr+0QuxxUKt_Eq-XR#9%VKzyJVmguUd@dd3YUxv^=Ph_Lwzr4F*!~ zf4L1D%;CY#eo2_jj|X??{d>Cg0(hw(`+SmVm{HoVFbrzs5H6^-8ye^{fp0qDLtoth ztd5)^ytwyttt^^)v|fXCee zIGXBHHgHSSQ^lgD9I|~o9k+ekS`j}NcYak4w=HK|S7ozyO9wdCe-2$VPrq~VCAQ!Q z?bW4GV(M~S)Cp<$*f)EK#=08j?ae5k#<|;ceE)UP2piM=aiNs8$d1TJv#a}y#&`|V z_av*r+si1?PWT*Ld zsK~+ro)s1z|L_Dv7nAX&@x=Y7cErW^Lu`bqj|e?2D9oUL_J2rJg8&3|7Q6Ll20y-WQ|AFlR5k#+`r`rq=TBX=$y*zg~WWq%gp zQ_#RYI)h)+DXl}0k+2l}sJ^lBaIxQX+*-LddwM$okw27xqx%$WUsz$y%j$vFPyT4CAkx&S96SC)Vrxk=lHcVCjl~GTZ)(ESK_Ko32_vB#Irg$X)F82@O%)x zarxxv7f-UxdBrfptVO51%WZwz>ZtjumUHG_Ao&l3SVhMtSg>UjxFX6uR{14u(0YWS zFS}&jJ`}qE<3qLU7SgvG80;w4NU*MAbyWmv)QLks1Kq;iSiYM<+M+kNAS4o8;UvR=6 z#0&+77@!~&c(Ffpjpk1Wzfg#Z_0ihpzd~4AsCbd;TA!0=fGf*F=W|lVUe)d(il(4R zx(drROBgatK@id_H57W=-*=~;>T5R#m|MtG!x@_56Lwpe$)R5-X?^~doT22B2^iR#{*__)a)YXoZTq-dD*O-~keAbZ)?=S$K|BW)`EazJ1uknycjfY~dKPmNi zF0v+g9GcKH)Q_BFk~R~G6>Htqh3V}pJc>kYT=c~+z8vks%Y}g!;9-V&N$%XbScIr< z`dE!k$v~6N;g_%pMs0=~TW{`J@0Jkhz{*3#3y9t>@#O@J^^Q_t zd2KSm<|9fyQ)+!BoeEvB_!Ta1uJ^Y{v!siW=t0$C>TGO9uKTBs@HX2Bn~*JmzFJhB zPMw#kkz(VDkkhZZ{~a3}B!V>|rh+!{FXw4o_^MuQe%NVf#Eux`@?uttaaQpO!yygh z7;I3hBJ^IAq`{NlfhR$Q4Y)udk6&*5Zl09%`iwX+MIO3F-)$3ckbR*d_{)#|z_m;W z0RI^`#Ou4HxUNkd#pl81+MyA(oLXuoOU)}ED7Uc&A1#nOUw`LD?mljbcUaU1SFgj`xJ+I~q{`PK>L)yFOus|-H-S3XFN_4@YecC}0O z53a697du_ii){25UHMCy+x>)lxBK03-637Bql63=3wHri`;^QuV=w+n-uvzg#65IE zZAWiapyFx%>`}6(2_q%u9 z|5yvw%;KE$#NN-|=ftL&Z(vAa+K(!5uoqSg^v!lq{FDOai~i7H*zyHCssb&cJ=%dF znKteJ7olVJl%^>8*xc3#FWWRJfG=uI|10+abAA#k-7z zbJlGQ2W8RV)AhFH#otbl@bBX9{KK=Txsb8H#n0R#eH9OSc^Md8wMztw)LC6MxLW3b|Rz=e|jPY+of21+>%4d;ZJq@<}+vp zQO>NX}pO zOVWhh0{*GvRp1#yY`6}^3+r8DSrb!-ur%&JQt<3riM5i)b?%LD9z!8CC*&pvHpVV z&NCu&c8}hc89g>$W^u)Mc}({)g&fPq>+DOrtouDy6ZJH-%5 z59VO%b?YC5Ie31`T;`6SRGL~O5toHfwOcf9Z_$G7oguCVbPcA_kD|(7bRD|)12f_d{Yy7Hlf?72rlm+!RPIiA z{ic*?DKD;cvzJ4O__s^bt~HMdeTKx>bU(YvIh9+aNo)vr42xgBz%vn>m^iU-1xqIs zl;bGDeS?*_kextyUXU**PDNpcpai-^#BKrGyI9uZWI9dULXwmIe7(fa;KCo4xdpIg zWrK_Gs(|qN;yOCW?sKR|xPRFlDwJKc_m7@SUWq+ODavqf!dJDBJP<4j_QXH^NufIJ)?Qz z@tvVDwi6&czW7~uPK@X8;X~h5UB)2{N&UiEoD4pmp~zo(#2b2047&-$&AWE7gE1O@ zd2~r$AMXr(5e(x*2!3d!MmC~X+T>{TIkla{)2&<>(&@+lcEXs}AEOS5T=?cH67b?T#YMq(jb^h3(Bh;wv8$D{x zxW@BBqSi*tdQk<>k|r`)*BAy(zHRjg(ghK0=MYyVN}RNQCf4uMf> zbEP6ck9fjjh9L5f^bKpGnh3VH-redF7FnqvHr?e3q>n;F;wY<3N_;dBSZ1tG(76Ei z{<~JNYoaC1FfTvNPkKN1c4_r@tZxjB=EbWUrJnN>yV7a7X2qr4;G3MHDS1>|B6rAE z8dURKF&!LUwqMgi(<{xc0_Vv@7C^1r2GvXlvt*7(D36oJJU4fY2_9l^(F3aUZWj56 z@!xpgH|Q&?d}y`GEnE=B9B&X)d*Mk6#p?*`bMhjTZ)j=(7fK9S=5DsOVR}gnHR0mD zeOV8UW7~=z(b&C?yaf_AB~DlY(y+LVaHp)&mLJ2cn5={g`+~4sF8Jp?d5xAPH@XFy zYn4{pFFlaZ*aCd7?FH>iC@Gh=^~9t=*jg09Z5=msPqg1{9l3C3S;7oq!7kZC_kDfYcN8B^tJt3v+B2REV_KZX>)(9dJ)h(S9>Yv#B&bXW(B2~+kwE6dTtj0@M3 zx;d)fci2OdXFRDf`g6=f3tqW`;ifO$MYb7Dbf$6+x$7~fU!jB+c6TQ+kP4@bvIII# zQt>R^dvA!e1o*@xEv1~HZnp&D-Q~7aw4mQx$ti><+iHhqlyGesU z;)@d`-l2r1Ct8u!BIEHZ^zCv(b51zyJKRB^F_ejfdVc*yPner4To)PXQCXEBpjJj2 zK2GitjMc)@=~?O{Z@rKCa3&nw=S_Y{`d6gr7as#L^|wFr+Y?iDfoIJ;q5aDB58k!e4B*mVqAMm$jPP#|!F3ij=F*uu!(91f5G(%pC1{daGBWQVf zX6rsSW|O|rc_NuKa@(f2e3{dIu7}0D5WKI++>DwfJwmbiFxBe3qft2AfJLS}9`{)f zB#fr2y>pNeZ0h#|&S=)o$e{ewQ1vQqQZuN=(T5NaBAc?0p<3$Qk4ba%5|#&1#1 zXBi5%?Mi1)U#QacoQS^f{P;-HAa%84w$M0^*Y#j`<3c}7-+`3E|9(QiDL7A$f&W+J z)b9$`|xO{i7!;*iHFE8@*ft5jR{6uwDiG z*2uIaku=ly�sn(_(N3K4mfwsSlZ`ljqRYRX>Z-KIY-cOGX4%cgpI@$gC}Ri0i7T zLSShD+U+W!j)R^HR0vDycJb7X$~pX<4Eqe%c_4p-$d4Vf^^`7py{8tBGrT?F#p-w4 z;It_BkBCyj%%&F_&dbQK^b`6wjVP5f1D>m*S7(kZSH{J9#uwYV6!WJAU4n`iUF$oJW^JsPbX%fRC#FWMa=YL3jo|{R{C93o9_G*AG!T}`_Kq^OtJfw0 zI-Hg?pH?>pW1SxF-I!1##$9A_7?2|_>>SmewxMcGVClU{dcSlN8r~wJV<==u^aVwe z2aU{sGB{q-<4xLoZf4}7CDSabvj)~5-0=jdNLUo4yO6f-5kC?orA*V;1IUE*mGlK~ zhuezJco~FEXe>%h$8O&IMM|I=kE^*5r1)~^6!U6-_awW9o1$;cJ!JLtBhX5JQc{7H zw^W=OJ|J%)@=_3uR@s^?9)1!x!&WpuB$$xned))rBs)7|PxIO5UiCrMO#bC{^i;%a zN*p;>pzCd2$4S_Kk(;}B8J>i=Zt9d&+FCy5S6Y_RJQ6LRf#TUWy%Lf0esbio7(t1e zYO*ft9gTrYoX6N4M)tjMqr-kZ;|DzBz-+tsS!#6wjc@&grr=H#I8T!5r*fm|yi^6Y zALc%gwA16z=6wcz-RjL7jw%UP;SOe0#*70?Le*RibWGUY1~vMMG%ec8=nc$`9Y>u< zRbPi~QmGQJOG|1vzr%1SKgv@z$4^8Ewf5$7#zWJ2-ZCN2hg!JJFc!XKwX5S_l9Jq8 z<=Hn)mBoZdfB_w_bMtiI^lbcu2^kAq&qW$nWG%lH%gt>vP6Y*C?uMi!si(_gxOM%7Q@$2ve&VJjO)P`MP%&JS zC`ZDXvw!e&QGDkwa1>VO$0eg!Z(}*a6rXH5-dAWFh69r0;o3r{J)_OjurC2D9yvGh z`aZdJ3|I_r2*oK1!{4gYl34(<0K0 ze+d+k{Sl`FRWx=k0{H`en5jZx}!g!Ok)St%J#wKtjwz;(oF0%R4}t zmjJc)4ZKbI;l@heaE6h9Pk^Z=YQ6^k>J1a3+*pn-8wnjj9#nELa4 zPtb?QYW=nCo^Y@Ep;DTjr}*jwAIZEa&pi)Scn zi0w32CxrbI>SFIALzB|>X^(iaVX?*p5pDbjakh9+~b)N2y4tH*}@jRbU>G)y+Q19}sfGVVNN zZNwKEm2`4T?-LrO>L_r`X8KBbpl%-#HLnc+jC=H_7M2#S2k9Pe#k>vWpWw@iEj198 zo8mTRX}r@FK77&0X@F_LC%2m+)H}gQP0R2(ypm#<|KrziE1Gx0;Z^ZrKeTpoETL-V zSJ?g6;K@@w)~yO%dTwO+6Ii}&GuwyM6*skJ%D@;T^mS8ce^x}wd(DK$aoe)JY6sm# zjyr8JDy*{;qdesi2cMWhFXC^^)V)wX0kqYM=0qs?ilu5Ck1FdKvkw}7iM46Nv+uQ) zHOZR9>k`g-(^7p_PcIZ3U(ytw*;@|n%of`WjVwRIjg@yu#Wi}Y(A9HA9rl~S7em`G z7+EM;Z%*jDK!LOw(^!O=FQ!QQq?EX`_=H^tpG#|=c}alGV*<~8@?ve}K?wm>)4Gal zmF`I|{bQ<+cC@~*7n%K9F@wgvF3%i@1iJK1slz`XAht*`mil_)s6&ATqz;B&6odil z`^T3w4abf8L*n(yb4{!Mk+W4Mb$>~6cA_X@AcCEZmE1o^%z|T9xozCPnFU9t3gqrz zipof;V=NgO<|!-o3HO;vJrOfZt7op)mv=?a{fXuIv66r-PSNTx5Lj4Nj!fIh-*PQF zb?_8iij!z+>VRuRzpwPVca9G4;43vv!(f9)+>|F)4Ww&e3h-%l)uT|jjM287bzc54 zDm-}|TEwow<0b(vmhp|a^xj*YS(P>4^Dj3c{)7<;t!6kzO0fN=Nh>O840K4|bcuun zU*EI5nSB!X`ex$niEc$>W(9R>IlywyF$7JKp+I)Y2+ADUHc@F;SIicjup9nEYxq(n z1Z4kD&yT`{+z=45du{;x+7T95tU<2<5_%-yC2ndRz#*_dDQy66G7B=no0g@#H>@1= zsnM}CTx={&Ak^P38p4muMg_+9tBwxs7LZnP7(?Ys3?J;)h6JRG>zD(anKk-QOqEIrkiQ+i_e7g}nDvuM=AzVD_&gUX}6pYAzw4s}^F_vhpwJ}+a z*_wvuT1XnC3ikHF=}BN?w4WQdZ2beH0CGv+3n!zA@!Rq;m2AOLp3Oc@Ihe?W;H=?h z^Y{Ws{aqy;R0zKS`*7~|TiV&$oZ2;&v%7jGHC@G%%4|pR3}rXn$!8vDLr8W!;izFR z5BJp@24UN>V7Wu%JqucSDds;?1iUlu&!a|-c-`eueIs*feUvMG!s}bROpABGWaJYw zmMT0h7~T6xE^)B$SXfniB!Tn^Yvv*(3$ig2e6@k`HY_P6CSlTgt(O_XFuT@>-9FH$ z3FRjY_OELuOH4Ylwp!tjy45T=8YIl!%-879xsw$MgCYqLr z#$=M7lxQcGM8w?8bnO94=25@^8>WrZ3+P z3inAI^hhJ*u{3xD4*3{HqMIg8rmtmuVpnGjAz$)*{$w(hLidH6!F>>(R?FPcM7AjQ z` zFX0+w?$ag6s>_W!p}0VoKY&x|(jnKv0at`jcz2hxCvX%Crg?98LRWE3_dGD)wAkhu zrwF$v$6sPPJ5;IpSDHH7)B}Nb(`L2`Zkc?e%XHN=k91I?{Sp3K)Pe4K>o-vY%{^9L zRE^2#N6%@?!{Uez1KT!GcBel@Fm3SiU1is2-FuoY*hRI$Dq5m)f1{IW=#$NLd&`1+ z*}J6p?Xf0iL?p=*@Z!mmQC!${cFBdsXGhR91pE-X?v9cfVaJ>uYfOKhJO8#7Bx#i; z@f`Q(+wRYl^-gq9kyr>FbNr@jpnBbYDW}?f7&SO#d2?fHdR%qPrOu{ewo!X`>*y!E zn=WKA8cbHMaw5Huy)%@{PYzFPykJcDbZxORN=ybhShgd62@F)|mp4X_;jahSco7ebjooq*wI-a_TAWqKohIKWj+JRwk|K@c~82 zP4j!Id|2?+9d<*xdNyv_8%mS5r!NtfjAmeaaBFHkYS#CQz3s0sthD3P&}$417A_Fs z+Ha#sH3E#bb;YAw4PD*<3?=1K>YeaO75_BifKL@j4D3~VKRe&IM*n8PSfVPfc5mNU z;&9*qE}>^FzYiN?c_f==@|0Sgz!3sj@G{;LJfI^+9CQM#kzij1u~)CAcBs!uTXs}k z+~B9NRbjK`rb=`5Sl#_pFq^H^?Yjm1!yKHdkjb>xJ)LmvDek1OY%k48L6pLPh`Cz! zj?SjUckeNvahbrSx11qumAQpySix|yMVqI6oCBssXC+E)m0na}AaG9W?%xc4iG+4k z`UNv-au1cl!2L7!7)xeVM=1#uPr_-l?0K{=3IxW>%?@yAsfbN0ny*S*Qzd5@1PPzk zw%4D-r4(%i`^=@DtQhJ>43bi=n(#9I3MeP`l;?FSdWJ=3Ra*yG~ox({WY}P99^#{$C)0O&Z=c|W*})s^fdcrvmqi;{Id>OAlV6v;f;-~{4Sx+ddn{%GS|7AxiyPH=foN!= z9{7RL-#3qCwzD0?OETt={DR6tzhVFf)ab%1l3*Ek>~4uE&j3*-JK=wELNO<Wm@ z$TP-_LTHWFI|di!I9i(DXIhZ0ABp1xACh(phxXCYmM0Ai!j#GeQB(;$IdZKiLbX2c z#Bm&#)VMZ>+^&oDE)Ihh>J=h4oJT*|Ct@975U8`Q?QOG#%ecquX1Kid3aC9MS;u@r z=ayq%+aTC!NXrm6k6&*PLrW1A?E1EZk(gWM*b7>St|;z#cSz~!?FjtWdDSi^=-dO{~!;+*zFFk|RUratGA92Oi=LO}yH7sWg3AT**jAyFlOR<9Q zd-$C|{B=C9*3HN3&72mvH|41ognICSPuI&F_)++J%WNzY?``IiS<5CVEN&n|TQppu zyzjQg+Toz}Muf0ByRlN^`q18ymZ=GBaaECLzmN1dslwBE5qcGzOeXAzwN&yWq+wbq z&ET~}s&+{JFBGMtyaF-gl^V|Qs*c4CQH}lgbnhP=dFKb73)NSA`piE1M!VNag(h>G z*i`RaZ!dU>lR;ofO1*#Z)R_|357|5Aa*@yMpb?31XWQ&J-NDnxG3aXUGns`5iTAzr zEq7!~+U|`vG;R4 zBHF8It{DTv+$?Xc+J@aqP-4~jxwW(jkv2v}OK92`&$RT(&Cm2ohO-v_&ddx`XC(K2 z%BcwTgSMPk@Eguj`(DxqRToG&AC2pGan64g7l@;zY-rs2Z8g;PDz5M-p!fMn1er|P zonCq0Xf-k{|Cc{eWy$a(2}im(`;XV(Oh-{XLD||uB+7;SVv-+B!Ib5L=x|&M_9*h4 z*xw0b#_Jwvb`7L>ctjX{9lR4$891?9UDBIlVX)+`nEjFSRr9Z$6)&aF%lV>?YYSmDl!|pu|5`L7u>X16SOA5XVoTZF;@`MqAs^EICrZx{;*1; zQZBSQ4adF-HSMK9W&^VIkJFF`UFtMFFY(nma(yAf-|dAE+WINAYIF9oB4j_cE6CcKxq35m)RA(}qo*@UKZe+U3|E$PEPGNAHlM&J=n>mok>&AFEkj z!xoXeNj82HyGht0|1_54kY94>rGZdQ%H6wx8Jz1)z(x%_uRu-K7DaklWb$F(&$F0AQCH=x z-<@j~!8+&$?~81-zfDdiR`gh!3PK2_V7A0ztJoqJ?;)_30&yNh*g~`O4r>+3VLJyi zGOLg5Cx%~3rqH-Qj*K%)67G+JnVZmw+!y9MVc*4(KQL&meq>{qU6S-P#xhsr_8?~S zO0a+bB?K&e=5u>WyRf*}g7>nA_Wzw>IYV|C|4hYJU+%_>Ud7_&=%P2(aP|DI@Z;d$ zG?pgl!0NNp^Fiiq>FU=M#=8DhvNorzD<1vPyWbwta1&3N5k5&`k!Pdyokz$*U7A_l zzHfFpsBpJR_q7fG%5)u~GY_)j5jiL{%vkvAZP+1xBUcGGCp_&`N-2UM8?MkVh~i+# z{wFJ*QFfkV@G_2Y$hNrvkGj=AW0NNp@e}WJCZ36t)>_MT*k*c$6gRSb40}>rXYRu$ zG;W(F5kbIK#kz^A*NMI8Z#HC=+c`qf&;I&`hItLT-+6)<~X%TOy%$dpmfdVMS~CvIm*3qO*%u|t*$+e#z$X%Q~JmxIA7KbDl0 z;ibSCW~{Sx5^#RHHzVf$tXi}pV<#E(PS{~@518zAh|N0|mT_3Ww3um@pO~#nDUQuL z-Rk7<13Q@U6_BqOd3R03}704e*J>Ec4Nr^$GFH2&`AIjQkE@&d{xbK_Cg78~&! zZ)6Nc6Z7uCWTpJPhpD1kGli{j;!viZi=Q(dA0v$LNJ4YOpFbMKo`o8~DqCXa(p;PM zgz9eVe0cKDC0~mVetX_(Jxo1(C4>pk;70u;=Upqs2l zv$ucrn2_JjIkdAIwScAVhWxuidjtdBzNfl}ZR|#dIN8vcQJaWJd>e49uCFH2HsK)) za+27IQ**-_)}ji>Tnc>)a-S%XSHyB`hR?6QT9NFcj&oCRvC0kj!=1Qf>1Jwv#aTFO z+)ax154@6DLMIxdNRkJVo52h~OxaR~3~=`C6@vBg8Cd^T>GD|6$?oY%0>YwE@O+rG zeyQ@BH`~Qv&4+dh1Np&kzR%wvtI$QSpkdCgh^sG471|tqyLBSWvw|(!#e`1ZchR^0 zY5X#l;DlBBYv(sL9tPhsP*t;VVy(KwQQ#8npX8j6-1=_K35|?XQ!XD% zEyzgZ8fo4Ntm|MUl&ORPcwhBnkj^VfTVkp+w+4YYU3W6YRJgTyI*$FIN2*R9 zm1;s76}W=m%>B1npW=g>h<{=ulOk(qP!L5ub}6UwmexXJ)`MdIK>&`{EFK}{(28F9 z4>4=8V4H^NmPLQ0y?y&~?ii7BA0Xh5ImSu1a>tC@O(4e~3;02Zxd;A~yZPgZCE! zdN}ZZ>0ul;%CF+wbl%+qNye7QA`!o4ed09zlz~~`dKk6;x0}^~s?(Fqc+_F>b>Ah5 zq^`%Q7BE91fByLdhHbMw*_>N0aG;w9Q=u=IKKYwn?Qxb2hI;r-(lZTx#pPDdhc zCS<<>D;k`zUi@Cm+(w;nmX=S=Uo(ar9Uyto2^Z-U!y|IxTpukD$QPS7Wg%VNNkrVC zYPDs*_3K!t&}SH^4ANxdVatba%2ggICyTNilEs|-hI1tjzfz#9-Vs}mFCk4&Ryy$H z-1B=|*z^uCW3eZW(Y6MsR5w`#J>JjM@pArYPv|7+-(w?iRyDo&gx4J768a22-nY(& z)+j2do9z6oMrNnl>dpbEA>H5HpIJ*A?_#o-UO!vcc*oPdtQDtO9#jK;1QXczcQL-1 zoeTWi4Mt;c1a~1G)o=SIYJi412Sd|fg*F+~-0v$Rpe_4r&pS;S}lM5}=ox?{ni zTTi@w%~xn=hqb|LOxBe9JcDnHrezJvUpWVm_v~gNWqVra?8ihR6l)M>qMW|gE3*&v z4d&|^m7VNh5x$x?4cK-f>$DxS=ls$oj|3On9HABL%DxzP%ck{wluYox(TW_Jmu-n~ zC-$d~W1n4d!+~ezSzYAUfVNUj_Ojt-m+^zKW<<1}4t~xJYRXu8@J;;>)7=-rMXXek zzDb)+Pu~WF_qi=D?vZpq1M&dv$hAaYq?(s{w9zKM-n55GlmOElSQGgjUtigR_NCx* zMBE%fa}004PeDmbKpPHwGJ?f${*igsSjN^@aEXO-+U~>RIjyAC+Rr5! zg>JO8T`0u+Ub{Lo1!;BNeeGng!x(1ppeONt4x7Zqc}>TqZSWsoYx1rZthp%Suj0ki zaSa7JLuA<|;9-IyC&u)-&3|KmqSi-A?yEpL`BxtCX*Z;!H{`K5Zk#+LzNqa%UDA?| zRC{guFTBNG^!7MamUEUk^5eXC;)E4N_=nhI=p(k!Q|>v!N@E}^O&A0FqZcr>DQ$x0 zh@`J|bKBI2Np;AO-+UN)$x#4NCGmdxZ7qAa!X1)}HFt1ydF9L6ptaqa4{7a?x<5}* z6Paa>U)ZbYTlU$W)3ft>WnyJ~D3E?sMi(|Aq>(m6Q#w$ zq6+}D2Uv$*;T!YxV}}94`0;l(9L7XQJG>jFbi{RT(GI3-)g*i*v*wB`dVQLh@S|)q zbO7h|DKs0XQJ7*H3LNcCf?%9~7@68s5ZT#`9K{0`_HPYzga(NHFu9rS^RRh@DLS3u zNLul!h+TTPlRqAPMwvBKt)*UsuCwxqMGNJjKK0j6B8Z2bqWlPhgQnS zyX&>GYVGwF8|!p^qJUx_DIpI?mgy({!lB99vv+ zWJ|S4m+BCj!g=Y}A^)JpPuXh@$Z!jtsI%cd?fHwU&Y+*Q+%>MEY8y^!6*oT;~Rz zf;?zMxIo>AHP}rHFH+jf>Q~kpbb^NZe*vCas5a`zA2TmzJR@HO#0hhUUHF$Qi0e!HY?MNjW@liN?+o(2LB&WSp`>ijo5FWP&+%V*}Jx!%= zvf=O^hjC{s*E7jj(t#Db(EZbxhieF-4i{!gz6SYU_D?&Txi=I56e9$nn88fte^e^Z z0)}bd1l%3rjff8>rM3x(|5vJVb^Vm@tGMWNbl*EMBJP(g@JUBah5diw*xx&WVR3GJ z++MCAdHrJ<`u1IZKq!BEo?*Gzv!mA&rdw$ zDfNxB-?wMD2SV!P`ZZ2|WaZd)ZRmTcHEe6oku^N#lvKKl3DRj2pFqRnhhDBUmF(Ew zqE*44zUvAmYOmFVAbx?BFa?8tk~5BHJ4Itih#YGuP#zFYKVF~zVc$Gys;b}fbrHJx z4PR&QH&&>`e|cD7eMEKw}R&q|+l_EesTyl}13bYDnXDmy&! zZqd10wcJvq1Z_pA*nN&y3`hwT0EWhtWhCzdL*kLa#OJFS9}tKdW=b23e%Vz(Ziw_VL5Qami4*=1ya4c_y@m*PFSKmw8y!x}0`mGvkovye&6el$DYx)o%O?`iOtD|l-xpA_HZ{mi!mY~pl z-Im2<>d-W`btO-OUaVCM!&9Wc;{BVsj^Sly`?fDXJ>qZ`VO7m9%rM}hDXI^cik;ZG zaXPHrH(l7dY$8^Zupzpa~F>X<&RH|$h;vwJS#t2a^N~WsG$Wr z7Vk$wjyvtD#!XVXxLJ=>U5lDtcCin1QFw31qgFWBH?L5ll32?ym+fp$hcbNoQ+&Y! zUy6)t*G}G7pe@Id;7EAj#AL5)O#Ut&ZkF-yUBOB%K`%`cmIBa&4ABxylt zxDzuop#hJI4Lu`oRsSneP!4J}1bJhux=Q2|l*AW{M-z2jN_mc2(KJXgG51xEHLurc z5sx7MJR<{)3t5c;`z9>~|K&x3#}du`bA2fiJly|Fi}38c6umEZfvxwieR#c(0`x*U z#&qS?RJ>NRaV7!2mLuYvs`DZv130_JzYIV^Qzz~eJMfm!&3g^qE`IbBZ|#XhPNymY zU@q&7_!TVez*39>CGUDC<$!^$F#Yk>mg3n7epWc=0;=jYY6WN1(YEP(sNWu2XH3Do zp_Fs)+w!!o7u)yz6FNfyjy=!3L6y}vtlwtTM>5u{(LiQVQb4-ubmz&vVXV;2ABN&7 zKaK2vYd$Z7RkVJyq8a;0%Z*$*C5ka7)MrIg3a&rcP;{USQPyRbFI)Se#gLp6c+#kr zHZ>yd&47F)Yun4&s9&^x6@(GNl!~9C9aB@^Un1!@D-gQDXs=juI5rvuP zTj}bN+QRWH#z2aF4a1EP0a|ILjz=FLva?E9am#BN#80cq8D?0&=*8U+FqF2{Bp8ST zTZ%)~1Ov(6cLkj*Jm$DdlO6^$SLid1RryISB8~SSZn`MdYM}eX0crm2hxB3bw=z3E zKj<0C;=f6q!1NZ`8~?m(0w?+C7qzORmnpN6T!Go^LIINpfLr?hzqkb=^6~~H|I?Q^ ztdOsRcjx>Kafoohs~AuFtfk!z*A6wgajNk`JUGz6)u zS~^R*Vtju3G_I}1f(vc{{=_`syOHa#6&JD{o)U;nbYPRh$=Gxv*ZBA|Z7SlP%~3P^ zd7UcY@9ACpkVG>i&W)%4Ap}mHtP|iMJtFR7kX`L__^vm}J>Pui*;O)PD(13ndBDp+ zC{E1}zw_;n(`@Q!yGTMFBdc{@xEa(Me*5YVmR)3aS-7fW;Eij*=-jzn@HA<8e(bXC z$2Y~dMEo2%XhwuSz(Dz1pCF1o=6lStUo3sG1uYXOYhDiEuRT$yypQ7JcFqG6oN1_a+rD z(?fR8aQ0MY;eBa)Oy9bcZj&$X$}AwyUU}Kfxg0dnngLV0&FE(C%*VzD-oYLF^E4Ck z+*q-!nzu|!cix(ez{CsH>~lchp?{MUsx?Ph|2E}28@A9Nb|IE$<{i zpIAEzKltC$9M~AjNX|oU9rYvMf@UW8$#8|MFi~c@GR_sGNl*IQL!89u#Re@F@(Ie7 zyMlEY*KkOZnCTk71KPKaop>mL=a=Y3KhffHWUpwOw+r>8-7O`WU)lf7V3j}#RrD<7 zQRy1MG5N3h31?*Dt2hZ|hiL)kxEsLr+RAgg#UncZKKb}w{17liCtv&55exzyfk7N1 z+Zney%#p{kNRU)K$$9Ur0H4 zQ3N_00M?lOvi7%~I733f3>3hWZ=W|Hr-73KAd`$#LuC}*&N`)ImPzt9^=&c)&~Wd!7fhGMO$!2uN3Neod1)p&Nlcn2*?vK$muo(MBFvD`A z>|ARx6>IX&f3z5u(73M~3NPjFMN$!k&QOz=A_^LLe)2G2XNY{%3L2TQ^5A4LqKYkjb;E{||s(EfoDQ$`sRAhg}g z@3~v+g=iZG)#1|fl;ZO-1Iy1p!Jidc!|orp>!D{{YLp+Rs%5o%!6$NYdQ^hPG9Nb7 zQ`gbeq{xj0RnzlWWG}9#^1=pE_Ks%GouVulOR@ELu%RyG7qF{`{|$a9vleaRAym>O zmbUgZ$`)5~Pc)<`^`|5#|5n8N5voU%$3gHrf`1BTCqTixlJ1`NU0ey2|19r2 zrd=WHM28CmCK5l45idhZ^p7=oQ_cDKJ8neWU>XS9rAa?$2+RfW#J%k&B|d*%a0m@_ zRsl?LM0$U=&t>m~Xv#wVPd{el9~0M(70%Wf*xp+3Pd|1>MteuvHh`1=Y4N8n$b0NQ z`Vrpb6pEFKuj7XCEFP6iY=+o7RGX|wnEb@Pt-H@0k~vrvQnUfK-P4y;T*VXuX-_m| z;*|7L)6}*a$$L_corbOG0Pd`Hdo*PJ0Xo4@ZdKb#aFW`JNDQDoM;#`%`OnS*YPV;J z2u-^jz|SYnggp5gGrA;v+D=#99UuSh3b+jo%y$sM^Q+-wV;*PyRfQ>(i}L$47hR2{ z@^MISd_f7=wR+qGl2JA$IAEFkW&HwkCk)LMm)bP5*EMZXCG>NByZYar2|K_ZlDV_h zKs~{SJ1!O>KI|_|Qh)ul!qqagFX#ol?9wm}w@n@qXMPwpfx=$W1MZ=e!mq$`{a=~K z*KmIJ70TXavzQbrAdq=-yg-}vmuK}et7YVwCCU8qkrtt^Zu|% zzo7|mcxnDP<}MW@1@;Ys9|0ur7IJ6iz>v_9{N|0@ok8R_jQ9-&w1#b))Sv zM5U#*gefl6!I8a?r-_FVfumgD91UL~V^dhji+$o&U7FN|uymIDP>Q5$-*P%`LGf(9 z?Ir+_j*obO*$_kk{9LQ5@V&%bA6r;qoF3W)dU0RusX9wpaLJhme@!AgX28R@>(Zy) z!1{z8RsdoFWW?>?EnR*A^-8GhgTJTvO*LAgtbut_efvAe>)vnv#7yve!u&EXuOr;c z(cTLKAmr`(0ISznowv%0QLfgJ%YkY^{hlJqPW|QQE=y_B?>+nP2jA)@d?oz<=h4qO z#?|%QCayJVc;1=Zq9uS{1&)3FOUGpG?qbI`qZ%LQA4|9Zu!MGMH&a?%$Z%zmIjwm+ zI@go`t+&!^dLcUt+zMRa3b5-Ieq%A_dA_9K3ex}4bOdRyxzP>9P2Imwdef+^@>Ssf#n$x$ITvz& zvimIL(n(Hzp4HbW7idp2pCK%BQ(SQ87efUXQ1D3XkShwHI^DshRT)08X$83A}+>6y3CT8|4f%8L7BgXMG zf3JlsXgl;!jwnJLbJk@AMdYHoorLM5jDy%!=*_jD_W1B$ilfRn4UbR5;u7=B|BBrg zIz8RxSq6YJ&p7Os|*BfG~M zY6T#srLlynNq0j++G)NwJ=0^ZKETxpl?EQ@Z70s@ZH%%25Y&g)l#D=EF~x^s*+kd= z`;GHEoai@@-M=Uh7M`$uEVILp#|?%RRy5cr+|ev!v|l%&0A(Ae=GBn#38ydBNY=IH z)*D9-n^knsGS^4Jk|x9+?U}d9)F;Mt4&?q6*VuyTL8^!%=9YhoLO>R!^?wng0{c(I zsB+O3OLqR43T;;R9t^MA`0x`A*aA4|*{k!RfBLliEQh|GFkDifl-Ut?2|VrQm}ycV zCeBrOAS}KQ!vgJ1<22z7PHoE};XM~amBOsah?p`>9fJ4}8{I6{6Dc*_J>tO+=U~7; zbMe1c$RKFF0#TL~_gwX##Bi2EZc-lty}7Udgj@s@U4I`PLjrc9IA9?4p}~%~t|SKP zxIj8!Lw@VcQW#&X=9+??*gd^bz1^>I+YQn#N7M0AUxMAikj??)7U8vaFg$3^K9KWyj|5n>~ zP_!W#@nv;x)tv$`akul?M$v1%1?3N818Wo!eHZuNLwn##e5sS=7d$MEkLnk#=q!sOV}qaW zV;F}xH&}_2(+7M@CtShAnKIgNtdaCiY10`e0Q1QrILeRRDMOO=orpMr%<}WUwLHD= zZL~Iu4Yk=d6`YufsrGl$KWX}_xkys(4<{=fX4Js*g!t($FnQ7AXl)1iH7MS&xPdxO z^`)8;y>cFLryW`|xG)sx)F#>!6@W@^8)9;gpiM%&sJl@O_)fKK<2;W6boA(e5s2e& zlB98*1pDc~KF-6d;M^63+RxqXznCu(!6n)ajQWDN8|d2ZX%?zt#38%8ez!J_uF-(e zJw>3b23&ad$Ou?PjqAPR$Z*e+`1-RQIzJ0YuY`S6^4_^g?y}4Tc=x;NbNzUIe-uue1wcm_{@>sS2(LWy61+^e;CYX= z%RplMnhqJbWv)(@@!Dj8T(4+lW>b(QO+PThDXA8`L;BbH;T46MliOjI{Lb)3ip>EI z;xcd$X$3ifMf{r^GpIHqe)+WL2!LiKE?~k^42(V{U3|t@7pG!PZ}7g0e@zltt?#`B zx~hNP*^f$~^F&HH*fFc64wd-Buf>1N)x)IET*{A`Bu1lS zMEuh%4yYupFn1WLrrBLuDHJ!w(7DUr? zKP-;9L!8^P--|x-=5)YGeGIlvBv~6$u_oyf!!EdrJuH5H{M>ISUC&P`4grk;ZUn`w zX9-RA!7Q}`8j1fmS(>#l=aNefFh9g#BGOL(5Wb^D#`n$O`9Z_*kedLk66aiIJ#}g` zIQYl#gRVDAar}1mD=RS*2n?H6khb*lwPbB9`;~{OK>5}TjgeTb%DDsdok_$);@%$OdLuXLGhx8Bgfh4Mzy9> z-a3*qZ~Z@XePvXYUAwi4($XNCE&=K81_?=NMH&PIq`RcMQ$V^Dq@+{2B&54jq+`>3 zYvbedzUPed{c_0IDBSm2SIle9d4*BEo-}xRKjvBvSy^B~UJt&fY;IRf`aKF;VMJDi z8R?bGJ+;hYoD9nz3H`NHi5Jg6;Jf@A_;eD!h{Ov1^zXuq_DcPVMn%oAmjo^@PBV>S z+G1xop&T+6Oj3gNLXmbArYRwgb9lJJowlNzOb2&;BsJ@@pS)B3VvAgXQl8xJV~5ni z>N)q>M&vJc-WE*iQCq^-m@yu5pukHEy2^Gwc(-p%h9$t{DD+%ueaSPCzix20Q{8c4 z*_xLUy_f!DW&T2v;sf@`MQ$%*f~`E^weaPWts6oMF->qb+m=k(LTQbJhoHL>eLVLZEE< zjC!2XY)`?Wa<&-)!Y{umNo40@dh`jD%xQL{ScrbmD|3p1%E?8MJmBZ zd!KtlgcVog`bJ}ElI}7nK3f;5X-Zwn07#pr{+(jgq_5@w(<>CF?d+%@C|`S^*1ynW z%vj^HhfBDM8<>kUlriK>qf1^{?w>2Qu*HSN+siljkuyJ+BIxba+38mgt%r=_%xDPbub;7hY-wz?<`OHN3l@UDK0pAj&V9B@9y>t-2J^@;>q z8Jg35Wn%~&Ks(;c~!KnaZ+o;ErK_fxz4dgDN)TYZJk>raetGcg9CaHNpjLSv_QlTr= zUhZC!?)Nki?#tu$e&W$FpiR4p@%=5dCwLViYfkWBEnW+%LtgiRuTk)b+rek6QM+nG zK&SEMy)?DD+A}L0#n9kbP7>Q4NWrN#VC9(x~QxZC|r->d{0SMt$BK*8xu`+YLC`Y)R{gH2)Hbf6GC^t zdWhM%G?`tNm8>z52;3mY{EGeUY7BeL<)^D9x6{EigH57i%*>@XHCOD3YxzKBS7Unu zdBBF|RM}|i-P2tA$OA5UG5FkruAu0u>jeB>SKF zt&!+&QKi49#H{PuZ^-^3kYDo6tag?(J_p_8Ox)#ceNKaRMm*sR$=_s-Fb=^KGW@%4 zMD8I4s{+xf9(j=_6i!Ikme5R%=2hcm9>BuXnl^~EdgMkHUM0gfCKfMv2FJ{qOle>M z^EO#Dvd=}2NUuIWR!sNe&x!rg7|;DH4#9{>!qx>v%klEOFcUKlc0A_Jn!RrQRyIiV zT5~SBy6s}t7Z#s_TE*4(%OMkzJ}Pew(BDL_@F0c+7iV0BPNVa zo(QsUr#KPmZUx6}+{xD2ND&R4>Jqp^NmJh&`wG&r;v9lJTCRn4=f7{4G)+ij24&di zHANSp9n157DE-ZR0KLw?Yb^@gJ+4onXWK1W(q9FF7a5k?h%yCs#qC>IEv()dV=XGV zpXD(r@gHw}vnsocw6u8R->K2Q#|{eQ{h`a?^=;(jL0ZHR4r;H9WkaK}Pj3xR%1@E5 ztxR|Ija1WHTU4N$>A&;6Zn}lyzMhXguNs@=dfMFaQWs}BR@ZCL(ZS_(SnKx6+;2k{ULNt_*3_2K58VhuBCw-* z8Mm&J)_7GFh+g*en#G~Ov) z`7+1zQ#h;C?u9#jLMTWB@xowsHS;p-a%GR9|9NWjR&(?DZlF^FP5VJ~QRgzB8SXiF1YF4R<07~ViOC6U>zh0hrhl3NMOhJW} zXTArp+ucY>xPX@^<1@`P8h|+{f5MdrJ*oB`Nmy!di>suGZ>(P1kS*{iWJuj>x=izk zlXrOP|E!q)k}!d1;WS{fXHjb5Z;~zAznbJO52L?vVqFd!SZXkDcTt5bH!pz-F^^0E zA;Oiyv-nz+Xk3XRqAvyGmQl(MW5eI+6Zq{hD4Wtiy5bT1>F1PJ^Bx}wO4mdRem6)& z(1Nx>`vf}t^1>)b`&P?{KYLQp5oFNLgYr+?71Wnaw^l9;;ttvu0XwyU(A%%vsz|c9 z?F#V17k?uNhRKg;G)wH$#TK}V%E~3(TV+cKGpzUlUHSj@M`{zs|2Ps=eZvw)TK+cw z_owgv3Dk!MG`^%Be<$s;H8ya`onwU>KY#uN>^v*F)Rq6w{{{Q*bKam&sGi3j|BU($ zRw9MPV$Zhn)A!E~45=Ufuc zS_0%6z%3vcOcyq}WfU5OyaM^v->&?sAM6m9TRqa(|&6T@T{*`cy=r5gJrd&`nFBYi)O4KIvq43R zXbe6n{l;g_-GR?()ArE**XcRy629@SG9UcudX5?Ksj)<)&p(i7;o#w*u4kW4l}52( zTeLbKRK(kydk_8cz93!BUrgzj2qls0%iH}nVa+tbuT);7@#}ji#EJ*YC)IBdW5mM0 zf@-O2cx|;Kkl|Ci8TW98n}EXsIVj{ixEDNcJ6qTW_9e#`w`&eyRhqiG-9Yj|xZ@jGJWRFnYMqx|KI;*Oj@cYN*YH&(!ZTT=aGT$m_&i6wx(71HqQwk0 ztf_E#8^EU6@VVh-!$75AOQV!-8QQc`XDGM_qmKyM-H(sWiTFH7r3MAQf0T7PSYF#E z@lE-KHYTqsMJyzL;>+?Lu%FAWnVPOo`74V^H_sc~pYAYg4~5-efv$~tvr%G&m!tFG zguAzkXY=~<8hZPEt#Hu#B4#;hYQ}c?3U0jFSAp@2d1*#PhJC%qEhnJ{ucyL+HfWXM z5t2*lZLcAwp095jB`g!Hny&9B@}-xTxxUg3t+dRsbK!7$?!k6Z?{#xVSv=?ZnT4qb z0tuRlEr)W2vkujK8|thZ?ChDb8hSu>8@lIph7ggK*}O;M#pAFbK z7>P3fQbeB>QQks&coNfGol9zT?$))(ys{KqqUPxxsakL#LEMFYlTXb0sH;@oR*8I0 zlEOanJ|qxB#Tm8vMi#+s;;~a%^a#IhT8-~DGzT4q|0c{ zGf1)htWn{5P>GtxhxGlPn-opV>*sfR3eS^;{o^f^%5ATgUcnEGwY@CFk53JYgwON- zAaf|2jjPPOK}kN@XKVR8&Id1S!hD5%&PNg_qBU(Fh<*$KFcaK~in|Ir8O1H_7cZCl z+`CPqlkwIC=u&3ng>c+dJ5qX`I#-r^G9W14D%Yfeh|7|>)WQkd#~v>uyw(lF7Ca_3`usz7HX?qWcdN^!DSezxdo1Ovd%Z#hgZ&yyiLauS=)gK;4W2s`A!t3{WetCoX z`1bme=FO(FPZqHWW|~u6xUkARVvkT*+#&Ya)uv^G`!(@+mlHqgHa3(};L?sr*aZ_x z>g7AFt1Uq_7XP_-L(lYwVCC||CIQbtkA_C|fxwJX3Ok{$1C>Oz_g%sNol7vIY%shb zftuM-I^TN{fAIVu5N|%0rK#dhe=gFg51_>@QP446U7!aBG@?p;T6n^&vv_lTc?XIa z>h@_DBs8O!T{Uhf?i6JyRo~z8_94)s61loy{#1muLq8`;U|(`SMtom>DpGi-u?4;} z;062=D(Yyku*(u=t1bGs0Wq&5D-9CqxkyFm0#VF;EZ2phUblWy-RK1V4(H%Z=L=ah z*88iY9bN8f?2o+Y zqsWXoP0;PLx(73sx2g3-n<8J`h3s-&BnW^3*vDzp$>pS^UY?&kjF|~n$VAsd4V}Q? z^Mlxe0a@nFy=%`gW~RlTBUT*cdoRq06Wl5a(NKH&d^UFw%Ak6My|L0n11V8o5mj$z zWQ#noa*W={aLunV^L_x6E#q@#qKG_Etd`0d=IEDpAs%tuntP%k;}K1{J!4rRJ5A3a zHt~zRlzE*~c0NB%Q(oJ!aI-_Fkm?mwv%DhrLEVwj#OV8UY|rIOXo@V4#}?}$ffnxk zO!*Z~vIiot6?jheDYmK+pgd%?$M!_qUJj~>kx}>6wwWx^=VifX@N;O36-;vRre_0vrt<&Tl9EeVe z26JVNXb526RJ@5`6rlbjG4{h<^i$VmWRuu~OFFqqd&1zMMw-Ox#i>3k+@{nOR4Z6m z%lh@?i5e5Mk$lvN(W2a>)Su;ba!f*L7!%L@9%z(ao%EqtP`o5!GFb&x#x=IF+W<7V z`Lj4yvfab<-_0%5IvrlZLrU{%xeY(hOJ^x3-LBB~(_13{90JwASj}3_9SooHNXkN6 z(rMsKcsv!&zc0@}%jt-bFSQClO|6(J>8<9xOLVbmNxuV<_|zpFkLCIQqX!JX(>Ez0;*|u z$GNRyH69MgyX+YI@5BD#`Luc9XM(tcWN3>J-ZBz+VTMZsOaOkKSMDqK_9a_L=(3@WtFBw%ihwE!qByi@qc>5pKC<=Jb|_- z1K3hAEVB(Q{*~+btt$sgaGFki-EUYy?BmZAalWyEK_5->NpD4fnb6UGN`^7K89nwd ziQu;3ho;t-&_m^>z z^h282qlPX6svHkx_q;$_$hpj2W9?Y7l_ZJNP&Rcdnx*gW=SaqI? zE}_~ortNHHep?(1R$srQNkX+hV09mfB;(0{^h)hgDhKr#O6jze=vi^swHBSr0?=wR zJG!u1r_?jMrWjPosY|F|S6yxoH0IcbptmAM^O*kI@BsW3=dnH{nITcRdJm(j;?y?g zg8SDm=v|WqP6O~lOZE1ij?HHh@Y79xw-eiaMJb1f6PRr_bk-htJ2KGuvi%R*yXlkN z8o|We{M>q--2bQ-z~w>Y7Jb(9g5@S8o$W*&M{JXvwgBnfMKV$)7F`>-mP5tTy`{!! z?dA238}>B?#@+?a7s-TR-UEo6zbpTdyxmse>5E%h6pHOrFkSudg<8x(rphFU`ARuJt=*51*oSx{6zNd>Ay9V-82Qm3n49zs50~$rut<+LcSNkMB%VbIQ_ua=%%ZbYfr^vV zhI)+T37;RovS}@+yQ&zTfXDoR9UTzbpj66zXf_bjsJ|Ue>ye@;r^2^Z$wy=cm(Lzi zpeTF+s?>lj?gjU@D&!Tdd*K#hcAL)`0b9w20N`eehhW0~l1~XXoXE3I=HvtOLeHqw zs5g@1T#qY))4>qO!zAzpJfGVdYgLXCWhG!pa8v^x2EoHI`GbzpsVzY-t5_mpWr%d; zp(BCkmK3=9pj&)_VU$5LnP3BmNBPm(9`|-MtipL9n@8Eu#q=rIEayuV+fe^TLo@jB zXxPYPvg2Zv|>}GPs?K zYh$hm3I$34nN;|#0a6+V^#ZhBE~yZ$m&_)J>hXcgK)3j0F+Gd#Lzd~!((X-!F*<8= z{-swV-Cr`l_Cm?YVuh4{(R=k|Fm`~^#3#vj=lcH*p-Qxv6-emQ&sat36k|U6iY$Sb zrj7y$%HTOv^9Fb{rS`phx5qLPqN+kvaL>SWg4V+L6EzcDqmi#f90_0tu~xKlhxdv6 z3X)Sm>=XXH3gorw7P6$SfLyQ^u^Zbbk}RGG3JLV5=iY#xz_rSzw6OLS*Aok1$~;67 zY#<}+2YQOeYz)V2^{ zS{^b`u7I8BEe+4tJuFTh&f%Z+4Wy)QV=_(&Y8pQ);jb_@!7B?{VIDuWAM)%>A*vajw_LKF0!LD!0Pcxtp4^}wmwp5n<&N^_9*$tW4*7wDNtHm@~(~& ziMee^gSjplS;`+Wfg0cv$nPE#DL zn+!8pOe}`NGsT5af#KeVBrZO&FUUFy*$5qgUFa(S+MVsz@8R5hVlRhzB`d5AYoDPdBee7HpsBSOZokfgd2UwHVsyM6JzyD8l|{zF6>P zK4j2PtJCJsF~lk$w!Zkw#kfL@_r3{cju4j6Qz(2xxGvD^^jkv%$tKiu@#N^z6`VSF-Xm(Q0v4og6--#;@>Z_9<-5(IWwWMc4-wCrHxPf1ygj61llG^64!|ya5BpzK{ADn0TU)#!C#vSq0lO_ z#@>1qz)_4=arz<)ubf*S^KVZwNINSVA{lf#;=*;aqKAU13NB3dKrsFX{#J8_hY&p(_9p2)pxg*t{6C-*UiJ=ZX!AZ#Z^$^_K^S^U#Tmsx6 z#`>_KHy0Gws0OknWL3wu=*usdKl`Int#px&nG8iA9f80a(@=!kyKpBH>1Cn^bRHNr z(G#}%LApBCTJ&FPBJq~fKe~gVr*OEss@*7I<$Z{~yqw2u7+=wzMT8Q?mB5w1{wMcl zIErkqZ6Nm4e<)a{?=Ix$6En`}LhQ-EFQEUM>`TpKTz>S$t#41>jm_`IcuOzun*3Yi z|5Afjw!Ic<>(r#8f}0x$G{ejWKr_r+MUJ5#fCWE2aQ3#O{@mUSS@>Swj!u>KEf_lh z4Y1}S*?@(YG5C=PvR4TJc|Og>`KISRq@Y@WB3@^WIAAk|zT49Dx!H&g;4A!Q_v#6H ztN#;WR$=2lHJMfIPceKWcHV9#bt1H8i1IeX6Eh{3>xbvtX7u7fq3{&;UJ`ns;4hO#r!$}WHsmwI13%6qp^;2$y+PU=kW-=`Ll1|osbf^m86 zJMbzdF#y{^VDn?3m9W1^j}WLhcCH(>T|bW5*%;4RmLv^3u5zXVR&4$SV8t3}YK}tp zAnKAQ&;F}yLFntI!_fU1JT!}KF;!U;g4S`gW_ed_lL)y74Gbt;9mlRah;btQy}T%F zLx9MiOZ@f^i=rvZ7o0xemP{tDb#-T!Y#EaYbPHO$S+Y!k?F?J-!J6B+x)*}>a3c+3 z3n?f$lw5npa*6Y=01X=HgQ{x%2uPW`3rSnpltf`R%U|Z%hN3~p3t&Hp2WZsuH{t?7 zSW+9%w$!~TBBt@N_6we|rzBl9=mX9sy%tYfRX~q*au#5DvfBBeYHn_^Q3rAYwFJ zy34&nI3xfLy}Qp~>BEnV8V-0D`NzKmQ-Ob{sTpF~$Vl-TmVVMu9|R8?|p*!%tBl&H)b0oAr>1YTQqXBbUoQt8c<9H?Hjz? z3n?MgK{O-^1V;$Q0(aE~yA|Tv+gUTw!swerM!%rkH7=l?w28l?Me4B)btO*nAUD!< zY~^sN>O42-e@vyOYjR|VyusU+S}0vm7m<^k0^B>Eu2ILlGa4>CyWsvA?Dx(_vpP!E z4lY=&_3G>L0xG&|w@8e*M}o{F#-Ei;FbY&~QVv`B4f7BY>7 zgvv&Pz7mS&ErAz=Pzow&#S4dA47}jy<+a1RrEKP}|C_Cf6hD!H9IVET0g_#UwrozK z+E^ut(pnB0+fXEEtzhF9W}ODTyYUO}Uy()%^-3-3F&5}mG+l#6q*XH977{%1BLM-> zv3vF5#jpbG+(ECuXhk@xY0~7q>Gw-H-7*p)#nJe3trozT`u%^Ppc^%k`f{M4YGh# zY4~;mx0uv)6MD_A2#SWdHzR@(3Wa64W3Y8QfD)eItxzy3lo>1!ggR_P1Ch++9Jd05 z)foA)uSE`+gae006&F)%@L#S?L`&=u95|%vcR!{Yn-{Y^F_u^(`aWJGY~XRcSLz)^ zvE|GxR;4Ry+iyu#7$uB7_l7PL=^l3<8L&ViOiVtes#21)McdVy121-s_^rMcQXUg>>C|I2TDOkuEEN8f7 z#<)RJR&jk0bPV1-_?a+)+xdc(RQa8J2n|aP|KC zAq^iY>M*UXh$^W+pOOp{kWhSGy05%#y8PlG*-2EnpSaz~3pw3p9SE|KC?oY4X2nt10X5(x`p zwk0PgkEt{kMBH2Si_j+{dSdZfT8H54UUxbx_C4G)gS2)3Al$XHV`2VPg2>by8mFJM z&;e5eDsyQN12NCp;VHSfpKcJ0KQLwLdP zTw;lU(+wKUquU$P?fJbgA zA1#!o(U<8tX?w&qAMYsvHV3s+Q{mpeCCgku@8tZ_a;% zgfz%Jd%E{QpJyOqa-$|MLm)U5j*Zf_Nz6I)ncdn;ugpioE84cFhfi{glFmYXa1=eNMPC&VWHRKkf?eMX_ zSY_&UJ&lV3(!Hv|t1;pPvZ3Bf45YE@Xl*Z068I19i=)e&aKm5}$^pZg~EiN^J=<0+0lIthuo!>nK zENg??ttUZD*}?XAl8bjxeNgmLbzQ*H*66&X^7$hMXq7h*f!(>7Ua(O;rd*uodDm;1 zqgqM+!wK%D`K9Dp36xl$w+x4zpOjo}X*!c8cUNfH?l$BV_7Bz<3s{nA=TxugL9Xv-Hh zX8{ffTW_Od9SoZ$jIrdO0>5LeGMjBz{LSqp58s}oic`w=Q9+|4d*q2`Mu*6#=;TON zys9%ROYDaHZ@A{YCC_xARipEwuK~KGRsG<$8H&1pJ8$RT_M6%d-BBxJmz+ul!UC}j zhuX`rz7W2*ol83Zn1~|~Ses2Xw}jbXA1t`y4RIa#)k-1el`wbs?Ca`TcE2xrt?T#k z$ep#*g5fy4C^MHGZiOmGA`RE z>hGD_fADJQ2~9&TB)Nvh?E<$unK5*kBWCKK-+59Q+KmBEs;J+Fh1ICf-PU(%!dHKB zJ@kWbne4{c+Xf(4ODQVtJ2^|4$DR&Z3VTeya%DZ=F(u-zdVMp}EH8F@80k{C8jY*6bb(v-ceFRX`5dk^f{uKmtoA61_78_ZQy#}nrwF8w z|Ma$cY)W0g^s*-0ukLbko&$8YpnQeXtcBvqc-`S#l}kzO;YD9|#1>K$q=;Jm(qxFFtiaz#7lNb9L`gf4)>I>ao}DIVQquUGl15 z?=|M4-rOnyfZh_)byw(f3{SKtmKo%c!9n;KT7J|h813TUhjzYH*`@rtOF#h!2Lvi^ zB?NvuE&=xb=7DJyc{I|KH&7C`p^zV1ZB0Kw*CqwJwz&1>Z3F^WUHJ9Xmid)x4t4k! zR=*S?;{qXIJ;+3G&DKH1b9> zjJ@owXKd_!dK)lFPHEG6!ZJ8_D{wgzZ^Xa`E?c%C5(zXA0k@FYhO66#JGJTHo*Srm zINchPNrUZJXFidv9j&Q7{G9{aWi01-Y)}Eo--EmHMne z)HoZzO6~TE?zZ(WpiNiq4+r%iDq?t~-(lS)FIB%)Q&)A?Ii&o^spVTB>Yw!|ZBJa+ ze|7+8iswXjIB^RoGGbJ4J zsG6#hH=eUnRHaM0KSvl&A)!mNYJka8JNC9l@@rgc%jNmy*4y8@Dh4{bKp&NXh6sTb~iee-H%GZkRs#r<*iUBDbjQ z1}HjVs`BF#pF2A`Bbyl@IhV=xNCZwmnk`SSKz#ut34%?88+=WY7GoYc90HR1U=tyq zWCp7JI_L72`k4c;htMsepbXVMYr;{>b%j#evFZ128;cIPITXc+_PESQNf^9)r@o8V zy%)-d0+R6cr}Rd9F8)7=t?Ew6hnst0!e@TZa6O_+Bao5kdbAE5w> zujIJG+F)afPFl0OGxDrSOz^(k=X8d8XrQ!}Ys3%jWqVH*9p%o!P!&G6Q(dUL4vu1N zg>o9nv0OU29dtXbS;oa-16vxW_{lPD{cY-8Bqk}z14~N&O#%hKl^?`G=bKWiy}gp( z&aBq-4&0vBO(b^Dp5ttV)2-*{1^%YLV6gc(b&eB5tQf7i|7cyV;zL3cXINLDq479z zEgQQ)=E0|VBeFW(gS#A{wEGGShgOd)->rxnap?e0dXw*A+uOB*qHi^7N}{Ju z6fzxAD^&`bJo|r5*<+dG-ss#^UutN)aBRXTzFxb{(sN_{P243q*na;aFM6`oQ;Q&< zyGlw;bQi?tV+wZe}qH61$z?Y%=$x_o|BJ*QuKc3RM8rm*ys&5_mx zEgZdmM!NY7p=u+prFR_eA}@l4Bhq()WQv_2=rY@JPF~`vjNXlIz*))5brwR>NJh1V zUcte%3E9&(ZPQcV^xmxl2X-{;BjFK`rY+WC`a!+x;VFD$Q#j(0^Q4L`E9r}U-eKHU z*AK_qie(N{a!E=2VqO&}4C*(PUgw@sNPRoc;C^am0dtQ{Z7CEDP%48}rnJ5FuSce7ttpIzaQHIC_JXt?qC@Lokhb8!Fw_GkvY< ze|VA9p#nTJca)zFGJAB{8NIx`P87~6J?6$2_N229`%H9hAm1TqbhAvjq!Z{eqXrpG z3u6Lfo8v}hDbaJmD11pRx1J5F*K^-Y6$61Ic81Ze3h(ENX%DWot~C8ME^V(yU{oc*7DNFzYIChtZM^qfW+D-6K) zaxmhFbpQF|b*6dviD%o3nG>9Gqy1L(#KgTgDsjQZWp*W`Cr>|jAzO!llTOH zHj*kPuM5cLRe_7Y58`K&M{?z^?kIg{_*{C;Qywby6j z`#+77_{|9g@~xszqla1Sp2N8xy(deg_D_~@Fg~)&!M!H#Q-9QIeE2#{d;G)q)|Bqw zx#@AQ-3SVwSOv^f$nU2TMjCH%M^%d3gayE-d?O^KWXijbqHgJ%&)pFy-SWfIOGLAO zqT8Z|js}j=U-61_^*4@Z;>kkQEjk|4Q6PPX`PzHGctMc?1Zb3UDT1Ab$@jLt#JTK^ zFw3OIx7-gxc8@pV9Oo$g^{v8qU9C1*sEZ)=N_~gr<^eUUI;S@ z;zBT0;qS}Cz;%9%Tr0nW82qejc4~$tLePr>m;L%Ss5E)Id;U3_++efuVDlMDQIz0q zuERcaL+C6zSoyO$p;;ponp+Z;Sk-D`3?AMyWeLedyESF zIuzu0E|mctX3H14xgPb~7bCrfb)*6whoxk{P}(`?pL5B_(B;Is zU+I$7?=n${RT+*WlUu1Y9jAWuuU$%~%-ixJe|D{`03Y(bmjPMBqsu=q2%47jg~9-T2_1N4i}WBDFrbC)VEj_!8?Hs07mt;kPSJSvaNG*fyKmJ2 zzrGZou-u9a+C~$ldWG3nwfcR8U!>qR-zoq){K8^jS0cjw)7=*i8Buz_I0Qmxl0(Wq4NPU+kIuc*>EwEqCOZm0lRx>KH`V0 zV{UM>jaQ4jHe6x-U9G@k6jjo#<>fTfK`L(W^BxUc&u_Q0Qnphv7P413TG+QCABxne zxag|d9-4OK7YYI&R(Ywq!PrqX4>T=x<~CY5dQ&ZiPSx9M>6z>X589~%&Z~)w@2UWJ zV!V>xBHF8m86>Mw`@=k$UVLt3tf1+bc!@ujSk}*s#T0AC^uovZ^G^5mK7{R^P}}qe zfiNgi0KfUIRqL_8zh=54(ntREo|6d@^2SrT2iKHKYz~xf8Vy;vk7N#7TQ5rZ3O9hDGx9-J!h|}^yn-M!=>C-CK z{`JYfR^RKn{k7$IM519v(NUT*J$+9~bd#KIOGj@uEoq`{*Asu2Z4T$h& zv=9HS$MF`6LBM@!#|vZY*qG8^oE}l_2J$U`(66`(GCz^x|EerMYBZUC^ zY(i?)3^7%R9^Z~)gtd=8hz~cCY0_JX%5YT%A0!laNxhbE+g9Y`v0-~ClB%O*$ za47XQEL@q)ER307tsWCoKEY3x|KPvP=3HbUX|zLc)H@N+uS3oVo!C!a^nq5&?=QSZ znoP38eZ?P#ZVlFtwf`;y23w_mX(tSFp}W7aoa-hJ2)5$Ay2_txEt#^DbgnxNaaP5j zPNnLb3KCbXaD8{{Gtm^D(5XeOIO0@hPiFZ!_bO<@kHYW$Ej^dT>;n^MFbkXA))WFq zIwA`6aZ;U2Nvp&|6#7PjwBpfpjutOPCMlyO-O~4Sl03C9HG3~&Q&{H`sX1L!MqN=H z8^l~sM2mEG(P;udvy@r;r7?IMf6~+Ad^}DPx>V^vnw3Qfb{iuxe(SNGtRdyC_ha6u zcM1C;7#S5+j;%A2*fKnh-AP?FG*<1$$rWP^$TywQIEb*xCKG z1lF4w!N*On(%Y=GFy08#v`nTBe>}v_4gbo){I&Ir(@XS4NjXl(7Wp#x@fztBxQ4bu zV(C%8HtO-e$94C4jk6qAc0}L)RKu_#7Be(z0h8)y__wu?fZLJUQ7dZrSC(Ys^t ze4V>2RVO|x6)X(07UF!C>{(GdVjpFmU2v|;=|1K+cf$XOwAFv;mx&T)@%Fp^I~+WV z^QU-<)c{2l298JIi1~tZ=1C$f1yAA4-RR7W=Nj4O_L}<4YD1|=ZYnftU(-1(9Tgiz zV{@2}5uQuzlc1{r!yCLdhn_C!I2kS~n^EoWX;0Y{D?Dk$>;jo{n$nvf|73y`F(sS^|LHPpq2UG4Z;4=qZZkxw)ZPh-jfTkU>C2f zx3o@IN@;&KU%rq1g+5=(@VRYZMR87Zt*_;Fb<1!!e)ciL6kggz9vEqyifKs=KA5Iv zvPP-Pc%z^ZT~TM4v~BV1DxyhkunYLn>%6P$aPNW5 z1AJ1)M^lXz{g;ovHq;!|evg=!AX|P`?)5X|`AVa{#jkrF*j8sb+S@{kgUy-5EiFqd z*m4&AVb!oITi@c$Pf9wEE%o$%d9y0mE%8{P8`K#=rJ^|ANNJiPhE;Te z#Z0)4?0p!+XYq%;U)DFepPBl_ekHU=Ey^ftM(&d}osBcf&4V+k(up%Ef5_@|62UPgbu*c zQt50;uOA7_&T4jv584rT)}DX=@k8BIg6I7Ev zBNOF=m{>$8QiLd*n_6veJEVe&U2`U-b9)=AuxG{@J|tu^+83?2j6gBUT?uOlMZlo9 z3C3lK0WL5=8WAP{Pyj4ISM9zjh4a*kSu|CM#jP4*t?)@`vDf6W@a_t8DP_$QieWsK z`<-ddg{#1xjB&nHSnRD!?YnvdKOfsAss&bFb>`}+>{fDjiv+Sx^aHmaQ(;ne1G7hj zwdN_hWF=R2Khm3t0;o^L7Jf|PO~b2~+8+m09$DZ?C10^t9Q*j=B0cE^Rx8r|v)(-` z0+SLUX_s`)&Of~UYl%bitFpO^Qcc>XdnBh3s4_dz0s!hlc8yl91>FqF>f z=FpUzdw@Q1UP^8gusbzL5>SP*L}LA*R$KP~{05P%cE6lW2`9w{ji0>{e9}7IO2X%A zm|-kToG@Ugw_|~#y|Qc1RjL*CN@l_>=bL+KYpdPj(q``=QeW)FNUlvX0%7vKxQ?Fp ztO5_TT-%=+ZkV^{j_POi;cw#evwbuR>zBZfNfM7um4hz_MrWo1EWFi7B#yCqrteU+ z;ml1x^0io8ANl)B{cXN&l!uOS7vgyWCCkIXztW3OuM zH;dTE{q`9lY)AcjPoO&I^amMw`bAyqP|N&bj}oUvu;lQRi;2OVSiD#Kn?AcCy#P^z zzNj)uy#@k)9fH2jGVT^thI^WGa6_xEVe9HgmYgfSU)_iI=IG;}&~d-MFw8JqMKmll zW2v_HNyAL^Rt4N?7|Vj`4g7`6)f))y*E3m$G|oq_RZaN&14qt`ORMeuBAR$BEb{e! zOuCdpw(kjIzq|YyLUcOxZNGIiiM~6o$n1J@NNyQS;W{56642m(`7rxY2j^ zBeyL39a&Y{f@0RD=oM&I=bP|1Dj|l%?ipQFNj5AFZyuS9<7q}dGf%-#w$0zJrUKjC zlcFA(Y{(R$|`I{ zwO@lcJ+J7Jw4BA#rYC2hhW;NsVzG`S2+60kw8c+8l&lIp|g|Dp=n+%+gJ|~ zqIH;X4}k%s2MDMpRBd#!kjVL#e<{!trX=;?<|1}E&t}GVr#ovz2@8#UC+nC z@jF4xwTKHDM3eaZEG?pgKM0Epg`St&_XT+aejVAiuj82|YG7qFqrE(FMz$~{R((%W zv&AKIP<=iEFnjkOqkiBF=05g!(q)T^EgCH>(iV{r6P*3@+wYP+ z4N)rnQ&7ywedB8x;43zfUXeuLYm5Xq!PGm73F%VKZ#ZX(EOW1WdaTnf?`N&Gy7F(* z?-wGXO8W@*kQ@_UU@w+dklddcq!djSD$7kQW8@&JlFPI8x!*(n!YCt5TbEL(5-oNv zm!5e`zP%l{A92nT0%&egn)BXgF|zw8X!=W=DLmi5+aS3wo2}7I`DLV{QY8+&UAZbF z`$bD3YrN%-^z(B)>wG+Q0WAZ%recxU8G>N-6Fr+3pIarkLt=^TkE!*CzkCu)GqQml1LlR*b31?*+=GuMud5;P_=S-J^l&=VrZ4>b_-)}7v%OgMx*_+onKUsY`0vMfUa zRu%EEqigp6@YlUjqf&LGC_ALXqAwwTT@FWronNL_aQt{jkEQVcvGvwrQO4ccs0s`< zbPnA}4-6?C(xD(FIW$rtA&vCVB}kWqba%HPjdTe}cQ>5Jy|4E>d!O_DLkw`iJUsJT zYu)P(>{z1Dv7LP+N3?mtgTGuD5{#sNWVX`;K~%mxoGT&&GybOyJ|O-~>Vx-HGlS`j z00Ip=xRpae{c;SJ5~=$ zqRe9doif|o*dwGvm;tI`<2G2hJ7l_1$zsD|L7vu>6p5qoMk(S?5C!JSH2eco14ON{L3 zy)KkQaIf9iR7K$UkL0j%TGsUG{I^-jXzOU-s>U7=-nC1HUwlNzkw;pGo#CHrN*2hM zcTf^la4VdrdZUxDE5GnUQy@lvI?RroU+NdYdI4#y^7x`k>FWJqwZY8gZuEcLMCZ9m z>sWJKhURzzv;WdQo9CB3xZGLexo%zQ2y^kg%rkC31D$u zF@T9_C)d$8Mf#fUOlPP+x73KWBye!>@1fpojTy_(kOEpw=k*wyyiHa^Arq4M{w?wN z+kuacS#=zucRR=Pr+7JG^G$ceuL~;d&zOLt73MxL+fA*+@HwXS3=fF)rY+oM1+G1< zXf-+&m4*)W>ltJEM2k#GuM4cif_R%Fmi!qg5YYF(yUaKsh3G$&>pwhip0>Y%%}lTk zcFfVZ8b`EFvIn*Q04D2YwCt?Ui9xPIef3F%p;)~uG~rl2BCjtqv`W(IK^8Uo1=%=W zYa=uYHR5HtC!A}On8r$xMBh&;QnmalM3;<4v6%pz)d;w0Ap<~j2?%9nselIn#5d88 z>q&TXGy}(u>^CFc&m%5G8Bv{mmLq|#I7{mzu}M4%Edu1KVlZv;df8@Lj!Yi%!{%3| z!g{iDcCk2me4Tf2QkkGjyrC~8x58WnlR_aE1abSNPirhCEu_O=^)2-=HxHQu?JI61 zE!p#_TFN2OI%Ahk?Fq3ge8^8Awo(HVPu*EX!S@tIx2_JimVZ#ti-1rgYJ-eE2VL;X zyS7NY?_^F|=S%jMEd4J2e4YQ13Xe}MA$OF6A(x~pDse`e38`|codk1>BanaLgDD-8 z?SstCS$T7moec?Q@ohu~S~E(!bX*IVi3kp{cKnPZ#$a z^N{b3?&QOw0%(Iw_n#|`>Lc@t%gSmm1}T(9P;$k@yaNl62yBxrNGyDWY5g^lE!yv< zv0mgMv!`Pc#GU`@8_CoKm!;HG2-nis<|dxI#}~S3UOnVSZ?X+4m`h4I5|6wF>jh^lJfLCcVNPI%Z8zV(4-_j)%i6m^)*9U0%n1*gW?A+8X73R zMGYA;VB}~1KToLzKuqOn@zn+`kvUUNw))42y+`2J7FHnU$-KCdR9B^6WB_JXb=7HV zcGwjB2MvAy90>b`QXLligd&2qMYTsAJi>;mjJhW6({O`dGt}|E@suEwHR5@5hWYm| zUw5y|3ulEy?&JU=mT$4@ToF^wGYO+H?^`mLNQ&7ww2Wns;n`(hr`vz`4GoulNnaO~ zWaYu?&7b#N+DgdPZqNRhl1WoshSjMjpk$t9{Za3p*2K*2`OO|Mo}{js9kHOz95$jA zHL=aqeXW%|q~ob1%b2dQOCqW<@UP^pr4*`h(VKZ_m2s5?7?IN}iSY<;dYhrx&AeqWn6#5o8eAQ!i$D9ao`!)u)qdEhsWS;P)AseZrZR-`W+!;wSJbT$ zPs~6y?pW;#3ae=eCynp7BH0iy|6_Dyn^5D zOb_lirI*KOSD>!kZLJ9mrz|(`8Vm# zPhzid#8ylf*k!T9`I|BP7;(NjER3C#7EWnbe9EWPE8U0$;vv#zgH2tg%{RO|@+75@1nNpJaANFzT1W_)y2n*u%rpt6x05r+r2iI~l=-Qc ztt4}l-fX`q8Zm3qpgV3B!&f7+R>@gRpljRA+PgQo>tJ%AZ^?9mFwdgu0*m^N0!K&o zRz7=4hfAn=HR-W)i9$Y{TMAyn(zPf>mYHCu=n?op1j_wXV;M~k~nMPTo; zAiaKRn@tnJ7Knn}G2Itz$5J3%jpsOZl@OsVCfhfnhmhwCtZrkh0gL=;$-<_+;4Qo+ z)TOR5+1S^FeosI|*crF3sUeDEsWZ2Onf&50#QOI`iN_`Y9&!;-Hh`T+Comw2N7pC% zCGl2ch{5Gznv`0k5HZs`DwdUQ{>%j2V0LB`ub+8TY)DLI55$I4rvCuXNo#!Dq4pIQ z!-h>16X^iBg@RsSWYF2P>8I5zfbDDm*{w&P*6IoqzpZwq&8}-!;iT5;8%aeA{}n_Z zA1|_6<&4^;W5CiH+>=uRNK*Rd1)@spw3(a z8c&QIZP%{#U2%pBxu-qjU4`2v@AepLF#jI6z^>rL-%PzTOfv-!zzV-fi-wFLC4;Vc zm3O*)?gfE2`F3>v6sa{I4IIS>D;X<|-UA+nJx?+TOn;-Wk%<SVt*#W%z$HgHp@lvhb6&bP6Btf?XEfjyvvlf;D>VY{}@0gY;B6X0A|5cR2 z?0aQ4uuHqPdeJ33F`Uvl!oYYHFPdZOQ+l-I{|(>YWk(YI7AEkvV!ArGu9H--^q#M) z*zoc`rx`Pj!k$))hj@fwro;XHON$%o*Oz4lNX$s%Ge0hvD?DyMq>~*L&YD@XeOrU; zzt7RErw(*BzNdh5%}GLQh3w?D>6!s;E45`ZxMp{lC&nAhzxnxmK2w?X2G){p#xn*Q zQfs`q1QN&h_EgC4jdRp?U{ve;Uce%8IRk?AY^Cp29v~(%JFpMQbl@$yC;@ox=mkXW zQf_>9LbUCy06~vu$3>vhzKp;*@&IlnKyAoPxx}M`1Nlj_^h`+ZrWW_3zzj)eqlTF{tP?iA|4LXHEEN4W*b;+H1P>x?MA!JoX($_= z!kdj$4bfFpBK&jt0$+&Fx4Xpp%vfG3)(^*mFAzRwqNVSjQjPZ2+P~ycHtkY1vMQ4;)R;w8KOYh z-^L?dZf;yky*;{;|A5BApwx=qv`|r7Z=OF!?t0c#^ehZwP690=1r;y&ZH{OnAi@_m6ANMB(O0V%Q-n=#i zh~76YD5lw$(VUPzFza=8AK|X;8PP~92#cBaJLM*(_pI&w-YmGOFV*?VG-CuvS%6@) z`~L}WFbqhD-yb8j-5m=5mo@so-5cQcoFh4iGYTJG)PxGu1x00(v=ID7T|A`c7A4U# z+f7vq?v4{bH?hUR{vGoNPC^lO5U$_aBm`vkV zjPMq&bF?c1B?Ibfif|VcrAV~_HZX$8dQhypOUNnfHQUWLS9e;{cX{>O+9uGT1n-WP zJ1N^Y_P&=)N*p~AS*`N>WIfus+7N_Ye=Uv@3A-75XM4Ec8JK(DE`>Hs{6c8{!7vFQ z1yVb2ltB}(xeJLL83NSb!7Qqne~8X4W!|;Ys=h9$bi5H^n%sUr z8hlqr&#o%24Z>%7Nkx;l3Bn&-@OZS>WsVEm7TfCh`SbDsSgil~`7QMibP>Ur%zETFTs)>y!*%IXykbO&@`J)D-0irAYtAi-&tQhwl+ryXE z9ib|=r!UsP8pqfQ#h_jvS!*%T#_|GqrtTcKq>}|BG4>bD_-FVybVWG^*(DeUyRJQ( z)E!|ru)xxXeQmg`8JkUi^$I{1`~Tk30A@$n{@=z3Na7EDqwZy(!0T1kif25HTP3hgNIR`1w|@1J24v??9vD3JVTtR9_9#VSwC?f>o8HoCy& zs&OIdWG0RP4l{Qh4$Q;BK?*C;Ri;S6s_x2S)sGriNdXv7knoA`x)Np zNR4_?q!DWwyAS#FB)ZZi*WqbsjkYU6r47UM2^K0YEm6J~G)2;11#o&I8LB3H;^V@7*7$ zHnz3JqHGqlW%8Ez??F`oeubFI?0xW>3*@BBpakX!3gXVmBf(Fuq7Vka zQ9x!vCPG_OHA)G({Dd5?(QclVFDL(VDL{9fdEHwSK-b?(e~u6yO1J2;*tgYb9tkMz zr<-ja_G`>T1xUl7H=s{n{0hmHOZ(WCzwxjeWFJw5^@eK4+^q_Tnq`%73?-=GSG5cP z2qMxqzkgGcK0pOOBgN*yRUn&F*#hz^pAB6MAGh7SABwh-+#UIcp5gs!wd{7Tp>!Ue zHh%pa9p`lQN&{r#k{9Rk7V=KpLl*9?C|niN4DR`d;(a!fwz6f+SJ6RdzCx+^lz~B@ zWJyXIM`{=Pw&@aDxdP6NX{X8A6yxb}c>tmi-d*v0O^cVV+9MCNH^hKp%56_#NIHqh zgzf8}5ik(>_~{;Qj%7QOHA9>TgkzeEzkQaN#QbYU(%m%}8D#!9BP7NH{~e$k!2YLD z8`Rf8eIbnxFLa1{`<=w3>4`*s`2l0DS-8Sy=Y0gk4u#`;kWP#bI*aIGwm46Ws7-(B z`=a?emk_^S7S}TArU!19p1{e@6O-~E#|#9_%^`vivfWSZ+XGM-BfoJORsA=I@xLnU zJ5Wv@(Ur)Tr}>MDii&{4r_9!;sHpGq3s0*nnMa>VKE4_092}6WOLH0#ved2OVB1P# z)DzlD3qZ2xxS#qqbDE%-RgD(Bjj8LZO*ze!( zK5M%yX}Y-74tTL+v~#}N{Lvtz%>9{Y#JfO6jkeD4fnob&rv=SY(6GdiPVleS;7~T_*TyeuV2lK<-$6gN$BWa3myJK~{X=jQ!Rro0qO%%Qf2g1sf zU5|ScXKh+wF)w(Kh*}~J$OXcdKAe|)Hq{)nK~Vll!LkDFx~n9Ko^#B=4}+C$FIc)a zzcZ4(>|;9e*bl!WRp7%3UIxp6_}=DUB-NkPB$~5!<(p#7Qh(O_YF%NG)v%~tvG@xV zuL87JLgZ0SI3~+}w!jaH-fI~34wQVxRX>HSIr&59 zYf+k_NDZV%nf>GWyshR;>LD5t&eavM12aFPu&)%1>$b?s2A;ev$RwDAAL)TtHV7i$ zvboEK9AwcwIe;Ax&NrD~wnt8Qg0Rfh9y@0unZi7ihPb5PBs+7;W5Qx4Y`Va9IL5<% zQnr^`lLCr(-WwW%j<+xdmfN1UA3!;7@B6gnb7!vk9}!=>a2!Pnk;pGYOvmp&K0zb_t~zZgYoD4Doeih@JutRwd*0IGRaB-j4VvvH zggkQ>L+8t+WoX2uy1fJlhW-yZq!fcsY;g&d0DgllC$E&hG<@M8BaDvuGON&+quH+X3cwk|Io`K z3$e&5a)?pCCv&x5!uF>3s9oS;KY!%t+WXDLiT98SuyYm1U)wG$UPNQj#m&&rm;!$*Di! ze27w_T>0sW{NoYZx{)$Tz_OB^`YwJp{seFBR#)E`Hl_w#lmGWN4FSOG8gV-AJB1Sc z7PMQC7Bu-=@b^X|y6Zsd{r#QV6vHISkB~}lLp)h4D3Ydc^lw;Xl8|YeTSB3q1{9lQ z5%`_dQ|IBkF;@wq8x)ZiSs{qRYjyU=eoe1s8{0w;L{ExCto{?dI(ge`uS?i$%>k*c zEii4GV1Bn|uHZYQXJBy^rIYOdlJ6XG@OIQp*T+q4#pV$Yct>=>;D-+r7)9-`rJQ+e z53B{K`7E-&nke(t1Z0GwQhdwrY`#`2Do5K}rN!Sf>QL;TgJ!qbel})H_4_*32WDty z3ccYn-MiHso2(B8PzTq1?hTEI#u4b~6Lj=vz z{xV0|JyH`LGkg3!rLHi_;jT`bAE^=|-_Lk)UjeZ~Bh3=~gQuKny z`W$e}H$AYJc_Dqkkt+iSg>l{~Z!}?pfUqZP*qFH^bzGu3O^%YdDQrwo3ivY1lYSoO z%UwI;%V9hLBz&9fVPZu4^Z5?Gt`=deV$w0KNpt(BMN0~PkZ z@N>OMmNA%fK>wjVa^gh%XwglNrl^GNoZY^DAr#53tLiw~qtEW7c7>j)kj&{W$!}Rs zKXR88TDGesQ<&r@Cab7s=quwaWlaqmYm30SrBstGO7c5A|9df+ZR12B7ai0!@|*a! z@9?9Jk>5HBPE77)O02?>)4l;43<-h(8w-~&{Jtl(bqm6##ONTrX>?r}$9PgtL1+&b zpn9-2AL=o@G{9+ktUULqM;R>t>`Cru;-|%ho9cM4@$pbt7DzO6{cEP~Bm3LxDG#0{tlm6ylpUjS3s!ZS>#_kWIiD4qoH zIEej!Q+%g$;9&|)uljpMA3M4?TVn!ax61CZ0$JPv!D~UcUTZdOuis6F>%c~3?Lnvg zmw!j!ZM~9(@U~B8LSCv4-ygn_ZWmu3mOxj$3Ys0flQ#j5&cK(ce!Qb>OO~)0iXaGr zsz|+lIuDx%soz381dGjzMAmZ~9X3{MqU|~5vo5w)w44b1OJ1u-{;ijJdU%d~D2^_y zq@-Lk8>up`2Jaq>tl+KhHtQ#)V7qxY{BZjmpUq*lAr4}-8KoNd@>Y3woz9viq z{UNgSdU6tpv&v#XZhxM>MM9gp|^phdF6*C z*0NFfb_mB>M~v)&muV}Tjl+bqCJvsq8&R{spzFfhBtH|f-=zYg5>iVXp^$Bt+8z92 zhwd4!=c-!7NmgxmY^p9M3xEq9UwxO1p6|Z!gI~%Ce)^so?qWaqo!_&K%l8jVbq9j?Ib@l zDwK0xSk(;$r|mIbl|vQUYLWG0LQS)5&p7MV@yj8J z3D*MjSE~F-k<^41G(0Qvd;TTMK7Ls$HaH`m z$M0X6I_)q`sc2?4aXFdG z@`|1*s#_*(=*4U}-}t#A$5c%~(1;o^cx^TaOO5tm6WDz=vAfnols;(P+h?9@8qv31 z(S$o*EJ(c-WUYH6{+eWsB3UaGzJ{-&zmE&q( z6Jw)hN=t0u0JcWVkE%1`Or-=$Tcd;)6$6*cvxMG6+%*WA;qPlsGLs{3V8Fey2`1>* z)FE`P9=*Vc`Wa9V-lBZST*0xLIKRi`qxW@3g#ueFS|X1pR+5FG3AtLdq?c18UgUd) z_y!W(Ic^hA!xgW%pt&0{xL!elDSbp&t2(UHvMStTu4S@XG6&7q_JnZsO;iZr=}M~Y zu}=8-HB?38ZLedA63CIJ#6W($cvM-TX&XKA@w;w96?E@Gbv>VV)bLCiRqY%1n&Z@* z;OPrP3A{ra^?Dv+YAB&|aV7sF*EC^k{QwnJ*j>ypR4|TNiEEfXdlwT~j;O2KJ?uxn zMYSqi55F}|@}lLQw+;7#AvldL3)Kw4!gy$#t zyl=tikIpF($7VvCF>vxHL$+VNs9Mo>YS)ZLq*rD1-flQ?C%ZywmO_c1m(v7Cbg}NV zOW1Dc%?kDt^lfqiI?700=4DBf+FggoVytUARmodP_d(&;I_6XsdnvB zG6P#5PM`46iH^Fj%+N}(?6NHzNMZH$^hpuerz_Fr&C%LIOx)LFYy9 zUA!3&T22_dCNle1XT#BscGxyDwS6Q122u#V@(d&G`6=@3PU)y<^UNvxvF@+vHPtM$ zq2~u<8DS(odU;0ThUlK)UpnGWoDt{4+zP&A`IPMp#5bC!CASzXc@8E|Bk4&pTZ3kl z#8jT3)eFT+*Qpu#h~Y1#-Oi-R&0RgLgdty}pDd*sXwC$XZ-}T)g_E=3ZDTXCWFeRc#Q|Iq(zk=m09Rete1}n!!SF zIiA}#+ED2d-0TpWU<>&==54K7QzjYG80l+oa27{_`dy0OyBV)ry6GYqCnoI8y8+56 zx8N)TeT?O4B47fmm|L+?IY8Eo>wcJYnmSpLyR=yqfD-jaN`}R z*ofVXb2@niN9V9Dqw}ELydk~w9@ksGoI^t1?Q+f9OOkKQ@blm3Mr?Js=tA!AYGy)? z#Uam0Jmo4MEnWps^*t$*X6BmX?i5qgu-*|-L|s!#6+OPRE?=QxEG5p!aOw?s0LjW# zi0*m!{Lr7tmT-J%*K~Slm!N54Ln>8ESPwv$g$o1~00 zcD`CF6O#(Fa0uLJQ$=P{CU~&%9-IwX(%BY4@nE`yCIeL{M_g7a;lup)>7D$LLvX>Q zEw}p*hriFHoje>4ux*EZ5VM~YmJ_aU9XD|VopVUZ)m+zQMmT!jv7WYuG^AX8Jbouy zM^)TY7g!(jCi^;l6j#%ksQ*E!a-bf9s!>Llf2%ZgI_B`zcf^8*@--argMF0w*TNSe z_2U~+Ny5gE+?WmHRKP_nYQBhdGuT*zB+$sSHx{dtm|a{_sirsjHR199>*arEs9LoA zvY-iulk%a##!9L2(tJtY=4GxsPbE}ZgFMbW0|^%ol7$iM5Bdee5^@BF-+JvfjOZ%Q zF2v$kt{^+@*U+$J?#()U3qO4?GRzlmfwN1(3y z0k0*GuA*uOchF}oD_n0rkdn`ax9o5Bo2cD4SMiK9h&||CanSJySNfc&&+VlTr5W$H z;3-KvT>d4}xZtB(XXEzV1jE@=_P3)c6^`hdEbsGVZY6zCAGmzE${)fw*}-yueLGZ+Oi3Qn(pCM5fTZAm&^o}5hZ)MB-5 zSBq!s9NC6c?6T#yJnR|xd)oq%$Ogm9o}?Lmj(&ENW#%st3uL);;o&Z8(^kkW-4b`|ZEL7`5;mEDT$Xf4^!4~5EHroV27 zTg_8QB2^q`aEsOR9dI+2qb`cEe-73Z6UtSYnTM6$f+XMe9-meagYrT%VrNsmH1^Rr zIn>?_@Y$xUKlAoTAUim2eu<4*EhekqeI=i>QgZkF#GHE7>>Z;Pr``)T8&rUMERi`R zk^K$N@G@w^aameY^35c^3A09(mz zldA>n?Y`I_xF-!Y%P(|)8O0Y%T>IgW7MhQ*v}qq$x~Ir4;)>0DwsD3#fUN2D@LpyQUXzz;Z-ass5fWpy))gJ0x-5ENY(Rb!zx+YZ7!cboKSLF{EKWRgi4L-%jpV1>+_b zH_qbO(ggW2CTqXNipPFrd!SX`4e_ma=GXe;#l6fZBiNhQNB2VMzj{t_fNl^B=>ngR zkgU=lE$x-8_s>UQUvyg#0;fnp=uVtndwI?lC zq@8bwY?DN#lawOtHjOt=iV@n2%?T{+(YX^_mG&FNq~zm=?olSF^4gu%*Z>C(FGJ4XkYN`NjcNl0?#0j(~uCG@4E zt(kJH_#1sE>gRYHM<;j3jk5OHZ}Ai>k#hDwb`MRf28h zA)D8O;&Shn*=Fl(shCKGq((nrr)h6=3*1rRlWmi|NxS63-}8~M9p~F$*uwGLvxf3~ z4$K(FwpG&J{rsz4LV3qiaQ^_SdZ2C$mI1OUm#|f8;v6!sN;-OuZ__6C=_ceD=T&5K znMN6h=2qoST?%e19*pjnEuS0;mEnnh+jy6 z@ZbG5Ra+$YHhE^ZV+vAq(>1qK3^OTB8Ir&w{Jc*}o> zeg2=bqTTNrKh{@E#?oiY-a z`=-x5(_?o!BppK&=(2<{@7+3_4N;_WcYcYt}aajPLzgJt7B^bvm z5Uv=q5d&Y?K?lZ&Tr=*Vz7LQR(Ye6%=Rva{Uy_Au9fm~miL$SvNv9&|kX&J#c=QPa zfBmh-H7Y~*(g;;NAjq(6L>|F>QSs3TdzzB9<2ImDr|O0sC!COh0*n)U`G?}U4HUirLi>sG-X!HX{DP6p~a{X*c6FPna9O1Hkm#m zUBKA-XJkf;2`e=nPZ>qEWt?9>aXh6OTy8?xq*7vjMef_uYqz@nDMOEZY^dW>7!h0` z;q(>*k2@$y$U9|7hrv+jG6Sf_a|VIwV$r>2u2msQi=RzhdP;rz;saj0mPvmPtL!jm z1xEnS;_r)v9Z#r(H}dr1XkV~2icS?Hu@f;zK&aoEzHoPkE3N*Ocz92dr|3=ZJ0Cv; zY)OaP@_EF03HRI>3Nh>wet7$_WXUKzz2JRC_~`;6dihPbv58d>IBGjmBw-yS`=cDQ=U1M#utta%_<>nevC}X zg56{ORLJkc^0L3-8Y+UamRn2)4`~?k$Y3RQy*60N`y?IS2CuXzN4ihDIRs<83`p5t z+SyKRp(*?f*&-B}i_(zQfWw{yi5v{GDL9*Qt!B?q?dDaDHTX-_RU%th&ogg9k~)Z)Yfe zmh7o=pO%4d%lyq(Z@)nQtxETGY@PRw9Ulj!hQw}dkV(c(QupCs2nwR$VF%7hvqN7l zUij{hX3dZ+O6#929A9Gh5YFFp&SToAhW&}36kf?rT7P2$!oM=teRS-d`8dQ(>PnZX z`*A5L+Vs?G`y#57dnbv13X>_33Yt+znQF2^g2~C9SC2>R4YH3qH2G$^aPxdN`4)XG z!6(zuh&N+cqH<0DCxZ0>jBZ|kypQSdZ@R_z7?$`}lcat7by(sP@bOHqSJ-Z&I%EF5 zl8pHy_Fm#hC&T>ksH<>dx(D7{O|zkxeCG8&M%yG6;o*jkSqpxP+qTea443u`mbDM! zL06LEg1QSG9cd;wp32UVz2p0O2rZZ|rAqlWt2c>QKN8&BDqtWu{7+*oieF7Yw)(4G zwt)3d(4{_<+Zy`Nao$-@!BECj(~8xKlJdnmE^^G78~84-8tHcO+P8cdx1Bbde>jmh zqa%)^ITw6t9|f8>6QwKIWiKg6Og+>;AJD5qblx=Ye%X1(V)9E|Rzmr?2zyDk!zMuK zCrYY2;hm-!?46)bKOSjdxht(21rqFDX>ncp`y~fJcT*Y<`xl4DXFoQ`Q!GSz&*`*s zKBJoa2C$Qe<1a*vo7pktK6&1;K(27LZ$%PEkNwemVZ^ExMub6mYuAbuOn=qY?oTFf z2rKi>i+s>OKj#qhCR5=s$^9~pbZDx33 zuZqq{$tw32AVeM*rbkcx{n|K~l07^6oTxWTM|+CsYiIFhqlTAvbm~q>MXa>n0@8c( zxsFe7d7>;>dXwouSxtl!1V|0GKZ8Ekse^G-Lm`uYxjTNB&JsFFQz0!%>L~LsQi{re zjbWoW!s_A>XM;D~NY)MC2BwR+28Jc5FZy7EF_JH@@LkLF*cZH~JoqfAud(QCwriUw zLjvuI6{D@x*HkQcVDrI35b!pN>7@= z!6PNc(PIriv29@I@S{G!$otA!w~HB=r)65U=S+&AY+r^VrO3uJyqBZjJVDs8!yhxFmw4d8Gu-;kI!u0RMgXEo%^ zVk97Wm^Z1(uB+(h=R2Vl9$@-5Z>1MNTH3y9qtrYNg(9I#M~hsBvEP6TY=AXTJ>PWs za?2qaET)Q@uIew9`e6AgPpV(Arurxj@J`~lAW+tuphz3b+t%;A#kd;36Ijr&uKc4V zG=DzA(1q$_Y z82>w)Y-D;+mFK3s@;o@$A_>Yz;fcR|gszNZmx8K+iSaQFZfqMq_gG~^u5pYVbubrZ zG;tMC$~|BYBfc?^G+`7S`0*a#AsSl1Vz6_d%z-36C7%pXk*55B9zL#4SY^zy3SxDi z1x1Rzk$wMf+6L%~|FtwLSIhnHof)MWsL5jw6?t=6*6pO=XGx0)XjtN@i@MA!Bx`Z` zF+8fa;xI=w)z4#{Kc%KM^3T^Z9X?-l<%l8fSTH}x9_pBcV^aWu)>aXNGD2@Jq0X2c zbrsrpZ&k`Y)pub^;&RHX!n5^z$R+Ii+;eVWNhO)KtJnm^JJ8yiPBU*O<+1wUiwQ7$ z#s(NkgD0S*E9cAAZB(t|5pCP`E7i1!&Uc<9+mQKCN=qhSCD8MUOu2NCTVWKerd;QL zQ1Y?leG~32PzfF9zvOk!c%J4k=hB<}g7!P5>|bbN`^?5(^XCjs^_1jg(y)gc9&FGU z9+lvG=hZP_)g^5auxUxa@+pdGm&h-Ww6fduOso+H=7z_!F)E-Obt zC;0;6O^;@VLeW#jb=^H(*aR1)P=ytYK4L(m-hewjJ=y!D#$XyXK&Pyy9Ti|*k88x4 zs{MWG8C|?yg;^Oh6ZyQ!#MkawTu1b`L|5;A8J?okVLX0GnAA6$S@A(3zHVm|5;%R8 z?1f#dZ66To?vSkfLs=CQZkP9tX|gw|t?fm~g!ncKL1ZZsJ;v&WJ&LHbp7MD_$8yh} z$sz|wv|#W={YT?Fu_V>HE;Cb9(y^S}P$V$m=wbFD6g^(~SZO`^Scf`dE~6N(s`iq+ z(d0062KT*}+63_-9qSX;zCk^5^xi_~`fmXjTu@FZXGkUMc94&8sxy zm9JUOcYf&ub24Ri18zrrj(PVE7aug@;nc{1V1DR@NGOLRb!e~5ju3>mMq1}R!RAOt zI5U_MZen@DNXJbqH=&{bMg9KjKhDk&3`F%`eJk4i*&EM9KD+0;LyePh<{c`PkQAy z%I$4_(%}O(U6}Z5=aUe$Q!zw;;o|vw_u|({KHD1^|21U)Wm}&CpGQAQD(fEAMzObr zUr$_}?1%qjD1d?>KZSSVw9XW5L2=no0JaWms^xTvafGq1*Xz`qVKcoIhgVbnx;K6=L}3sA9-qCoNclb_nD4slPUWpUr5ag zj)Vj?>Zuu7^O}=y9*Ku&G}dblk{O(^_7GwZ1YRYsDq^JnbseD>id zq_%i^uJ>)U*OZPJaL7rf2Xmwy2i~#RhI$lS)ei!6np#ltCP!Ry5!TLZ30$@hx&Nl3 z>SI7G&<~R7tf3QV&kjX)w*`_Tiqs>kbAE|S{u%FHOVAHnpR7Lp5_=#rYk!x(0J;zT z>J$894ac!doZw7kv+HA3iEX?x1^k0J2|uSHRG>Z}qm7a2ZL(dW$#T>g8Xe4^w>(Y_ zPvexE4xhZzbps>LMcFYwR^Q950sd3%z-((*@|XJycRpn+YoGIYFE#D;EUgH9n<4u7 zCVaOu{KJesNa_P0vtMWgRE%9##e_x>nz= z;&5?M)j*;K++^c<0Flc3l`By0wQ1lr zMJItQVdlK_=qGFu>=uS%0U{hc+?0@HRat<(^P+VCRAFSqOoc#X&r}8VcMZ)`-sZpi zYzuF?-2Y;pV?*Vk?DT(^7?Z9s1L2!u+^o=r9j!SJy8w97FE>5Yhuz`5Iw%s$ z6LD`F%LXt}S|$C5fuie7K98$;AKK6v{MYcv`rV)m(aJPB47ZhOU(NTH0LZ!4J6V|+ zO|C@x$3tlVcqr77zLCQc0CE*3NbNnMHs>7}l-Id5jvcKA?K7?OnN9zM73<6ulkp2b z1?{U92fU;+$I->-l*K@GUNcJQu`|oKCWP(n???kn3ut*eBXleOU>gk4He-cLkgB7>6#pV4m{Y!RT^b%v#5v2U9!$K=W2@_o7}1#TSksn z;|*B;I;BJ5+^S{>$J0&PG>PtKCSBzwr-rldI*35xmP6e&B2KPlG`8m)da7c`5#O14 zf3->IUuYxBJp`?K_VU8S5+77g3J@F~IDY7qvo#NdX`+i+%;Mq%OrBjP!8nOjtbMW9 z^%ZsrY+&W|Y=`P+oSr_+O!Ie%Hv~OjeZ$7~1wQ_6Csq1Z5<;2ofHW9RJk1W@fKpSj zX;cnx>81@f1vFGGi{vP_|1*9dc6+s2$sBR4bY`GKeoIDYO{@IL9&y~VVwxPD_tUd# zLw`u@rf&Cw#JnkdtYpUV{2g8*r*1gQhz3;QeZHpkhhKySm(QNl&LtgYMfYa{S7z}= zCEKCF32$?aNZ3mFpf47xss;S9>jGMt;TKxdz&A$$>YH5U5eM(7iG+UX5~%d-#@5+R zHSUs;-sUR&;(IBV^LW1LB1iGrYF9(2_5-Rd$1v3IG{rbT!9IJI)WzUa%emP-Gn%qD z6>3zT>0w$>#ixoyBPC2_o(jLn1+t;}HwaO`a$1I`vfFh4-)ZnyIu_Lp*E~z??i-i2 zSevZwV(3Px&?Aj<@frPNw(n43e&c+<_6@Y=!a)9AJZ5V6luL6rt*$3wMf%-|K@YJ} zUzCclalXTse&sFG%@a&C_FN&XqUuZ<=u6URljk$om*L+(0xU-v5?}wMHF7~$I?$!; z?R8$~JAf!2_<8vt;|6=qYv$L7VF~9pq8Vt@HOT9J?T@kYja!G0hFH|*;Atppwns{| zHMV^c?>N|E65H5v@D71Hvu7H}G-cdW3&|L(57}9y*%y3xZldJj3yjp~_ip@136~?8 za)jRg*AwLb%5nr1?9Cea`(d{6E+(MS9hUG(J_>Wk+;LZzdk2syb47zR&z6prcB@rp z;FU*C2pZEcMUM|>JajLs>VQ_Ch8x#d3X1g2rh4I8{!iGPJ&>!AumjR znf9P#6YIt7KLkC5tyo8<&npNmvR)@u=Al2H7f~kE{MKHUG#;#;JxUJ~s^01bXze{o z1Wc)z0mY?vuiWg|`FAGW>%C<<aPdX^a*p#OmT zB7wQ-C#op4I??DXZ2Mhp)~MqPOsen3HQ(qO%R{;(+cZuK+vt}}$G0ejT4yHT*}B}H zA+}|enb&#kNma%Q+bauUzTR-Kwh)~#p4bwpWHpV(25Z%GjGr$z@ptp5~c;6XMwLzG21-b??w}X z&Z${kbcMoKRyw>$EA}Rbm)b03j$j%y(eb5-Jny|M^46c~FXkIDyP`LiajEF5e~vx8 zzujP`-0Xx2|J-6uwg!drh*TM?%E+1Ye!M^6`XD&U!sPR%897r%F>|=>*B|!VH4yd* zbED*W3E)|&v*VEMkk{|o{ybEGLR>W0M-k{b>cXA(h4=SqsX%>Y=Z`B?xk8UG z5%#FM)#N*}LTd_$`iGxJH5cRe_(fmMZ8b!JQecKYVdJ;7(Jk%af+FVq&OhS7YOhia zuL6duXw$A-LAP0nSKq$BKqN4i5fB0&ca)dGOcA4D6fT|%{oyYqk7+QqPjeJfeDBjy z>>_g1q2&xd$M70|4x9nXBJ^h$sncOqahq3DkTPCJS{7`$eY5cDBd2c4a)D0jQ_@#{vLey+y;nFm!c(-+xaujkk# z621IFX3BVCkNAk9+-1GJz0r1m^xkJWy{{}mYP2Rh9oUT_sfuHgG)M@t#X1IlmHN+s z-;=X7zvg+3+`W|LKfgg3!jMm0Ot}_4AiwjS2mhMz?1S5FxHk8Nz#zuG$wgebGwVip z?z5~gwtE(O0nxIlnaPPI(q5_UHB8H{efP9ejSWSfqh2WpnURNWSbfvI|Ew(}uX^xH z5AGROloI}EWoW3QdxqM}Y0?E^*b`JK{y>C%doqr!orfcrcHzrtd{=ayB_v9z6yE+6 z;2+46Blb%-pqR4~yAM@|s!!*Z`H5gQr%zg+?1(5!n7)hiZrEiV@_S}E{cvw1zWZ=b z-UG>~He7k}Vo&j$5j9XKl){1eG1(fIG2L8eYfJ3j(ac@$#B*-+?+cqB9v;3P`@|jn zI`I_maVo9@!{!}S$?My<&&DL(2i;8Jx;#`uRP(iE#@^zdk`Gy*+Aq0s^v5e@xS}mm zw!qZ5<_qQ^PX_9iPHD4d8eZm7XSpOEe26S=cX1%!u_B6cRrx9WNF(2K+ipRPLBDK% zL2EPpKK*uvea<}Tu6ed4D>vm9LdQdp0pIty(OZx}{B_HALhEHr+S+$t&B-X)mAdMD zFB$}AB*o`G@+6+>YG30~Zp~k8zkfL2T>@Ii@L;;9=5uPJ!LU_gnJdeKLmmR}+ejt` z5-k1KN9l0X+=o?5#SKCXDCMC1lbwsZJu3p0&dTg8ytO7~sc-H4C{aq>&<*4EH&XIa zEmcD|%}z$|2eJ; z2952czuyzzb;oIEeM(SNdbNxP=lX^HH*%lIj$!5{@l{Uogc%IJmSUp#o_&3&4Ahxo z{o>msMvPKqpnsH@*)x6gP=Qoa;dMs#ZM{d@rz8pm%^?)SZCjxkfG`+WLlF+`LWyP0 ztJuAwg__pI?iRE_fcRJn`67qI?B3lPqPj7OuCQum)mJ1;fbg~uPuQU{pY;==m&u!4 zLLTgsgpk|Ks+C8fw`(PUS2spwW;1U8iF@&0z{>0<;o#9rLX5Lr3YIgMgX41#B1NBu z>th|Khqpu9IxxVlB{N0fmt*^6C)drd+w43 zFAekpFcY&*Mh(nNdQ)=K@5gl&sMQDHzTAXd%@rBDkZibkWLKNPKkg~!E~|`%t}Vp7 zA>@3U)e^MnB{R`9CY&+P=$z<0!C*2zK3S|fK9A42qPnAM_-S$PAs@C5;_|HlW8!uAh;H5KGsc`~ihE;x&6t1&8>dIe`X&m5)Bx zf`XTOVqAA49sVko{RTM_i0V|ekg?T!mHFj}dS;;}<5O{eS^mZ%h0(}{j4sWUUXN5r zel59x{7YFOP0_nzHMORqKQFZH1#{z*ys4Nmy$a*1-iCnAAPzV=Ne#aAno=#a(CQR(* zhRHh#TrGwqgo7r5WHR<4-CZAMssqsjK}pH8Et65O3Ek#K(j)L{%yJ>LjM&deUXjSZ zJdUuvdt~}po=moj=?eYv{JnQ47d5Pj3}|iyY&BmFNm%AZQx^*@h$f5p!f(o#VysH$k*;+sMQf3h-TDa< zKljKXztZ$$x_B*9rFhAEPel2w^n21gWp2RG-)F0@oIyU6S@kbZuch}#_=#JvgpNN><6~VF_DZg*;y^E?=`1R9 z7uF*5b(Lo<)t9OCDI9(nJ5X|z(V$$K95=Y`EyI@-VZn}${}9{jxfWas_+81MVJ>Fa z2GPA`{!}~_^DFK*$P;d`^UVesIctn33ywklaISOM z1pbbkJ!bJv-;^F&8!}sx+$bIEU9WL)ZV;5q=kjQsHH>wR zLlK}%o6G3hLV}-8I7sG+#S;y>L-(+F(0Ebnu&eYqYi9ybN=PYGbP$_JHhbueA@VSZ zFweCeLAS4!ws(o!WZfTlO{lAx+>+2Ic?PN!z50c5IZ60tw=PG39kK-Wdbi{pxG|3z zcy#m9byLdltsj?iM2)yI)3~B}Kdw#@H78pJ9Zheem$9?p5yz$a8HJ9NwpFoEQnFJg z%EHkdgOT%6jBhp%lBHVwSpvDU=eL+!ifFSa3U;mZY!Zm2K5cLVI}PcsnLe=70#@c9 zA(Avn(HS$lpJsPM5{?UhnP};+OqA!F<77eof*?p+DE}*aVK_~Azb99c$p$1a2Ei0k zbFw9Sp9@0x0j|!E@r+O@^Su~+X&uE2|IL%wM$IXRxOq0uCWDoqD2zW6sW=ZbMA(6( z!QHjys^3Y2ST2NEe1~`2Ogw9dz;wcBha_C3vD_4rzoCMSR7N%8ORL_Lk*B``y^)|T z(mE#9w-uNGCm3%p9*^`I62_ZnuQ7R1MpaV>eZ4}zSI4|riBGeWDAH})H@G4uK0=?T z_~ZF;vp`DgDZ063zsNJg^+>mLdWRpiKd}Xd@hz&Y9QnD<+d|%rd|!gy@$ymRI`h8# zLXwj5wBys*z14w>YRU|%p9NN_`zYi6Wm8@(_y&h3s7-bXx(oh+A9WVuoslp))EcmU z)k%%icKG5MMmw^jGs5WxkhaSSR9!*H=og9RPKMb;DS}!ju@v6{KsQHa){{+=0;Y&P z-Y6F?+g(pUu$)Wl%|`m>%bH6}ap}>hTZ^@h2CtBC1_)mm(sVdlJ$g;$l}$xOh@zzD zia@=bUj2N?p*SM7ck~@!VBg)5X8hb}d>R9Djif?l$$AY&c7(tm2AXG{NBg(QU(J8s z&zl=!FPseM_Zw~JKTVqVda|APjOnv>Th~Y+lDDXrV8}*zZZ+qsKtSmaK2N2L`(M;R zQ_4nL)Q?!-VB?v&L9qitVfU5qUn<;DsYfa;$2A}2xb1?r%U7g8dMD|DqJ0**%`@ZcgSGDQrT@h9I-$;G=Yv1n_9iu0)htY*+d=Z~qQ4 z-ye&b=Ax;`nHZvmWVygPVxIlEp6{OqgY1R3-kUCPc_Kv*o8NMF+$QTs?{AwL0eTzI zjswEyr=(;pFCgi@Z1zWe#6#7a_4t;a(a9vl+~{$RD83eRa4XZ6;FxAKjMNLH7xPVM ztCFhE7izH`kLyA;J-^^s*!56FJv8_{*KuKFw2dsrPQ9+{ZroG&djF`x`UBo7qThQT z_gP3!VH0h?<9bZLO9?lE%qS|7aTMhuY~L0!_aeL*4btBdn9}a@tI`d42K9%swbJL@ zTnHeC%Kegjn~^n(TtG1oH%$GDr{lxVw2@j*yHVpKezJTs#x~#O3toteMlU?bz#u;n zxLdNZR=Z5#=Q#f3Y1n=Iy z+mc)DO?w~1?JM_Nr5^g~lz!n3#y@SPuN15iQ0%To48#^v%kCEhWqRbH8GV?%Y1do1 zIOt_|I9!##yqKTn)R5%<$5q?ya9{>)dnvw_qU%9DPYIZZs=;5ml~XY9yq{D5qfsyA zD6A2AiB5X5zZ3uCImKu(D5;W)9@s6oZYF|uyL;wnGjqiT7f%d9V-<)?5(MC{cQx>& z!kPSP=B3=fl>uFP)Ft+|&-JE(gyy0U74(i`FL!Lihj+IqJ>Mrl{&|EBVTPnX##Rr8xqqxLYvptmRDLPbA~vh}w|>E`0neVJ^^4UxZC zO5SuQJ&;HSb>aWloz#_s!m%O3tu;LK$~3ZcaP8@@ax;7(kjqYMDo4G^Mys$PeB~S1 zzfXKI+z9?}zW-60Z)J(zC3z;mg>P)XKkVA!sK`1Hr^1m$y%w z06yA@_6P9+}c_k+;&zP;yf8_9J4{(Y=~;6KYzN?He7jZwPH`+u<7Jb106&@!YGDT6gfO7hqLijIz!txE^dxU@M5MbLb8&gb^$r+-pccx z5wCZz^=|djzBhHX_5D80K1HlZ^`dnm!cNTjRC-1Z#8q!pS3;*V#n&T4pN=Z^cD|wIT&E4BQdl;r&Kso+M~^|o9YaSrdWKqdOgGE+l54+(6XRy~=Hi$yH%s-HX=p1~wYL&91MDYo*KX2--8o+mHY1m5 zs!rSg)zXJtfueS{fhf1T@~|h4Ge9uQqNUaLFjt-}DAV(b&H3Dp{}0m3NNmo(%B((s&nPsZ&byr}AN-H#>EPj{8o2lbC zL^>@4ZC2JQweme6MdTm|Jjk#oGh`nM2?vOlqF=3r~v_x1_nCqq94f_9Z3z$jx! z7cn!X7@7+16Q%=+xth7IoGBoYfZ;!Uru0;5Ik3C|Dvl4#3ADhcpM?(P$G3QC=13ga zhjOOE{q0=-&u@bl&Hv7V#UJr*4zvi77ZPrN|Nb5HJAymh*qYfq3q7zZ2Zl5Gw>_Q~ z5ox|~+%}(Wxzl_lpU3HZ*(Uu}XZ{0ezLqmQQ;m^Ec&o7@jzra^=43iEN-}&zrY(%m zke9zfutYz)eYp^JZ=@hUdGts8!Y=<46my>JCMfEg{P7@HKsweZcQ1LER^OeYqyIzb15z?{m3K3iXi_ zvgktfF2NliTZT7=iLt&wd}N-Tf(t!U5&+OF<9T#K4xou1_;gY29Gi2}nMIG-ja{iv z?iVnJ-RwbH^(`p)$>?!a1NNvG2{blT8R#zQdpOQ1LdfnxZXZ{yJ=+nlKEzFZ@0 zSc?O}%h%*<^6xUid45cQIwbRI3$51<>~t@ge&+d7lkAh}@@T2=OzBFIQGNf3A(wiAV>H}=uN;1CD%*9er# zc4l>R#N5e16g8z_)+Qb&ajp5hsf87zF0;^g*JaN&-R!P zKNMaQg8_kDxggB|K!*Vck$F2VZHt!&ShcUffCr(H+2LP>QbV9jez(orgSwdwPniBOIUi|ye+iNnZ)t!l%njOQaBBfa;uk`TdJX#(eJ+AK>E2ZO;P|G)uiHD|R6HHxB|IU)sreKQj27tFvo zlrFR}?m)8^}Z~CVbFrK+d1cUmscZ$6Jhr3m>Oye_&hqQmj~ir%;io&%9rnO|>95l@44sUV>CJw?JCGmOYrG_t#>{@4-SkNJA@Dp<)csZT$vDzgZI|I!5pDry%gMl6*_*KREgiOuL%u7)Etms}O zZ7>}qg;rOI+;W(WVGH}Wm96J(xaCG^eF4guQhn31nkZM5^VdRv^rZaJS7tn$MABC< z+nb+N=08!2AjOhoL?RFdJRiY5z1m`2*}1@YD|@o)b!@DOuaV-`t)rxLZUwErDKmc? z!@RMV$d@Vx#3^GSNwR_6iESXMH9^RedXT1dsl2K-UB+vIc)s&-cMpnd7;#?1CKP>y z?2cA*4CCY2a-MGk8S-V9%N3sY#Lu(`L3VjpG;JvruyC)G2Goiuu>2>I+Ec3 zrQvdXyH=$=|Vqx7>^=b51EvbTBx5OCX!M^OuZK!KP zi7rosF0mc%)%t(CWvfE+y%SeoL~=#fcwpbOCe5jDN`ypFa@9c8d|W-?t(*Rpmu{;; zoC;in6s;rl$7GDs9$Trdp`>_J(0M@uZmh}#ch&OOs>i%*5_;zBL+@n&geB_x4rH1` zmSvci#cSL>F%5Baja&q-SO>|GGBWfx&3ubrq`YZv0D(U=zbn01b4RNeS>SB{4HNJh&|zeko5BR`|L| zhly5AfmiVQNONzXzslJfBl2ux$hq}k4&k(*wS1&%*oQRK3xn3oerc518xKRZjB7h` z^KdZ**P#$pLva;Is96aD##+YL!D>0$LD)Y}Fp;SuD_L;!WC2ySZ60T;L}6~#7CUe1 zSE<;CxOstEM7rmFIyyfDPe*g_ma{hI1e7y;XSQhF+j;Sf?=*ab2nk=!pxG0xd&y7@ zUl&5jL6YrYXZ79L^J+T_DrPz2v85F}z$qYHwvJ%MO3!mhdBdrQP?aNL6Xz2YH_zN@ z9PC5=xd+TkONq>TXq;ye+?I1c;s@KLUTt6#g?czl?f3(?P|CqD7xq!J%A3|Wp$)!G zdb%x_6`~mk!_A_ZI2eG#xhm&eX-Hn2D7ebl>X^UC4xy&!vtO7+Yn3z~*EJ@Y(PZOR zfNhxe^tx^w;O^xN{-{p2AE9smba+HMC%OGWx2E;goAcZ+TWZt2uKOcvUYF>CrZt1! zzW?#RveyhQAl3J3zYq3Kk)9zDg=_}GuPrab?+acL5HQVPN%}T-*-T*{^P`Hl4Fv-g zk)5-QmIn$sS1!)CsvJ&s^o@|?yaytP*J*<-49~Dw&+He(g<$FRN+?ec3DPo~D5^I6~N48v)9X@(De znw2A{F3T60oFpGy&m5m%4R5s@V88+p#ctzQkr%;R4QD^adqtaVsEAJ&e8XOk34@MOlw15H5hhOm@&W>U!<3xR=FxgsimSvp@DQ>$&dU zdx0B|5Fgss0SeiR9IGLTC+WLzzX?MxvaVQ;Od=XP6TBcdtS_rP2BKQWTkGDdCTEom zUb_vo3&%&$k0CuC#NF}gYkl6R3Al_Scg!A6zlcI{1tQ-$n=|M3H1hGB-se zS9&rn%?T{FkOZ~FGJg_UIdR=SaQ1}t8!C@GBZw3kVoNGH?=$y1P-7W((c63Z+P7ZI zkqX^i!A?2%X!bq}ns9hK7XNwVpFluq1gXz?#+1TA1tFT6>)NUq(B_d-xsjQA`S*yY zr(xjN>Wv<|pF;KH$CCCu+Gudy?Wj`QtS_E#AX!?? zu;1?{z2;)dQc^*Ze;L2O3$)DM03!?#-m7i@s=d$D^GnP=S4eOe^Y~%=I^x*Q#+>)s zoJ%7-_T0Nd2+R2mm4s9i>d;m2x%Q@g$$VmxJU(wF-$^TO_tWVMJ#$OtvEA=kUco0n zU>BPgXfKzS?f0N!-v;<7w5buT?Q3n&Q{zcIuZl33Z>)q$VFzx37~v{pkHd2@z3V-} z*6ZbH6YIBhI7KAzkltP8YqE6iY?Bt}3%Sdcs*Y>fiN!aSgxly(2$@lc8d&=V7hheu zHD6v|`79Z3SU~bZZ5mqJkz87#S=R?HbzO2z*YJ$hz8&vy7a_d!4_OWNf}Ix)Z5W6> z`>kR$k&y{&gYzHGZW#)mN^UzFrDP%3z2AkO;<8qWZ!3NfXuVj}G?ouu{p!qmOu)3R z7xQ>MJXmO6a)CLh%*JF&l}iL+5QZXv`2MXlMda6`Jfqo+|YKiFQ#O$JJ%#steGh%fX~qX&p?MRF-4pB z@VGv~&FsO>vUvEu(CJCrb;-epJNj;$>93K_Pul7nw^^I|vo{;kcB3_mDK4Ue4!OoI z_m)WSjazk;Ffa=JtiXf~(4t^i4_WUlm?^s|9H|ui&Hn z77E4JnyvOxP=^PwdzlFhXM7*|gIB6!ku- zNHHO1yCdMbcZ<~f^f8!@X(Wje7OJqM;K|?Ww(~TR1E|J7KkP|4)$a~YQyS|nV`X`P zuLXVznQJRQv0|~$#Y7P&EIQXqh5yZz-)tniaDA>X=J%7^E+z2%^RrkX><6B4WR6T_ z{(|v^7B@%y`bnwoY8Vu%Vx!U!4%Z$z`u&X&5xx_6G+ki7TS~QGZ7V&6d~XDUiXx)i zKi-qwd~4)>R|cO8Kny$N(6W)=qh2;9cOm1GaN?gcZcbuP%C4?*;%5l#8W`CH@l^1j z2$wl~TlLWNd?kX76FkYuVYQQjjnr6)&9Bd`{$$kAx8v9|vu!^JUSB%SG&+d=ggm)= z3(RIIf^3>Y*!#iK{J*c?zTL=^SpAN7PPT-BFB4n!<)d!6;F&Y94DtBgJXjK^ zgxbGA`%i{&Os7B@-IM{H|NPS!1^5r zOFW~Ey6zy1`Z%f#2t&x9oXQyL~izv}vC#?|NI`@9-_0QjpEC`vq9f8Y$D-GzW4 zKX*8l^83K&wV?$(QN~?tBj=Tn^aT=uzu}vY?N}~Wg?LjyU^m+oczrd0t$ua`1AjjZ zt2gbRk^e&*6T{~FVGH*dej`-<&#Am4yj2k#jO+)}f3aXY&)CgLi*46eRY!viI5q(= z?N>FjZ*n2wP~&PX2Ap3XCh}=r?Z`C5e|}e1<=yP6OlcQ)Fc(KTQdNyOgl0JVIeDAi z*T%OtjP5Fh6S<&-_#H{he}iJk@cLZS^RS*{eDefz+?zuG-rCcUP)Vh@KBh*aC8{&nqT z83g$$`i_52WPpuOW#-l){d1(PZ#QP1*nh#8e;q`tx#4GnKC;(+#?_Ka2>0lI>X`o@ zC>SQR9z{A9m>4?mbqf53N0K{AZQ-Ztu|Z4_EKx*=WNsvm_}`m?!H|Qlr+YcSpnR7S zMS(}j=D2ws-z>n%L10NNa=VHek^w*6Gx41JiKr(4V*bQ2wTyc_CQh>d4U1u$_SVY-x!)&C z6QsP{8~6`1G3ev|-*tl&MowjrwZw zX|-1QVjLB>(BI72L?w@TfNVJ&7AgD(G6*)=nq;&KDPf!joI|{X=|3cs|A!N=F9)v| zK&wdnFMOqND#f&ayz?K_$ziv8?9JA|LNu z0Up#ilWwTW_5otl0aMJ(7~iIh8vX2*S6Q z{vOCS7L~t|Di!<}9J7(oNS^cW?=6bWd~ihKwx7egX1q}*i;I$4DYDOV-~3w{4u26| z)-{`Z%cuE|Dv6*Q;gKaC*=6Z-6`y8mjEpe&7jzrS3DC~J&%=Z4!Me|L3&4;Oafp;X zdlP|w5$?!6tHX0-lC_%O>}6HuYv7T;&>fCjU}#d^m4QeQzE>71mG!TK>8SuZQ0fqo zmG1vVr`3`lIU;Sb{{ve-D@;R&7f|D`U|ZSZYqs+zFw=J&e`*_~by$(gZo=lLfchhL z+pZ0$Z=`1kbK9j#Gu_W%Qg5XjosxxA$J#^n*=aI9F zPFwzvhvxP9WSi`J*B`m(ird2OkfVm{v#je_DpKxqz_0#omn+OO*KEJE?>Y=l7 znp0xYSxeL|_wBT?jh1w(I;AyX{+HrL7Qj=M@Lo-|7#be-JvexkkzG^s=;_m^(sFWQ z>gtI-Ymkmh{^CgI*RPkBpy9?WiU1u%+3(IrH4VK2=Py^eEp2R6OLSFwdV0XL^MJ#( z!H529H(F8Vy~mSWcSTs%v~O)z4*cMoMHY5^`LevJpZiut!_fG6U|Jf1!&Uxd!f%Z} zj6!2j)OfX3)s__GWulVTf8S|$PJsXHN4%w--D+bR{-a_NbxqCi{pmz{O|ZiRN?_-8 zGV8+IQ2&XgZ)h0nQSKu-nWrD^GI~ye?rfeA5E4S052&*|*{oL+wHf5&>M_ht)!oII zo4AJr{C2QwaX6={lmDV@GArnWM(vcgre<(V4Eo)B_ZFwBO?W%@m%A6&)`|{Xf{eTc zTU%T2Lm=X=t~>)7e8Jxx9Oy(s*%*>UO<5S3m?U4md{-@`7F6-QnIVeW~np&T{zYZ+Dize*N0+Bv$!T>Z;U_d)^b%$jB%c zuF4malo}O&F3*-ff-15s!I7;XEsY+I9d?+zasKD|d6;3#G?fL786@?aJ3C21J|>a~ zwjI|rUG+!P`tB^zT|zQ4GRFL!4xH>*G@BTala@1*+czF@l7fL4wh^!H+@ z$!v`ZE4epiW%0Wm#0n)o`+gLUZ<%e9Jy?95@9osQzLIPh)Yj%RxGy*O;J5vm=Rp&3 z78@IDaxjj3gX>&kZrVx(dHEGU{y4=37^e5#;b&_F>M>9PwqdV`XBM8!CMo+;y)|ZZ z`@h-!Wp;{YA(9#?teGvmKn7u3i#N_BZlB!!Jpz962FbqaGWNf?zo(+YvWA+KRhhw@ z4l%6BujVir$a3F$#;Qy{O-dT&AgBtCZ8}Yb)O1`%Mx=np@d5}n+)y(eBO^ASvGP@N zlJaes$e53c5U+vw#|I;@f~_Y4q`$p}h#gVPw}OI#CIb)-f=J;f?zW@ws|zF|dzQYz zz%JXPyGK>tFnTL&)9RpS94@)8WD!Ow`~LNkWE&?zn1Fd*=Dny})T_1fKm#*@_Kk@& z1%@RI>Wu!|@0*#X(jP%|zgmtW<|FxlUtC-)-Mp2#*;X5YTs)r-k3S#Y&z-;8ap8IOBcxGbE$!CbTiJEPsaI`hQoH?2ME3fQ*#Ewfn zOV91Ct?;5ucUhf#+U1W`pFhtAM)Le1p82J)B&pRiN{RD~*ypu<^zXIOp4c+?ybCfg zMVxN&IQjpeBn*dDRYgUdB0F0@Ihzf$mW=)iU^genWSlPXO@KeaKavhPhDH3SMo?zCWm8WH}dW;$P-AXbY5eZ-uUm@{M9>`4)T>%Vdea(EP*%uI z3!lJ#-B}5!0_2Ae)q9HHoXfkP+wz0cIgJ3#FY(F@BK=|Nk4l%7Tjpf)m|H#kvJ~ht zMu-HK&7WC>a0wcSg}jLr1<&$-(+(TR=}|d5*!mi!e9qa(G4dOcZsS&oNOCj7u^}5y zORJrwS%B3q#)3X%#|e)AzCSVw3iggy+8Y3J+~JiP_CCwP)iK({H{z*M0=_*2835qW zQ9S>^XFRoL%XuWV%FppIFdoAbZy+vUT1X6*-niE!L^S?I^;m6Npcp_riJb-y%8G7BgtGd|XoYtM@ z*|Et>oJO_BPg?9P{}o)ZTA9!=MI=8FFhp70H}-k-x!5l)lU2!Sqju+!nHGcd zpX(RY);f))_>xmlNRU!6=H1!!FTN&w1RrFDxIW`X2Ah@-4Xf8+kKv*c~1ony{nWXszt-Mwgc}WZWovNxm=Z6pNMH$n@Wao76r}INIj0yLmyQgn9Pc zw?wU3tbKSAqXf4_tgtpnS57rptNP5zCM**RiRHCBC<1}vtRs>CJNOc(VP`90PWEb7 zRZ0d$1PGbYPa5aHtTc)&u6QGDQXGQ#5zETUMa>F!pfM13-eTcEOQ{w$HR(~7=(orP zHXKu1s>*?JeU@)|y~R7Fq@sv$u6y7YSdSip^$kWO7=ZNR0M7t@2gd)e)pn8fTm=y& za!}j_l$XH^1&?|c)+g-d=EnQs#&Z1xsTItA1IC(_McLC!owfq~4Z84cYOS+b`z(nY zBJTZC69^SBOVR(<6FQxJ0|QBlf||v`;r!#AoW&(2BF@g(rF80}{B*^Ok3fu91H2RJ zmY_`O_t+}UJ7-IY@!HNAR@{Udog}NKK|4NNDw3&bMn|u#uFtsnl<@^9nHz{fr0ChB zZ`VQy#c9M+G=bn)cjrMNScHW1h_NGVblKkV=Y#tocStR_0fd}X8x0iWN@Pv@XydAr zJP&h|T`gKdpr*B;R5m%P%7>0V0p+0?jcUC4vpl^bDeV{KTOlb5GiY6sx*Mc!eKOr8 zY@f3&O@KcbeRnhY#|f{4SALdFc1`_liN}PC)_KtnFsY3W@?}zYYs+G>%iWLTsLIk|iu?j80<3thMt3pTd=ln{dgd16JwdIX z#d@hI9Nd)z@Dm28NybS(RSC@UuT39~gAsfq*Bw2?@(7l)Teeah} zsE~5MPKx@kXu~hPG&3_Z7sx|yTE#1h@YUh zEiNk~1e;+_PhDAmsoz;@AW?rSgUa|L&m$=_^}saEJn5VV2eN9DV_}X=fFuqG^wn}j zS9s2gn#)C>4UJ8V=W0*`d|O*@DJS3rEhD)+?~dB$!AtXNui3{Yr88z@vm44!K5|^8 z7T(TF!x8UQR;J|Ekr|D^upDWqA;Y-(E zWrf)5T5i;zv$U1*ih7x{X<*l=jkwGYRM!tW9ep!*l9HJA4$}NTIIbDbUH0GaQ6Y8= zX?>?4xIbobDLsTh1z2OWvMN7V#@fcrEi}-opqY-oOqt?b&vwVC=H8>&=?*ChQhN@mpZ9wQFj`E-wgNwzGn`JUQj zc4;9|!96Q!2ukPT)kU@AB-7974+uz#0SnI@@JnkO?a??n>>l2~|NKVA0}0vQ-hO)m zCt{!$lE06lwTy57tF*8r@ycn19-f8zsFnZv`-1(l zY1j8I{gDxQ#sZ?0kz5cvJK5LU*p8}Kwc1Zdreu5AQBlb=UU;Sl&(z}GCFY~*lK<$k z-M-U;gpQsb#KvfxapNWZ62#a7Ex7W>@;Ze=AhrJDm!z1ze37qfm|f+Et)cH&|gW9pZd*=(vv%vTrdc{nhXL}@f$ zzZSlFffq)+$^c^L_Q@z<&e3gj-^S&# zQ@ha)t3~wMd19uBxkZBcFj+znFQ|6eos$v-mZ*}b=$*~2tsARu2jjObDu2Yu!BPAc z&3SA>Bd__%1M2!0P>tb}k{a8K*FV2EpixI5Jr)!ol`y3%u@@G*qF7(HE-|#ig1xIQ z!GIY*R?%-FP*eXwe);^mdp!Ay=iTmI(PcIzHI=B+=_V;hPcd;ZN5Y!RNu#C}=Dl9S z%R5PBmoEz&5n*{x1b>q1JhzY0H_Xj3xPlSK=%}rAFQAXl+J3KIy9ITOw8VSXtfV=w z8`On32CDfazw_40?>VkW?Be!8|*ud!Mkkd1yApG zlf8e3l~|XcaPCF66jmTy1Sm}ORoMjPSh5qfpf#9D8;$~%Z%#tGR z910|dL5inVYFrZ~{dtwG7$1~_t1Y^5Ud%SB+td%c6~cYKqy+! zTXFI5zWg#6R_|M)O2E8T!kf+UuA5c80P^{Eg{7EfrJlUds45S74#-&q9n z4@`L@Z>ZNeBv_xfj!kgqve6TCqiv&C_au2e=FwMq%pO!GBjjBy=4J9!q5JN_r&UH{ zZ|jSR{;B2Ukjn$>F%&NEGFwo;!(@CF{k$9Jj9>`-8*bd=t_FcEh_M1Kwzn>A=3Lk; zQVT}%{2u1Sj5XMd8~jBKi&}0iXRI3S;y6S2j@T3%3Eama4p-(xiZtR9&bxZP|80wd(F=^OBTvMcVV0KV;s!mp`%ZnL7puET z{b7CT_j_!m&Pi-JvNmdA?gP6wA^+ zv7fQ9m({GnfsyoFuH8A=H4ji=)+5mf`Mx1+O)S+t?zOn!>Zs$`4NwJc8LS?9}0HP3-#G>rwSs6ATsSc|9|6 z6 z!A=>L+M`cVQ?eYKB{JVhITLB(lXDU>R};J|BkG-B)rzhLmA+Djk39OOfXKPIDI-Rq zgS_<40?II@mDVv}E7*)_`ASD#hq$OU2NW3gE?S%2@#kQIw+`6BcCYcK@n|jSu~|N=>U-$VZMW&#F2vQ2knl&KCD9aiZLFH`oFZ?A}CK2LQ#!~+@A_~kRsi7()r|>=M2__DnE-)zXs7m1xaNy3Ic-h z=R)a!kdhnA+D`6i=)*EV) zYsyv9UP*0d|7rK{%wv-{MGNaeZYjoAh=dl*L=kl2W|4*S791icyQO}2iFCGaVP44$ zakKfuusoh+2@)Hy7j<~SWPV_22*hcncG-%Eh};Bg$}FSaEP0JzXbTXe)5T)9A&qu6 zZ!@Chkw)sbH+de*9`zW$J2)^;`3JXLq)c8RZ=|Sta(V!_l#brgE|u zd$sMI)RrigY*YgKa#Q6GCA{&W>jzG-x?%=cH<-7Y5}JX8T2PmBMqU0vT{GYK@4DRmk?yd>$gS)$17~Gu@+!-vmL+}70xwH3I_g0ApRv<@|DmBzm{O!FeiAuUHp2Y5c!g`9*4amtzH>ATl_}LBdLl51AWiQQBJT zpSBpXXe)i>S5m1UL{?6O%*uk*YQt~r%xZ4qg68DJv!S{bxllmN%cn3OVrGy1jeZk> z#ZCn)eTv>SA1l)+`fVa?P~h1?+RKP34d^z4YNeHvM{M!@Oj*G0hL_%0cm7`5O5|eW zAULoKgG0X*`3|eRbrzF|j6AYhL2+o0{2QXOKyJA1%1R$4lDzzUx$%;uxABS!+W#3V zp7JwXA@o=%#LW@!a$;m?s67`WOib*gp?j8WQ+2}G`7OLXiN;v$#Om*#?gKm~uhZGl zKj+I_tK)FbzH6MTkduhoxXJjz;4Ffpreki!9e(#pL`(5ct6da6U)@SwzwYMiy# zPS+Jfvu|J-imQBXY3Yd0tb}82yy9|Fp+q(^Y8MFwlQFO$=JPj<=J7eTSd@RBTiB~O zmo{7#CIXukl4dTv*2yXq2Nu^p`2ZkNSLWnutZFc z;WX5}UT~r%*)u0iR@LG|qA0`L8(~hG#lnYb7-<>ksJMjC>^y%gO3Naj`{l|3tek?( zcotH5MsOKb)+x|tuWRCWqBdRmIRykFHbFb`O!fF5^2Oe0X!cD^lJR@G%3iz?e{ljK z)v#e7p!1woXQC_)gYa=Jv7feXa8@uOXDYm5mwmH{Q;c$r6~)f>&eB>TLUWO~liwWp zIk|EmNxDiKB0-iXGjKVdA?;8-LPHag_gvA}j{$5`FXVU0b;l zw@4v``CB8QZ>auF`arw{e2df=5#uX1^M+=O4~42<`h75^c6I76IB#ym1nX4@79ZGt z`sXSW(ji_O!aYMSe}h>|-b=CxAg$g*;1SsLCsd2`F2o?1WsS9^MBB{QMA>en8f-c` zgIJX3;kniN)~RKpwFw-q%^`I{;cSXnIUl1o@Vldx7t1}2pXG8Mp1#6Mt4^Lln=cxE zz}(k}D7cx#BrYE(%DD`_5(9*Z8&LVEJs(UGAB6ayk%&ifIB!2w!Z8%_^6)%%7R=T7 zXAzDU8c;)vQ>+MvvUilRGo{Q509$@55~o+oiG!}aMv>XSB6*Dbu7YB7!9(2ni6lp&vy z!@;onp2!6G*GI)o@{i%^{sIHU*#2fw-<%06^YG&x20J;i^GPH`Xy z&81%Y;%@dAh%0?IY_Tsa7pclM8K6No{wE&9Ta6pPil5ce0{S`%>i#9R^EVv=0PDG1 zztA6B!Lc}r2zD=B#;(q)s7!(!*xfx#4$~n&=W?md^fOID=dHmZD$3Q{NttOR? z^z)2{rYF$b^vEK=NJkRx(bH8^^Yo&s(CE+G4|wwaM`^g=UC4>vo8an$aF~qzS-S$( zyrqpi7=N8khc{)CXr3AXOOH4;lDdUAXf(| zC#TCALQv{=1|cB>RO(6_ZHW*1r{cp|ROzjm&i3|N$Oy{Iw_aYx`RRx9jJd4^d`vA$ zivTUbW;IVrYHC>GjI#9-tYk1WHMD)IEB^L%qIrz~k?ovU%lKzBog_AU-Sn4~7j` zbb1MFto=cR`qQ4#xe?M6S@;(B{N(bu)w<&aC_`|MI_3amzxt?y4)+x=$=(%$74u5S z1gD|1;>c?0RXM{B6Qu-~WkgKts&v#)E+7GlJ%w})jjZMl7I0o(N*lc#VP+O0!XTl) zC-@**=cw*DD0;QEh}LuXv3B12HLg@&Odn6}rO1;v#iKX&F)KCYp-6_A{niBnA;dFrpPe>Xwq3nLK6AR#cq@oB;f@2s%3dtfG;2efVB##@wU)wu90eJ)GeJLe6BCoO ziwkLyTx{ItXe)A|Sr;;S5kZ_ePo4$R+Xl&1CE`h)@#_gvziuM?H`5B8mH5z#`CHzV zGRD83@I?$Pzm*cP7WTCSPclW)RZ;UXAP5;4Oerc#djBMEpyEl$gi}VH`6=x7In9YS zfN#B&_knh;3p^yhM`()=%o*JO1Z~BWl9d-Xv;18ii1hU(Z__>B5}k)8L^*c=>93Ga z+TScxex0eL?Wk^F{2XU%c$Vw{{Wg(L``S(0q(?TrS3Hro@o%ElKc-q@bzW&$fzATDP5PhI2%HE-{|eF*<&93S=@v-wD$`b1}s2uc|l$ahIjhIdy8W$7WgB%&xfY#&FTr%Qf^@aBcV z=kcx|w=zTIF`e!@FAq(_h22bq!A!58){f{rJU9)anaOoMp$9Rc4zfm8RDhSJF;+hY z!CAxP6ds!zrb zl0!S;(4{;hlP3pKN*>Hv6(QB}x>1EGa&uQ^ZgY;cxreX)!LHNWHtgsVI`MtN<>mjG zok6Y(5ASS`KS%ME+d8qV=<11#Fw>1<_$%*RVG8sm10 zi&X_T)lw0#OiR2W8mcTG%4mkCTLO)1uPm8zJ9~i-B_s-HWgX1dsMZchN+e1-o2ug{ zKSg=lQ{Rp91{1TM*;lG5Zg+^LLRc)3>le0}`)w~$$n|jYw3i+uzz@|ozpd^6_P))p zo;Io&7%eX9#AT%CbB)S(@GKM*Qccy~NE_AjHo}7q7xcBF|0L4V+8iY`-WzFO1pGTD zHtv3x%!}cF@+C}xrzXv4foI`{iDTGaUExpxb#^8w+{uakhs<+ytmEvL|M4G%_Tuwo zK2dB(?(Nx`x)8;My@QLijt6Ge@<>z^k?~2Rtn|3sT|J7O$8$r0*Qxu>A2{!JzS25c zynO_Gi|BH?85J~(ePNQAGL}Wo*r1+A@w(Z5?FuIcIX?sH=jI}G9K#mAqL|UR0Lw7- z?c&VA<)y^83sllyps(FgaL!^U2NI4bp$|vhyN4HHz*`0H_Vp544(K4e?%}M&PgN7{ z#^i}*^a^H89oy5B<%WL;YRCsqZXaWKDGYMK!t-T}mJIkkxkD>sus24d>iVm!_?D}o zdUR$?5eh7F$Ak!7bMf;eml{MJt@FE1bBZKo&dq&YtvesJKJ0`CJ z+PKEIN%2*+TEWejCOaQ{9D%^rN<0^)p#3XOxV1I@7C$i1I{uW_CZ&>0y6_4k^Pz%1 z3HieW!@(4$fale%omab@(;QEEmd)JaCb7HzBXw_w>ho#HgR85XjEe?8_p)Y6o(@pf zWvhRhOYw);rw_6O-p}W!5wclwN=69W=*lrWxPgw#@Xz^&i@U4zuVUwZ@6>GT=F?Tr z0vQjQ3e{9@F3V3ZTaFrH;PE3mWQtaFVf%^^V zlhUb2X+4>W4bd+nKD-WoT~3~kuUm8a+6j7&-+%0aBOwjg@kP9*&Z<+VVsO-1#v9U+~ zE74oS*VWXh_jW9-nZ}Ipy0Ml77;($6?nQN_xq91^TslbvQY>um!f)=j&*|-iD5wxK zIX>uR%b|cp>M>hA=nHT~U+P*adA9)%LB|nM?Ts5mvC4-NYdju;t@?fZ`cFZ13r=5I(t%b_ zs3>=JtRy5(`dj*=)7!9Ge&Z#eXvl?sZj`64herhw_jytdA72GiccN4?+Yx7ZSBi#9 z4Is<&ja#f&s;(o=2s!x|uFWzAlDN6Qa!#ZE%o3w_-AVfy7dFi9@n9-zutdi5==UbZpR~9Gp;M`PENgz#F+ASZAzzAi3yVLjEB(9f|zEPgl zGzblm&0f~7`faaVrgI_@NB`J{!SB{h%2@-=5P{h&;mS`$xYt8But)^&m^hZ-+nZ@l z=F8;zVmbz!YwTdcBiX%^yiX0zj@Y`StezeM7;OFeJ+vN_VQa|=9*h^g4R-ZL_3mE4LJ^x>P%@I!bK4Z&3L03i z@=V-0yKpAh(=&7B1GSg8)%U)T#y@OC&H0YsCZZ1R_UO=ILOgOyvXBOX2mqgtG7cz( z+i*t0!NI}uUi-gBkF3nhB$5AmnG({L9!h~7&$cTW(dJPjsJwP?;V5&Z^(JdySfoNc z*1KK3c{8Y<7f8D{4HFn@NNumA9NQW=QcjwcQF2+q;Kj zk4Oz^-6Xxp0DD2uHx#tHAg&!(g?4SmXwKhwJkuyO@;WwMcLeO|I>Wtq{hPH7jkbj8KL>laSZdM z;cXF(u?W{>eus}y4}u&(Nqo0<$GY@EO)Z91$x{WDV5N0+x<>4fW|Mq-Bkw=BFweDm z+NSO6Sr@Tafkn5DGCCW?Ik90No|8s2nep4Rf?^ySjV2zp#%sZs6}Mko)x{Ibg8|4p z9|b?iYmD!I*WPkzlF}#e=?u^?*#)f|6wpYO3BL4Cf|nMoljQ*uCN$O5=h(0+L z2=U?pJ9xv~U}AR8yCkj+LqEn}bzcWiJ(A1ctv|mAzLA~#y=6!y_MiQwPT1dK3^59F z>*E;Q)piTCdTaW7+t23x_iS6MNc-je^XJDckO*|HV9$Byct955SWkdFCFxz-x%oNh z+im4N-i!o{N>*}IoY#L$M|aVQ=-$N~fFZ|jqNy3!`iOXq z^#&o^G=R(C`**X$PPNYw)n8sR!0(p%)V3OrN+AcIDhVaWpz8K0QvROo!kH^YFE7oB{(zyA8t0!OpQEk(Q(DKL zbJRCF)9&YR@c43>G|sn^4qn(aiI|#&yo1qu5}Y*mx_*M!6y1 z@b4tqzFuUYC>Xx2Lt=2xZui5kfvun(kTx%?EWjU~ zd+a@PcmFj#0gW|Lg)>Brx{kPZFEvz65uqnv2`j{2s9-ipTAatWYiM&58NZ4?r$Yq) z*hldd)Lv3bkNuy@G@^(7Y^6gyhk-nYy7>tU4#?lvhu(K$4OLJ;MntSQq>dy>0fB3I z5q(Wpx1}XFW=ZE`kV#^AvrF)GLrHw#W7RG7ln24}8NcNlC~~%AOx#NzCdn-(VUfFe z#;39gIDHNebIfGa=ipK&rQ-pG9N4isC!?XkfJd2Yoh|q1Dg=-g7Uy z>%b87BqTRIxe8Rxt2lR+3Grb^SI~~BHhvnVfuq?3V(WzGVeX5x$hgqMqoRJn> zd_3-CPD1z}YHeMCvIuumYyCFn-l8H zTXeC(%?=gFTYI6yGV1Yp*s>P#h)uDvHJX7XBF6v6(;FDrLny8I-+_^P87RktsWK)% zA0rGFskucTo+gK!ok(1hp-rfRSYVr4S-lWkP}No(oQx6*f^$OlQ^2+35{kX&S+Icc z7O}dHY3v{9%Jl$Tg7UHf1492^x)7wDLnC4*b=eO?Ni0RJh!&r?nAia2r2ZBfMsmA& zZP zQ=F?8I}+}B*OSN&t_-Jt$9?VWEUA*lH##~!rX2Nyn}MXiKj#SBaabxPXi(1i5jhf{ z6TUIo*&z_8`wopy;Op|Q<@AW)xXflEnp(g~;h<~;uuklQgbOZp|8(4Qb(9U1=|AV? zh68Y;UBk)?XHczATcD_-zWEClAAEnq2k|=?Fw|#URB+!b+}*4NYG*4RZZWTPg1#SY zg&s^z)Rh%E=^3xm@@#o`mb5k*k9oxVSGTmlu0FAy(bHsknqJS=uvg`9Qosq>rFdI& z&5X@S0CU=pII*XXPWvjLeW+Dkv&StPpBOoqB}~~@qX4fJ^>5tBT8TLuF0w86eV>HIW9ecIbZ&gjvmxkI5@FMLt|qgxXQOc zWUDf8fKS51v#k&9lmMT`xnAVTC*2xSp&puC?xn3^N|`Ko zoRLvdS496;=hg#43f4Bu)&?hPs_RCa0Nc)}N2@eL$5$czevL#Zuwj+}SEr;@XGf@l z%LL`(hmv|uenNq9kJ(~H%i~KeT1Q!BE3u4sk=MP3v2H{a`l>HLwzBC4op&v(ykXv- zlA^M5)!z9M`!Ij^U3u@YTyKeK!W{0H&%X4~$qAs6`{(!l1H?&$mYX{ zZc62w#>M9`)W-+KSi0l8Ldvq}aU*))cRI4E`W%Q@ucFZ(;KN@|eF=-7x8O8TZN8Rg zCAi}Wb-J?iz9r(`RF7S)+dVQjfxMQrrca^6ATmitcIMawSc{|M<#~A?cYHi$l5gk>3?%WHbnKiwywTn3=+eilhgSwFhr^375KVp@m3QsGKH~r2 z5LE6A0vh%2s37_p+$XlBpjvoov#qbwOVkEt`C74x@gi@aHeD=)1<{lSGsmn+s9F1z zF({d+${GjzOrmdmAfHVD`)vPe?OcDRV_uNfbPRsVe4&()l2~2lL+T_Q-xHgqcq-}v z^J&$*s2=-0a)LTt@u{PRz&xM2fwbx40?fFtj360BB%3KcCl60V+@$6elA?@)M(LF( zfnMBtBzn#sC*bh%*DrU7u{Bx;GPyT3>1xe0O*jD*s??H+?7{J3@(o>-Svd~}zC7@IE=jduM?Dr09O!Jk^KpC8I6LBnNTVf*5-n`z`v9KfKghI!7*4CjO zyL1r<>R=%k)%+JS9->JqXTeoRED|YM*WE!XbAQOfST9Q>W-TGY#jjghBhlKS9z`FI zG4B&A1S~6Rk1~D*10rTBz;-UBFl;`VX|2*oiY%9l;9D?&Va~KhHl8ClEYaMV8{<9M zb+RmgyoH-5>h-Zvo>dFDvno~9w059>Wp?$+%F0TAej~atpTO!_8zl#YM_rVeB;!OG za^q?fMEUTlX61*e9v2c|P+&BfEU6 zJ%~L)58|pKI<9;{BjP<_mYIr%>$fe@imItZ$T)WR-_Qv{L`-4H6?$FFkC&b2b8s>lG<>(R9+p8fL+WIX|auI^ron{B)eFIH5I0J<8t>gsts z8)^hU8c(TptGqUnZR2XozZVhmO(F~^X%&O3i&0*Mo4-uODVBG(_irxZq=<04o~lY4 z2hRI=I_KZqs!?AS{jb71&_^6Toz>huUgt3~p3z6z@1^u*jf>Lr) z(g&-Ip8VAGHaXK3A`gTfeC5pz!)->|0{%`DZsnv?QG$Z3j|zYt<5LSs!=*&BMxB5p z+h?S`J;SpUWt#ZEx@F-|UAG@$|L0 zK43-CpTx^_wh{C!E@~%#<3U0c?L(@!RUPG>9vy`tCr@axUZ#Sjr!4~wXwA0YjBP%& z%fH+@!B01TW|-lDJolgNCyP5gX^XeRbAhVvo&|A@)tVZMzW65hm#UH>I*kGr}GU+Zlf(0BYvpX9{^RT(#G?ces|l$oa2-NoqS%T{r{@wPIGD_>BW}5%Q){ zQi(#a-`CiOf_qe5SLkicC$`ZxZVr<(wyAkjH5&GUqtP8O%OJt+`e4Xv#3LvU6zVW) z51Cqn+L(ZouOgUG3OtrH;K!EgYjwDGR?cU?EZ%{fnpS%(ezJSHU;tJfv9QNqRu$dv16!Lr z!07^a=X?hYIRb&MUg6>0Ie6VQ(2yfakR1V5d6%$=<$mk3y}fI8W+@~+UAV>>UeR$V zfIGYfvQABQCaRZ|a?GAX?W@`{exOBnP|FWqtgdCI2)!iTKRZI5*9pQb+KHUzxfARN zVLAzfoQ6#-ub0UZP*Z>Pmo&0TA@4PvlAqqFz+0o_*N0jTkL;)SqN1Em2NEDc6=i3O z2ERugd#L)tW)g71^GHT!$Zj-v#^@*28^6H%c@CGKnD1{{%JMF_@zmZ{XNbm@xIZ9< zP3LnhRz+XN=*}m47Q|y~_5HiE#hcD6Ti$}`UN}#6ht-|89hL~2U;!=a)KXWyl=8QZ zQtr<%SQfh*-InMlkB_+^_dfWo=Xa{V-c%QcA`6?6I8C@ebg7&_l>Hf(gcn(y8^&3B zoT~+g#&^(~RDn4?!3ZOsMmDzOKF3$Aaz8!&T5Y;AGf|!kTnS`%4}Lwh@CbzECXr;_ zs7Q|r1zKtcCaTB_bz!&d6#fLPk8P6kuknBL+}Ie-y~s)6+3v?bR5uhOzNg8&sPOvd zvm{rm%3N68bmNP)M-N9a0kmo3$`4IHFw*1jXWNa22`d%U(vRwxQ}vd=1Q@a0tFjT? z71W=KHlQ(&X_;g)3X>u3g*b6~Rb`9rdW;rqDzzS}Q>L{|kF#(GHX9H@>jg|%sp)kI zs-}u{JGvw>`2yX2{9@|UVXB(NW4owdg>KI#x}$kX^~Eu?#$WJ9IkHj3>ei*keiatl z*Ro2>UNK#5_961vbw7gwI=j7E96JR@%01N$SHrv%T?S%+?4b2bYTJLtg2F<>9)D#? z)f&gT*DQ}{1FT+7@N)QF;J3G}pT6vew{LvmCbmO%S`?r1Z@9clXGGnJg)8#@*x7e> zgCTi+ZE5vND`p@NR4Gna4GZX!h^9F{q_fg9;&D`(qh$Yuxz?N;RlB@x!`0b~AezNG zyhalyugol?Vx?CcWb8)7Uc$VCVT3O{@_iA?>t(#Gqypcd6>6g`kiJT-vTF+lLU1e? zV@LC8IVc~62`Twf*4~73dOv<|i%3k8LEIcb<}oSQzv@CCmf;&2+I#7Q9E-KswBI2) z*UE?SLi)V|_?ZM5TM7hoK0YO-J+SBJ8ozcr@mZ?XqIzpOc|&vJkAHNIim(zn9DDz* zbW!#V&Q#dd|2&VR3OPSt1c~S>SEmy1N9Qt&a%3k3e2CX#vYvLU&&a7RG_bhU$R>CB z+Bls${)C<$L~Z`)IFxtCzGnufUSGrM1%fjn8LkiSiA;Qwhh5P}A2|X#7VyGfd3Wr%~9U=}8`{y|g`8h<1-p0B~#dB2Cl<>%ij46-xI2vZ+CW)^034UGg@3)IZ2 z^vY}Ty|AT~bbSjuR2(ZD!Divaqi?<%*@5>?Zn7aJFnP7W!t85V`kR0x9rkZ5`Jk7vEou~9>N&GBneng z%}O|B%u3kU+h?+|*H_)f09^~*#9{W5lsI?B;n?R;+*$jvRi;B+oN&IWvEzi){g{#2 zAT*3*5MfczIZw6<(j8s)(p4kn^n2#M62(qDDnJIs{#pG~wX|(yVu61EgJF4zy}6x# z_HWzP>UoJwkQA4`IOy_X`C$o@+?#MKJJ2Ev)2*nw&iIlWuzTm zX*B-md1<@&lAc)hC09=>gPR~V%(7`d@u&L;T($?SNJShnqQ!MELITzU`=l#7>&kJs6 z_;!Lqz29*37rUf<>V%RD1+H8e?7Tr@#K~%9c*yCe=X3OA)t^ORq>XD+kQqTd;yXxPQB9p_eyMgy`@o5aYZMr00H#M@IH_&ObU9I5~=aaQ1k!=Y2eYHH89TQe# zg22hHb{O!%S0PhFiKS%(psLjRrAam7PlZUJP@-N{()ETU&Q9DSrz-DOS(r$RPgb;t zC}poNNXzBG*!vDW8>Ntf_{wmdj=M7{IIJ2o`MYl;VntubUz8j+4$c4Rnf$&iGbI)l zFL&>Igl%&N`Vj@%3hXTht z*%)U_rc_sCxumJZ{rl<>Q3xcMwIs0gB0GlR^qV5s6@a!UkmZG{}CjsB42PVo0){SWpo=%89&(*g*_Yim3tu^G`4qYk+ zdSuVqPkp~IW5=wAHJq`9O=hBz0dM_Cc=ue0k?w#5_Xp+3Rc^1eGydh0Qz(Uu~j zb;oyY-Ij6{_Y-(2QslI9|G|DDN>#5t!o2rK#K@FTtfq_-{k4%-`<0oG!jD`9&?^<9 zNZAM%>!me z>0|FRy}@UVcw&SVEsjOQH1teOWUMEgWj7nw4_(AKK-kL2mOpo+g=!~HNdMtb^_=%F z|KP|U<*A{a6m1-iDcO|K@~p7Q1*c{1kx^!F$qeDfO_9u%$Q?Ho0q{B}F zafx6S5eJS-uU1U1aTWi;3&F53;4DpzTJrt{X;CCFK-lHtLHvxUNzC3ZX3<3Bt92$c zwT0mCK6iLWdAi`p=$LShEXm4g34_B=^y{;s=z?#%zDZ7mH3<_sOCOf^?)9A*x8%4q zQK(&(SidWh<3}*DiJ`;Pr+>C?^z{-h3;j7MOR)ZT#-dxCNktQ#_=r5EUF63RV@!Aa zCeI#0%M2%zoOSnl=DkDpv3%TOb$pGpJ*^^@2!Mv-z=eh?{!kq|5hR!9#Ym!$K0M=qk=NFtNWBEbO!`Wd0Q7>zJf*-|~ z;k9?$qe){Kd3naN>E6X1{Rj{t-I%DO5hDOFnzGt^Lw*Z8+hxA4d}v`srHPBo?>ho$ zrY?BXqLmdigDsVg1}&*RYSU5iYfcQ5Q((ZP3PR?n`f=!labyvhlwUqI{@;gB#gunu zrCFgwHSwBWT)O_9jO;eSkx0$TV=nA6l8>W7I)No?1^?WRUUYte9o=$-uU-68Q5S*{ z)9ZS5^sU-dcT3>1{ma#uz7;lsjXZhMv_PKMTHZsXW`DQV3t1ug^soMC86B||%L@ZM zO}OC5xI^_tKFRec+f=#>XEmD3|u~B8UHSkL7V%6mUB!i-YU> zd&~TV4^`TI{tGf9q=eIgvUp133q1UrbH1#R2KJZ*@5^MFjrM4nWx@8O#rRbOE*#*i1xmpM!Jh$VoGHuva?cp`n+%#`KA*FOiy`Cr4 zILtrcsxvO|7A5AdFYJ4fVYXjjQKi*OP*e|zcm4JNuJPEOMAsSttla3ca0l-9Y>1m- z2~!Q4kA0?5$FRE9ul5K%$ES=~QIv^mmt&URsyWc)$ z2d*026{a129{Xa6Eb-F2yY6ek&f9&#uWrNqUik3D z7_cylyFqRg(c=oDjZBh`u$#}L!-r8GL2mbKkuhHx(j;n?$J@bngr?k{{AjnQtpT1W zYKj0%z;(h8!10cq{j^%7w0uxNQ^=CjPR^n`Lel4XKWZw)i8rGhZr2xv zeQ2y_q(KTo$57!$R(pxkBS`}^jmuJU^?(;u=(c|@YMNwaLXmMjxjADg=LRG8ZD`&bPsw3n`fIZ|D`QY8 zk)ucE|8@q3G?ks zyVYWOGXNHgX}uxj$UK#jS?IDNL;!4NC8y$EeBZn$ zab755@6eYtaY;AzLpfZL`Kvr*Bx5`3YP*038x?A>4Qkp0vc{K7{5&JW1!Gg-w-MZ$ zF-!R|YI&N&MR|+Mc|!G2R#e^<10bV#)Nrz_%sgU8Fg1*WnI3gC(e$x;g&nmPRW!9Q zQSP`+4J6$_>4A-R1&@Y-4;{JotB@nLr_r6qYPgsga_B9l?vn`UNLqp&Pm)rdlCK)} zc-5urr3GQ<>8XZo8eq%|)Zk0T;OJdkeuq`N-IG)>gdWr6qAlaa<=>TC zKb$-Z*@zI(y1aSRA@#niHZSRiAr+ZaFUCZJQ^LWIsPhx%AhgHY84r%7^<>OYWVC_p zox;j4iU@uI8$H{T5~C2Mo*BGva5el~oHedF(u|D)FdlsK=K=i}4PIHx8h zT3MB<{vmXzPf2M1}!6YEEKiOBYuIa+@`aVn1ItxDaDf|&V`yCzk z$59F*9HtU=YN}p*z5vBfNjnB%BS&?ElnBAfL@0S`4gS2{;P5b_L^OTEz;LCq*6%v@fp3J#fq}oAf5Pa3I=%ugA?5;S@!l}~cwef@6;~tWY+9J=A64W2 zLHtU$;jHO}JQ!?A$5b2gDcTVgR>WDbdii zm|Qd(=V&#%32ToFx(V8#FYNZG#U;WD(}~5F4z$Z|^~4k<<@GpZd@E|T>RbqUhzHwS z{zThLlWrpR<ey=d4;-n6cnSrFyPi3I z4waf3tEwPbXBuRMBfB8zxM8=sm_74j*73cJqKV)S6B2?JBbX=H6hN}FDu zl6EU!l-M8uuV!F53^<4Gt*uwMu^S1^MGMD9XmHoj-+g+>l}2;`4GSI{Za5u9wqzTw%hMt-%zF_f|J~-vsq$ zH+$ZLF}t9PZA0I=?0os;S>Il4pLbvr?m;AP#%^~^buvZii-kbX+i9!I_FeMSM*=M( z4BU@q#z9#4vi@w7h}I^}v91u5x()3koz;@cHWalY(4nabp6e04By(j1D)toK`eJ{> z=koQdyj6c6EhO>bqf1U{KAD+(svlE zUsWiNPCA40_NM;M&AJ^2d{6mt_9=;QX{49XsM5Qk=-HifVMyhBu8_in^KCAJ*V5B_ z^DhNCFU?PcHGjqG(_MT%;URg((rU))+%aAdht2xv{_AB$mHt&3;6us=r!dn=iTxe3 zFfBR$g4x^GyR5t%X;nvNN)2ccirK1@T?Xn+`yVY4zamKY>KvJ0I zW4YJ0_=?Wry>b2%ak4ltB@1)>Q}A3TCX8IDoV3P@eE8dgLJR~Lsg>OMzBu8l|BTf4HBUq2@eshM-~I| zA1a;3{?9v`So!YgpxYA*Z+C{LvrX@wy(*&7!h7D6 z1wTCm5lE4KJbA=N|kxWKaFr(!?JV5mY?J zlhrB9HWrkSH7Y;8?^nF_^7L}freyF_T~btwJ0E?1<}slsXv8SkB`J%t%CV8tcKZB+ zs%k2)LlP_EamI0vOCp3h`eP>w@6IiYU7Ih_VN>Sp(cG1YFVO@gB3$(89$={Jt4|jw zWmJ%$L;475+DLACA{&!tPCL->xljKOPT+GHE<0bec34iCoIcIzu|oboVM2LX)$Xxi zZ&YQ=*)w==#>Bei=%iND3_F=`DyfvBSZEw}<{J;H4nIc*43c)^qt1P3 z{5&(0hC{9l1@3x!p*@(nBB9B($1yku90oTee{sY%V;`_d5;E!!R6|hKm%aN-bIK)2 z@%>Ou+K8}>)J1#Z%{$M!?iRl;goh;~UK^&|^qi8VvOwfZx2WZ|A|5XD+Ges>Bx~z* ziL%&<_BjWAV^dD9Whvk<9f7dtc-~*%1@ViL=f0zDpkw6R{Y`8;4zhrp{3_CCRp&u_ z31?Ut(KJbNt*{s2c=cf-dNIxvb!xkk8fl`WKY6}=`DWu;rjw2Rz9I=%qif$`SgMR3 zN-BH|>BFdxi38!+&39i6gI@g1O^YS`Z1|~c2z6HHad$BdjRmQ zkrGk*xU0{Wq064os9I!x$}GP(OyXu3JwJVi%oLz2 zes@Ql6rOLlW1U(|5h+d`QE~?Dk&RH_%;P8eIh6?KKx3Fbb(A(~)MMGC{zqUW%)I_x zJ%ydO`2WM!TZTpTNB^S0z|dVo3=fJn#CNSA`r4T5y{&>fOWhXM+c(%q;u(jC&> zedqT-=RD87_c<@;&BWfb*IJ)iUstZ8TBh9Hl8HfWy!R60-iEVYZ>wL9M{{9A(4@)E z?DH%LH~A02UC@rrehLu6@-;$tY);Ph&byNa*Oz1knmGsc9gtq z?Molt_pvdh=QmWQPx_Yhb~&WM?Oi6gouIPYL14QOo?I1!y!FTauLXC3kgtOERTY!c-~Y-q>|2iO5C0}i zW+swlN|g3IQ!u`@(kH;v7#RYE`HOfN~^dA@RLFmo#;rxHi|#kQvj+!I@t{@|XiV6-3> z{Q5;i3I1(V3g_&#^8UC+D1wLCPgvA_b+X8{SMe7EIcM^_YYOV^-2%$@3)7aYHO%N5 zL&@)aA*5d|@}jIt)O4B16)L;iZx`9!z2! z_E=5sKZpG|Sxyp%+ot(CVi}vgBY%(Ww?zlm_Bk&O9z$jIXIPzi^9F)hjo&lVsqN_( zchuV7n=t-Y(HD#k7uHRyyhkVt0iZ$O1?6;^!*ML4|ynq1Z#XKW$93?=qb0-Pyv;Zh3pzkXG93!P>8jn8qfZF50`s-s_5@{cAR98hJ^{hrPyAvo%CYU2h(-jjk z#pBO_nd)?ZMHpb3FE>T_!6NT+w=4;8aMKymCp4^-l7y;C& zi+)rtnpWt$1911n6$@^Gh6O8hpTL?cQuu@_)Gsox+9_MX$%8|MreBAw(9Z#U74U4U z|NQNLgt5Rt$^THm7ZoDl-8FXy3DEWM!-kIwnV7`as3--Wj(%42enADl-^7$W!AkUs zsy;$r>%lh%#*rceC_+NV5L(kyD`Ju_+Jj8ma=qjmbItoxtBU7GPsC4SYPlhlt^V%B z-qP3S9&vZw@S&zgp{CFve-ZP54Nk=tp4kfn{ihQ$p`yaZQfv}FZO8cT9nbV!1frK; zshFequXt}GJTG-2CyyBWyo3fR;8O{KE$L zh(cb5hfH=*C;#Ge9KmdTaJjdMpT6%Bja z_GlQwaKxxIkDNcUO`k4$lV(ZZb4gd{AzH6eIMGRW(z7$Fg4kvFf^E~j zdk>sMN!^qljdbxEIa5&5E%W)(yyC-A`a&X?)*<^TL*^!M(Cn*+DnqKJB zmc+9_eXh*`WYzg@85@s`YAJ3P;7cg!aWlfVTj=hSc+No2S|Cnhx7DV<^0+snQM__g zW>UbzO2Ch%TKN6MvF8tps3$KkD9SV~Do)v5eedM*W=wIV;K{hB_Oqefq}?k9LZnV) z>gpZLqJF$-nvVkk?bMBG+;1vH7JK|+>JCI2@*PFUby1= zXrV`pU8jfDzx{2OcPj_{_zOy8;0uAwINZqWM*F+BSlp9+<~P@u<5lL@4Cm@PEkewFCsQ^xHiE-y8zmDoE$c%Jjd>MP5A2PuZV$5L18To3)bhv z|ER~mrd$0LEfq%GW&<7Nb09GX1N%A)OHRr8$N5U(w+qFgr5R*diT(&b~CAYlolsoh=P| zJsc#3C%2LXSYQORnI@Hj*2Add_I9%_+pg}dY%YoUccz3$MRU~EdHCeCG1kw%oxWEO z4dork2%090L>SJlc)?$slg*?W&YvE9j?8TMSkZGm-_h|vj+_BFkTbwBK43h%1a9g1 zl5S+@lsTWSDn>EwQxp`lQYrXjiG7v#1BcAQm zRGmF_#mVc0qt~gV{TaCooD3gD>PH#GLOb>gnY{tLpq?KME8$epR zoFl!VhSlum!~;2Uu}01W`~JAU0LqC6>;94>hm+X{iqPykO3mW{kA7^y!~az8ul&vO$pjMiTsnW8PKgw{zwL-q&^_+fvAG^Wy6Gu*`$}d5`tMO*BwJ9)-u_1e*#c1eyqrzyA8FNO9eGe@8iWb#;|1dTJZ!Jj4gQVu)6f84sm82$B#VuyvNi>lGxnbYJm~ z{cMTxJwuZ`vhW^aIxk_Z-u$~0J_06HpeCg+4Aki@Jeej^4Wi8- zQc+}3A$&p*i-a>s?X~e>E+g(P_e#{GHN zyK|(EEi6gSSaxQNC3{=+d?)a`cPh9i&bnq$I91ZvYCK&12Lo76B^jaK6+1HIq5sO{ zv@H*b%Yccq$SJI&|776W8>x1@^?9a)+}(}4m7wh^Q}vP+zJn_0zM(R%ILIkyL}JHe zP)H`4K*!sGH4*#c<#tkBaFeDM8L$a0TAsIng$>&I)1mF52l1uddh;7wHv!?d{kF`> zMjZjI=UJ)z80}nbrms7?i#atLvqF7e3h`(7=IUcWG--%bE(oUzB`;?-m_ z>9$U9e9tJDHn>WB3bwl7`B9<>FVmziTLZgWqg6I_v}6;2fTP`TP7{LB@ig^)p~cJu z&=-)l3pxj8LzN#1UG5}D&cS?J8)>1%C@!WXadHbq!Va|S+E^_Orrw>knH$=zPEE& z>7+OHj2Hj4(PUN5_x+=_lkap+PSyz z(q-LV#*lxE;Uw}JLhXb=S6pfP7dEE(CkC{kiRx2qkrbhNxc(luDd^WeDhtm+^Vt)XpW!50cy8?$TJZ)F4> zgT^*u?~VNwBVlou`2w%wG_`nqBbzlp=ixOaCC<)!jpfR@t!kIXEk9U7^+VMm&@OMI?pA59FV4Z*WoTq z^8*cB_1`N|+(wz^cbBZ70<_w$Oy3>Ls-dAOlYQaPX0O`xi3$&rFwfK&Gp+z(L$t?5DtMeV909?j~SR#!`qc2dVMd}d5t*E~#eKmS_sj9BT zTj7p?c34s&>c&)L`%HD)uw=bWZV^W^fhbm&N_qcpszZ@>jE`^Y`n4?SmHj@&9Nliz zwuh5~qWsa6Q?C<( zm`qG!B@G_UpfA|u>35G@9MYR%kQB3V?qC9Qh>Db|N^cjVUXgA9+4~8T*_kk2C4`l~ z!!vm5$~Fr0D9}Nk*NuJHLT%ZrW2?nro~2;QvXvJ;j07sFi9(_n%o-V|my4RKNV(qd zxiQF-nosRthXF(_<^H#oXgji1o&BrzkLblH z6%Q=cVa<*~Zx!$`J%fEPqQIXgw5eh>R!!xD@r=R&Ye`mH;Y zn9+~8!EU#gHksGpAO}3UN7VfXw<${%`qlG%wty&~_>J|4%ODy5G$JSDW$rp_V@7&U z?$bLk@ix1=3IH^ugf40$q3}UP|1E&3ErGLt*P@0NQ5`@?cu<-mJar;t%v)@~9ngyW z2~1obab(aykW5cE5rtAWyBY@-ZeNldGeb)u9bw5In?hhV0x`(;$shkZLSVbRZ3>fS zXE+w|3p`Y^*M<>EDPnRhwc#8<)DUA%PjAyygZ8NM?jWf)Tm70_LasRmpj&8e&u4xe zW4q0`>8;n{B(1eR8j0|XtJLIR=*htwkG}P8y4`y4JOZ!kOGOd_62~VeLr1nJVR^;g zkEBQvEbWSTs-#brfkO84Vz5g^O03)_1-~mDZIh|KPw1W!KZ{WquFXHkrXzt zrcz3wJ_@QlM4a&^I`M1JjiTYjgwrL&ks(rhI~qNz{Ss9HO}I4A-!#A60Afmjl|R-} zYgYIus=Wx)oNqhL0|J_`fa3%Ok{{!S&?k)sCcix5dKk?-@p<+;F-m^kx0n9~;QGjr z{d5EXU^BWp@87Z>WKT}F0VIIGG|cipZf6DIcsA1F$+2=B>ic>@F!&`^$g^~>$<9-$ zZx^qc_1#NRF12P~g0?25||P@Ydy$&OBwVSfasib`KD#i_o8^L4qsZ=u2I(~B155i$*X-<`QR zDQb|`LGtm&Q&SX97FL4CWiOQ5E@a2q;^%6U_b>xNTO z^*-ywZ3Ov&ls*MN{4>)ZeM>`_PyzbDsZjVa6v0oY+I0|tU~3NZDWjpX^h+5`PSr^q zk9BYwc<#M2C7;+3UNA5`2bI(LQDwy@svkt>C9+L?dX~_ygWi)JQG9JQ%NKC}0|NBG zEw^Ka<*c4DJb$Pt(}Q-j{u@`HtIj6P=s%LKbW{4@p3Nao9}wU#y*50VyC8l&RGgP| zlp9-Ozf5d|RN`x;C6?n2*$-{LC{^rq7Ow?Q>evxBzKCwQW!L<8RHiIH8h2~S<8ml! zYPciQ3`ZUaoKXB=0BUg#yvM9x(WPfr=-2k{Ylv%kf#KZ+|8td_aF#(i{0dp(bdMzt zL_X+Kmr56g05qW_Dq|Aw!_K1$nS*G6uKFe8v?_jl*wBwP<9sb%)#`317(YJU4=Tl( zYPmC?a%Z~S1CWk%D0Y~ONNGTuOW-~7mG&U{BV2J_Y^fw8;2pu zhd}|A%q~X`t6J^Zx2jOmi4$kU6*#v8*(5d%uw)(gG0Q(CIy>KB%qMyb#=WU>@0xfR z7|Sj@mo)5u+1ivAtcf6+i1M$y;O%+3B~r`-XC6KEB_kH=+4-a%3>S2Vjyvvm-?htN zMg-)!1TLZzj+Hwkt47IJ~gp%)_of`+by^mBq3fR=J z3p7#w{-o5e71nQbzcSb+P<3*Y=?qT8>MhYNvUf%a;5AvdqTdY)FBzUmBC0U$$Jerw zQWbRNzz9awa^RXbq!P{Vc@xJIVs5MOQbDk_&%c1UB~j@A<;D5p0ZDRR!YAI#5w% zGgavHCMY2#$&xhj{5k%Ad|)3*VG4aBUv-&nX|P~dt^KR92|%*qQhJgz;&d-Uo&9YQ zr>AUxPI(X&YbcVJ@Ynd?=CyA4J?G9y+~!1>yh0H&&(736-_?$d8|sxrt?t%=?h`*+ zWqKPIiu*6nRe6@hfSCJBls5;Hb-nrg76!9FruP~qU?@$qY(1F>lIRVIj*XRl{~qV; zyt3%)o#2tUf4trAnQ_&(fO_a7yXVu>E7^=L#-BYm@Eqy8_byG9?k4G|_}&gVIB>i`CL>)mj>lmNyzH z@P)Q*!}n{}o$B5E8hGD=exrd zN}y|J)fvc#EXD*XuN?@M+-pycusKFiwAV-`SJ{o+(q!j68Cl1cefEk#TseE^_iWaR zGKJ$Z5GiRc3P!yQG2yZ01#`jhR@Subq8XEor@}Sv4NZ`COpL?KWpR@{_d^l0=>WML zZIpI;l20>!RdBneD?*q2Jh+V&P8b7pbJ;IZZ9vM5=9}XCY+7?OviHkQ)fXt}&TpqQ zv<4}Ws-mp=qoRI0aL0t=HZq&pbLa1VCJkLbCUr@jl>0HUiFi=V^0aoZ7ETqCIOto_ z7sUj+&JFADahtMt_kQ#y6Ksuq!o^g(y#frG-~z_qd1Ga~jo0&XC1NCugp(m|THoL~ zi8Qjo13MVz1`1N#T*-?E=w-Mfk{}U8F^ehLA*J5zH~%G9^q^&pIE0RA^dGg=Co)@K z^iNRSt(-fwvB~%RaKRtl*CCc3`z%Z{BoEFUb33Vcc6usqBmluY57H}5o;znYj2zUi z3MGZ8a8Mx(zbvODvLgC+b8rr6Z3Z@n|cnyf@qfQTIvFdPT~?zMD|c)a>=0wdw- zz(VdrN>tI2r|%O?XD!KeGV4@v@_$-qzVmZp5Q`8C8#2wmtliDIy1F);K&U3P}uwc4U6Ob8g#+3_nfv-WV~<_7lw*#vl44;BL0 zh~0>-%_o0sTu%21Qp{Hdz9GyPW{yD6^AEkeuosb1N50$NyW2xtbb!yTT5EW)`$mb7 z?VvldxKp40lL6V{6S_&jH1Cj~Mt61`(Am6&b^k-cf7z(}FyS>XhCnpj1<`(2oUOQvGTP~$p&?Tlq5;dFp}naMYC ztXs1Zv)xX1Lsjb!By+{u1Ft1aT8%(#@)LIhkiD8j!Kd6RDH;Q7v?Xyj{yax0NP$P0 zSRr5Yf5 z4VlspczFYUE`2nDa89EHW^K=o~*i^A(uQt`ltS;gE{pJTvcdonOt3B7=c>a~Q$7^W-MGZ$z_7C@=~ zIB{2jV*Ygnhk|e%rFF|VyU>H`+8`+;zVC?+duZ$7yn$2U;jr7#BqAR z`(_?-*w`xCy1dx4J&HLNFCpVogQC|rXZu`*w%=v*U`D0j@s4ZL9M#-0vAezJu%&)^ zn*JSAga8YM(2Bzb;fk&ykMi5CU^^>nr~78-mT$CFoTr z&p3EUx7-W!R&FYp80rF5Ypq0D)DzH0FVsC#2lJl%nSA>}5 z2EN$c&@yUl>Gxc3#8u7yphlD&)1ky{J1@@o(2VBmN1|gvT1_clIft;gOQ&C%@gf7bcd<7db<*v;zdLx9* zyg>6M{7ci>KN|f5n?tLXH~5kPJF<%R9}Gta<+$VlB`dxalniG6SW8!O$j959Bn%O< zrI%^~0BBv9PH4^WU0zWgtyCy#;cOcadHdr!wk4l)6M+~Au>ehmn6nMvPOFPh_z@`n zuW-K6C3IH>#clA4fm-xJEWe;Y5jb0qoK>P&GOo|lc9BPF_`T-W4=hg!w`v?A#l9yc zNVGBvKANw7IUirHpXkIu$qo2j1$hLB)e<6+r>#cQSiC!x4ezT#hu+f56Po)90PR~v z&7hBK3;>B(K`+R&SSOsPkswFk&mAz1bRI+}d)t$}p|}kqh9Cw7$v*{mH`{oFI|YB{q-~Wr-Ttzg4E6cXV5(AO_~Rzf zep#k0C3NfmZc6N3rTR`_k@zdAfM~selVO!JHR|_oS@{p#jxPoB+ncDkKU|M5>&BO7 zDp}VduT{0Knb>Tu4}yko4GWhP+VmjVUpH`bT*NH?ZEy>3`y0)Ufk5Y)RgxuaY=Ig` zGqkBJr%+N5Twg5)wp#hQ)z%Y5&bu-EJaQ6_Mkeza0yz(P+n_CeM z<8>8^mEr!lV+h$O@Iw34d`Ibu9VF7|rT;}xm>^9fq;3VcWy#8e?1hT%e8Bo^ksj3({ z#D8;GsdcUWdn3$9RoOA1s| zFB2(D1yYu)t>6Q8F|`8G09zZk*7+}o?u9UgJ-Jr3Ja|;TaBQ;X?>_D!J#?36ckFn& z^Mp%Xw*+5JZ~*v1{EVEK01wkHv7ZYfGAmydM;`94JK>mlYwvtwChBaJn6<5i$-<1g zjys+M6Ab0qGzKroqcvk1_pkX{n?e~)Zq``4;Uqu+I}ue{LU;G?{f4>W9$2;e`ye-v zZ8Ny6j)9UBskFwtTG)2lX}%YpYLni@#?y_W+PZ?i|S?gy5)+sWlt?vH%VM5Is!((9#>F&uYONww&H*|5};&)tx ze+ZDErluYnn~tcjf3sM|Ax#_q8-fB*Th<5-o+u;6o03HJ>%2btv0;J^K#rZxJ7Z?+ zE-%+#`-h#&NBwC{u&K2^fEYs22(aV76!8jAIhhBpV&}SF=9iSvEE$cSJ}gnIPLJLU zK>R@8G~i4kX`48@9YQ83RO1cK!Hf$O5uz%@+?2trWDG1NXaE%gT3UaA=3P2W*kOfx zyHc8o3wrWTSoFsbG7n8mZGvDWH8_y#IR0sqcUojhPwAOs&DB<>VAzNSwd0C$B$V;H z5jpXHl9CG)#3MsTjHCwzfuh`bJ{#ypC5+tgEEqU0pgXX_l5krHURWEyVdT}0TL~mY zHDOUyRToqq%GEa+J%}E@J}PQ!qG-ZSrB*@~t;PRSxKIJahBX!p;9KCtCfb81GZ|Vq z$?5HFwiDuCNIJzogT@uz@VHW?!7cPOBs48Iz9%SUrQ7i3g81j6mkG7NzDV^;5mJr? zNZ=Bqp`s^SdAGgqPS6uj00P>6L?|8Pu9&x51pfXlZM=YfAca(33rar{a4U$NWxmiE zBxPi|76b|zI{W*EO~joeN2m=*9TL)?xwRg0bfLFES`&Szeex2}(Yk)sr zN8if^u*s*{V67LoUk$7F|BdU=%V%n0S^e_3P|b94mh;H>Fo+kr?CYQaLKk$+Ddw#W zZ#tgt%RlTr*VWzZ8nmCa(G&fj$?WuTxs_QzP!0Jg|Nk6K|MBYwB8RioM{*H+UST+C zO3ED(Ep4Nxj-C|n+vu!Qu={hs!i0&Zed34+(n4_}yk7SDmIy0V_kRo$65x_%o7 zGD)_&6CA0HFtCYYFjk`(lf1U^{|)38a})Xz;9nAi2EbbyV9E$vgxb;4z%HLDYnfC-98& zil2Y6EGa1vl9w>5vE37f_tTF|bQx0BB#tL%f6GGw(StI7c4rbHV`o=V8#VQ{F&%;{ zDDabzq$a6^W2I~tUJjFZCJ>BXZD@ob93cCj2-s2-p-okQl-TnpUpacSy9(3E(&PVAmSaVn@dxCoe=+&0nO@oH9%DG#!i$*k5ZHP7s8_A-$ zfa%uGrJkaLp00yxJsvuq@AGc9XBZD4bn9c1(JgPQHHV6mIL^jP2J%FrZGd#5Yc8eM*2v%a*G~HlKMwSCePKW1{Mp7O&z=#6OTtg#gT1Ta)*jhoa+h6G>~$Y89e;JTFTV*^V{s(wu7 z%#wN$B%J;gjI0k4V6$t-?}n@Xk3l8o3|VL>uW|&-gBlYBwvMq#Ig&s604frVS9=y| z>&4qKN=RFf8!{YdWk0ymlHqY+qXUw;@3HP98rCqCaHGmCcJ z#S{CWn`FCg&xc1ACWL8z*ME$Dm6t=esSmBLrC;(ukv}}X2*6suO$Rj@iPS8XQGd;w zZWy$h^cWF}FrK0G>4Gy#kpuUTNi6J7jV~-k-Pm>w>VrN?0nn0>93fi#kO_MbWs7SS zdTfd2CS=8Fgz=d@lKzPR83!Kwj;s{2vHx%6Dg{P53A^mHEvnvxbBgn&x_q+rt&VCd zk~$ehLN9OU&YUO44X&-vKNIqg3}8eb&^$e6>5gUOeK}-qT}lH5!lLym*rG@ovMCWU zlx{F#|7^|?u>RL01_T~W4OVfQMEBkJGv)Jx>mR}X?V%Z|r_LGPywgwPON-~vcz74c zsg20UZ^u>U#Vt{IG&BlB97!XNX_|x<{$W$ZU904k0HDRM}r`n+}*IrUGW*dt|tYWb(Em<*Ko8C=>Pq}{QncKkTF>s zB!5n}AST?~?n1%c#F3t1#PJ&1;6~tC+2nL=i@CSU%J%wTMQr<#1i(g`(`@6v0ZT7PzWIEO8L zL7m!EkfkSq6KTZ@af%+mK%L=_KL-Zm@?WYWN-q{b4Ya0tms!tkR|)Mh=hbN<%cG%) zH*{M@Q2sF<0?A|39}PI4kn2mtz7jT;(f1V6j=%Z>nw1otq2FCx>dEbk`ne_lXLQue z(-^;HW9nod6F@GBg9kVSuMY}G*V7DXD!-!qCD+qw;7<$g8yp-I;g6|~CiI^3Af-x{ z!Rp>*vkeG+9XV?E3&}oq@csKSu+?8oT0iH>l>EYk0#R(gq%ZMsjC;AT@#3yx26R6M z)9lL&LtJ^9jVG>o#$m}+VW(az83W~ zy|loL*+jz%UsEdG>8gu>D?R(=MoorR~5ZW0-;cm>Z(o${lJHnd?c^~QvDq;(7T6iKm}(~teY{tjtzmS9I~g-I5E zj2XN2Bp-QX1fm6uo{P8~ledOLM z!s8;&!F0wc@FJNM#P$~#_&JPwwIQgTD8Cfk7QX{y=sm--{43sdcQk{GXgf}aKW?5k z1V6fsAGb3g=Q*4ZI3gBF#AJ?K8)c_;CL{hxp#b00%yjyq;-z}oV5JwqR}_&6zW^MT z2|(t^Pr#%l&HAT9Zj#+JjVhOij?+YyeHP;Yb0PC`c4`KH_EfE*QXk>ioKW7I<4O~a z!)U*3iPKmM94}m*f@0mILILqizRo~mybIUMa=HsvCKz6^gx>YmW{A^GZ^Xv)amjhp zaYr7~?)E=XgFw*EIr4UI8o+)Ob-Ww8r66&2X$Zo&9|&5_Gm*T%rAQctXr-xYbt`#| z_NDxUUzJJF9`5oX-`d4LWFK|gUEMVw8~a%a;prlsPg@&vXQxf|HL;tjmib|- z0{5Np_Pg}Thn4E;!0Xm>*#G+APQG1c%H3~L@PWsc*|bn*eetcrwy~i3r|n8EI?#^4 ziN@+6ydT*ic>h1^A_mJ8OO9>%mI@{XSl4@h{En#zuSEtdUQ(&FRkrA7x2R_}*f?jEanl9wuH=ipyg~==8P(h^rR@C^bvCpA3Q4P8!=fBa z(((oi;npUz;bTfQ)ADZrtkOY;9IuFYnf{G0;FHEH(2t3_5f{%C*a(h`7y%29i1#4C ztQpFk)`DICiW8T2+w({XzGFLmdo{jpn!i4v@&|jzMDOUXO_o#=ImuFF>D{;W?urK6 z+t3u{Fe@cRF_8TtCI^QP?qCgmfnt%+@YR`hj{KMrtEJ(a|AY zJ@(r=3UV*)yZ;L9<+Mt|30Jaxz{>gC^>cJLY^nH?|D$F8df(pdA!8;oK3X&TP`NN? zmh2K2dpz0+=e2RiJ2`DjPw@~x3qgqf8x+~SrO9#s8lNSEcR{V9!2jlUc^l28*?D!g zb6})K2UO|@dXFi@vqN2C5BZFUK|aKVJM4n5urJ#v+x7vN=_||UAqYAP^^H!tHeTxO z5vz+$#ajF@LyGusGw+43HS#>_9o7Faxz2KNl?cQtd;Ak+-PU)f=^UAkOtqQ!KHdoa z@>@)dB0)jnkqk+D&{CVvj9UjfI`aHarU17KE@-eSvT7h>aeG{4#|X0Z_5F;yM4O+> zROdKxcZ(cRF0Zd^uQwMUh~6gqNa+9y;GT(?VxqI_`QINFG#fd6$OINWX>37U8GI%E z@egfNO9STq9o7-$_fFC%`gk_@!YcnkYRqq&ALoBkTBfP$Y{nNk!?~cSTfEcy^FLC2 z+`J$I3l@!DXvSoP_OAHp0V_>^D>Ojc&MU~pD12+qqt|m8l%FfPKWiAn#$MJ1Xy(;k zX=iy>-Bb}JyT$;E;4*8nT%toLdw;*C_#JPY42H~=;19+n$;g$^{)U(ZF7#*PIJ8v_ z7fGZPPHTyNL0|N(43E`rzxh#8XH5q-3#h`rJ(Y<y0prl%C1mf8%%lcAle* zpW3cLveVPu`QLG_FF2sythEKL3_TSF+5yz7E?-ZqS|x`_8`hNSr%EpLGKhS6m{w)I zN+2-H3JCxDi)3@#8rrZ6c4xUCYTk99HTBlPAIHYK0DU*PT01*j=waMh5n>cG{#7Ut z!{1V6zA%s>)n1kAgdCNEo=ATT6W9IwPQMx{o9=~qoX&-(DG9dWIZ&%-xI+iN?-GzU zwAov2x;nF%u+D_LeI$co@2b4DrFr4P?uWM5K z3mn~WH^A+0HOsH(+u;@v$~CUF7c$$h61ic8X+i|CyU*M5+z2DT5(^-UlF>QKtXHE` zOt!J-B#ZIi{+cKrZ7cp?$BEs}WcoAAk=tueqQ>bZj32*zN@3V|5-88DuG;TT2*j^+ zVs19ecMuTI3w_M4Z@#MAxsE{OR@}#L|Hh(fDmiA&LmGEVcY~Jb{-}gD=6pv%?llnD zqi$9dTBX0~KWrD9+*dnj8orTh{7P66*Xd;8veE5lZaw z1G?!7rzd|FRX1!XV*y@CkiO91QCaJIVRtd~m>h7eYR1wEBCrr+XcOHArit4>969TD zje#m;01<|d0|49(Zy(%EjsgilGb@EG3RnoGkkUGB@mtG*wYBzVG2|k-E=iV!7ewDb z!j)l~#~(+9JgSAS3C!r)B3qdV@!Sb(ft-`4k3XzVlE`UvG&Ixl;ofY7A)C;;*39sp zb{gmMppG%`C1wJc?<)@bQO*pE4-7Bw$uj zR;%1^p*;HM!gjo#Q7|2`&NiLuq?V}Rugb{^^nR@FrRlnJ8LXZytT=nji;0K$?1CL_ zV(H1wx-=^K&vy?~_>S)vNecpn?D<@U@MVx~Wu6b!zSBvnO4x3=u6$koE$xjtw&nO+ zTmSEU<7*;rw|PW@6zX?F4MDr$W}Rc@*ARSwk>$gahESNF zX@2Y#O?IR6i4T|;>4J>O$bwa;#dA#lh)@!72X>o}ex%w0Y9R(>d`tvD@ep*jOE%U;|IrfuVH>qW4JF=~Jn9 zwQItU03QZ}WSVvq*e+odW{+pRH_G=S(cck4vbqFF zN%c$zSvAR)c$1vEe%kc?`@@chjCzC+dr}-%`ZwVGERm*B*0y&8G-8*fenN zUOQ=?J~R-XBFKkE(|ZsAeOupOx3D%GFNAvTiF{i}M<)FTx!8M#TUC2ie2DrNd%rV- z5Aaw{zu^u$eKr&5on*jZho;1gy|l0rEz0BB73vs&nzvW4*%>!^byqIQYwd;@QZN}J%sS|UJc)h? zvZ8!|R9T-bdx=zcsMu?fE%YeLSr)GSRhg0xNARXI7UUB{@FRsqcU9gs{PljloyJmb zg|%Q^*;X@&`u&vKQ5#(Xc^J{!F3F`w@LiDXLAt!VWs_O{4c_9dt$4K>@y^C~x7(P3 zr)L|>>($ITLgZJ%?PtQ?Z7lIaU2*9<)OmUh&Y<7#2>qWpqz=zK;Unu2;_2Hu0%4nQ zk|QMSvUYOLa`zn<$?i|sd_Nc~QY1XF;t$zc=*zf8fYXWZfdBsP?(d}l!1B^B$#BHCl}^iQjK=hq<)#kY`+l|~ za%TRhm^i-TI#<{jx4-^mSK!wGf*salMk9s@O{Q>;dK7R9^(@#@X>nsYGM0>o(yi0h z#@CsforJ#t1^O1}oq|BsoY$)9<7V*R{I1xu$wBzmXvQb#u8E%5L0$Fj)IJ7MuW1`v z56W2@iQBu=aQ~Kx$1HO)VSnQPK8t;=ea5vWCemjUmcj?x>59W`5zFm3MDDZc+e`r# zpN0`(btn6I5!t7Y9u(Ied*d;wH)r%jODs??OR>qP!m${naJGyf)n;y8uPh?^vnDY- zOm2~q$_;3r0CUE~rWbWPmps;KUQ-E{gtX89GX7Kg;h;{UCBq(5C9#t64Nrs|i@d`K zxYgukm-6G_GP&Yse%Z%dQi3Jwh}?_Sq3_rFWADsAY|2j1G}2ATgeIx_O#_Pr)|2GF26wm7E0itWKO zSnvZnKuuk}=JeP=F>$Q<)&6!66i@EF#3=}!WFUdOIbJf=*ED~TzjgUyoH0`pUUFDo z1rmZS8YE`3g`kEdeKXfus>Pek9i^6bF17;5=FAW?kOV_mW&S;N}!mrI(NWx6R=_7z4jZ$p+Z0`QvZC;GM%VFKmUa}i7^x{Fresu?v z^=D&}@o~44=<}y4Xo%rITV94nUN%K(MllfMBzf<@oObs-@V+d$GQy7)kWS87=XY8w znm+gQJ$3{qslWb_xTylghP8oTyF)3~QlPjMcZXoXT?!O;cXxLv?oNsZmqJMK$@_gXbIn{c zzaS^iNzPe&@3rscaQQXj>7y?=%iF6AmAg>m`F`_T@izHT>x26=eZ4lclnh_QpDA9- z)>QCUJWCG+N}!wXKK4D1J4yG~SnrB=X1%EE;)`=xlxsKsD7`WwX+ETNmqaoQxYP!} zg2TbdbHWB?7c@+>7UX8MLL|(zm2l==2iGovIBInbGouH`8QeZ^TiiCo-ujFOv^4b5 zs1_*@lP~7x72Mqr{RIxCzLpBf-EU+5-6R&*J_48Is$)pSU{j)es5Ek z)8KfIos#wf5+D~)yrN~3LF-5l8lZ5W z&QRZ>@ptmPKz}?FPXab-Lfw&nxovbdb%mvsz_v+ml9ze+Um9h6uMc{@k}&_h=kk2M zH!%o2&(OI(b%aZif%{kRrFdC#FlEE{C;C-!4142kGY6jSUb_lo&3#{Y)H7Gp0z(M>IFD;H_)`SH}x<@|DpYb)A7 z5R?9s-B9z))CH_OrW5yHjh-F>m$Su&(aXL2u7_red{h7ue^v-Z`v_nAwsyEN+3P=> zyV-(KH4+^+0zIzz93jIBl%iL3U#GLzeTJIf)tq1I)r)M#WUXuBMxfKovJcg?P1b9? zZ^0;3d038#^I-ezOX==Q$1dPH<_ZLhcH zx-W-2L8toX2e;>RZ{j-X!*euCSQ+9%F3V9tfm!@{AP&Wx3+4G9A!f)i%=6zzd<#S- z=A;x*xXbNF17Po=B%OVNO;{Kt8BO~tpzx;mENI<2$fzu2#!PY>G@oAhCUBa+Om5Ps zQQ2};MR0N@iI3bleoe_2oCh;JUkD1(pj?yXromXM1@BVpT20)3%1(hscR3I2-RfFO zmACqi!hwj&?V>UoNvmGA+-8qJShs}2HA|u&_xdT{XDf2ID-1S0h`a+3aHXG+y9C;4 zQ-)_QQWK(4Ys=A;opM_?H#1%!)%UEMuBR&+KZ)PPO`9*8gv7bOi=ZUhr?@OZr=eft z7r^3wshSEHbKfEW?0q~J;LzUT&u3Pixbu$^BD_e})9rRgt1>0$*YiZL zhS<;Y17BZszqvt<>S;jCzJ4|d4MQ@7P)sF;;r=8rqm2Wry z!3)IYJ_(upKylel0v0Z6Ji4C*3rD%riX)%Fj0`>(qNPm#j3?1DHm@v_k3C*vyGHPRBvt6nvz&ivpcAKjTpMz)20Nz~vXS#<@NkQ_p83cyV@gT2`j zz;?#HLtxJh?U)k+s&+aAmc51fJLy^buYr7g|4s~gWmf;)uR_P;&v4xg*DskXK_tn| zuj?FTA|GE~DJXS5NR73Sek|Kx$OYb2^=;Hrun+99^n!BQG|(GN9{9sds58m`S5+2t z04V*xAw@Dsh9@Tc?e$3a!dAOC=>4ov@7MK@M-n#hMR^Yz@L!k78jOp>@gjGYl2t z{t(?vdl!FZ((3$T!F@UxxTt!i$XHPBMfD`-f#S`$>@ij{wi(Wj9Hhixrc`M(VA6=# zzCwrHp3OMb*E>`)gPUB)84^oDjRx`VYs9MB2IrjU+?;A-r{L2d;-l@kRoz@K^J(+j zTDP?otlBn|d>=Z`8Y`bv5B>aGF8`FURleF~8bZy>{c}@Ssq%Wohwqq}DwG;d zvhK8ir0C_5_v8z!ZxOU>)P(G}nwU#Ak0yEK6vx*~<-T2`+6j=8t6QC^#0z3b^f<~( zBjzo$l)Hrb|9e1ae{|9da46u)H0;mvc8b~>$U%_nl)tY_WIr~ASe zoaR8N#Dt*l&XWmydvMQJs9uV%PgRjzIu6$8w>fuJ;Y=zx#%@+?333oI=IyC{!?M*A zy|%Rs6VK}Ph%nqQ^YflueSAz7n77QOE&d-0tUXMFTA=@@K_eSqyZWUF2ng(-pH}AS z?l|WEx8DbWaF>PsCqlEfV=o}Y?KHU3^mw!)J%~L7M!um62-8Q50)oHl`5^F=SS3n2 zsQutD+2UAvX}1sB(_L#fU4kdf0%3-;s#Sz~bP`OI z-KgfWEt~(Pz!Wp0L2ps;P()QrX>54>M)SrklObiu~_L?BZ1$wa&)`=u6uJB>Y zB%fQ}S*s)^A9BM?w4AlJVx7NuT3@q1Jb28u#WpPynw_9gK!Ucq!QE&Uv}_7JMG&1j z4bU24?I-w&ARl(#2>8dCwWgGlyvnE6TjERom;Edevoo zZZ#_U%wWnAb5JSwe#bwd;s?J#8~>d3(%M3`+3UUvA4n-e>myJ@rVQ;t$=)x7V=*cB{S7JSbd+hl@_Mt_A6VgTvX`CYr48*bZQ0_uXV;f3sC+1f*L)a@YA zYeI5QPt|{Sm4rU5aEQAUxMCHpW@9du2tq5lG@Q{f+C`B{@|)v~smEXZD=h1|ElO4Hj*n%|oN8*IanH$p$fP;!`R%1^LVRwISTIuCjS$@Y-;Kh9~ zg>Htp%+ja3+MP1kHWhuZnL%q?7y z?8vZM`Ob8WjmM@U!0kF&QHbA3LU~B(X_0>Z&Eht5HXVEMi*Wm3=cKQZ(r-2Y+dGUh zeUqzyw@b%Wa#JmF!&~&P1#Re!d?IVIMgW+F5xiQkzIW863N@ z5rRh42gv1Yr6j(S^-4-E4@%D;UFu5T5NLF|KM;ZDn%V7Q=|vaO5QR#UB*p%071W*6 zHxW1~SUz2gdU6}F4j!@kywLW8J0ss5kf&}h7NGuwmPJb5)p=}sp0w;SKM8_@mP9`$3Y>U7^fvj-GHURg;cbm zP+SGe{mZKF$CO=0x3j`BvkO%>yJ|d$i14gGUG`s5W8wN<`eigf<)+F7gN;Cr5{6;s zZ{01Aq9Y1woODZWA;&$cFH55o5y-6(`@OtbHe~8o|3&CxL)1|UR+3eLJE=$tkoS-J zcAAUMD3mWlbBVl+e<%oal>*;1|MMz8dcx%3y|{(ny+5&4Sbx?JF}7(1z7V?y=9An3 z))OnZa|$bYtm#&41%euXwf5M3Yc}Pu;Du1MpL0CTFb2C_yZ|Yek7US6o0&X?D>(=T zwRoZy?+;3tMpDvfr$cOwmj<^g&hg*%ZDPEXPix4$9}UqZiwF)k>V`D*R0wY0I((zC z0;q5f6rr_U#MysU;D4;7CCTUb$%31zA0IKasJrup7f`6z13X-Jx$|1i=Zd1(ru97j z#QtgI1r6&nm0|FDN9#wY|B=T^ZG6F99Txle!D(N@FOyWyg%P@SGOtwGNKQMa?S)jg zUvvw+_H&yY1UMw(lCjFjmTM@z#k)>hcX2IB2JAry&3sFRtu4e~Z5P~L&i2`wor+}h z|8LcNS0>eqG!_q6m25bU0~5zLVOS*TywZWBAo!ZM1O56<)2&QJC_iZ^BH@IVQnyh8 znBZ1-T?E3xu7o~~oqrW%}-%aW(6lTApYd|M9dbpZYCCE;_VZPt2Eh=z(Z0KxORyFK-<{GnpKxLMzBQ3V$67LDy6y z>d9#8LvuWn;BkI6N$?uGR*!*5pMSH$v^y=s$6!%w_PPNSHDT+V_o-LwHe0$Jx#OCwEj&&5m^ieWmXj;PvQE<`z-xPkZiv3t(Vw2$tij zQ2&a=%SeffpH~SW)^85)bf^_z`-~i|rwZWb4gr#ZY;Hhpy;$%*o3C%xYdKHpCsQ)n zpZsppEdX$IOT13ki6`yQImw212$#UH^g-TlBXtsDp!~<@gytfm;haDa)yQj}S47(qUN|JSif|IHeAJS7d~e0=pg5 z+YH|v^c*4qPV7BLX+*k}5;LmM6IE}h@+56xHdHigDD-gGb1O0~Ty9E{Sv`L3&fcKU zlb|i@a{u*0A;q{Wd}@g~a%hk`MUJ|%9BnAZBJt0qsetu7810ZWBXlsr?eoc<|F1Gm z$}T-jSQmXE3t^6Y8m<-+nUQ1C=rH@o#z8=f7LWdma9KUm(*Aa1xWn`8^XO5KjRXG0 zVMpYzsD2!*2zC)l`PJI}`-{zXkUNTCT^e;9h5r-@LDU+#1BsCU7ZUV9!UA>4vGew^ zGmc(elG{8yP@JQF&8uE?GcI;msZ>{`M37slOx(Y&sJBhHAOjGhPesC9gRaiBwC1t) zTzeRL><=rff}WB^&n+aOg5>92`Xm6Y2(8?AW)&orYkmf7HS!LsLEI-C-@d_8ng=E| zIYc{-xa$5mJi$gHdx{0kx8mB<%_sZ0^4J}5#6KdneN6I=suyT?b&S7+PNxSR0)hW9 zkuwbh0gwjL6-v?O7*376U!kZSs5ClA8jMy2<<}x)j@v^$N229yY_gsTv|5GE*WO>4 zdDw0&yh#;ey6VQQ1-F5B)}&wmF(jdtpX&ecC{>LP@F&ap!Gr28^DCFG>XoLnjXC=s zJ3deFiRHu2-ZmqTnwqLA(*M=JwZ{BUJEQhKh}v(UB9ndiV6_iwi>E4p#I4M}@sS)K zPNsc!xxYw|GHc1m{|<;hgP&pM)M22m%E$Jg5-%)8=3lKVqUR+4UZI&7ZZtrci>T=m z#cX>rc21O|#*1PwJFOjkXEwSpW>r4*DKz#&v~w?o3(GolRk@Uk3dZH~azQ&axw0$F z>=d0Jt0wK}Cb@!Tkdnj;`AVANLRh8F%P}TaTL)&&C-U?y7)T;`U zue25i0vqR_rDF^u>OOk-6fAyyY#|k+CW-t4tG^3=I78sKgC9*$Jh$<41-7KjLMaC4 zs&HAqChBX&hKD0>+1W_@uVYz!Igl_C@`h(9{CjroXBLt(*qBu7;ws5&kB#u_!zvAn z9vqEuln)PmDdC@u7QLmFmHg`J_unhHH6|<~zbMdtVu%>VLg27tukTp>T0kDz)%Jf*j5;6IUD~R7tSY!Dbs+S6c+dIpUu;qJ{LLn(;s%s z;%yhiB`=CDWXujzY98lOJ{;2+(x9GNx_5RL2#!oqxQw_mrLj_90$4uTv1%^Og}t)y zdzSY!pRDvs0)h0fX@Z?;WA;nw_(1*32|{KJZ$AF*00w~!VP2G7NlCjK82b_yDNtSf zt|G}SH2g2kw_)0fN)%q(n6y_VvxpydcrP%95li~G5M_RiKm^$=%+4}tHFnGpMiFja zV+31gox6pJVi~us`03Dx0jd;BdZ|$i>2#^E&}IrPE$zL_?Jn!88hKL*Adi><8i6!v za(p2mL?erN>g|L=QR?^O7)pE6K{)WQ9(QZ?eqjQl_kEP^MSVH(b+-VrADsN#d#QBY z_}pDpq3qsH(UXzMV~*(X>;>ZAsGcz+HKx3t2>z1VB3LhIFVpU3+NjdH%!rlN#YYFU zb(QGf_`}9$XXS0PqlGh48(bnSl7I(FGj_v?Ge{9nlyk~^%LA6!$oTk}cki0ce3986 z#M8t#(X&hmKj=rpc%ndJBJ2iw{ECcKO>OO%&Q7wFl=pOW%4ik3VWG@~DeT)615&qe ztHrm&r)y!C=98o()<30sc|lQx3e)JhtHaVW;HF5zN*Km}+ff!A0TEpo~aLC z@ujakz16Pet+;!$aj6k>j+-lt!x!fObE)Rf#fyxi<9k#*V_%#PI?QIo(A@ExKD%Yp z>j3Tr^VGypLnPy0L9x$%@0V8!Yc4+OWwL7K<^Y9K-_z)z;i|}g-=ro|b>GXo+eSzOsYYSa5NywxYW}z2#3J>+P%kJt z&F*c6#DDO~p-D!sjblUh`&1hKPg5(nrWWj%SSO|e1j7Ymk5wXaO5^LS!8|V_ony^Y z>xa;6xBr|0a0oANpFTa`rvVQnNf8V8qA(AfVr~gtwLb1De-|*5JFrit7U{rD6|{T4 zppf|T?HllMrHsUwYf##ge27F(Wg)-#K`AlxsDS?%^%q2JvV82I2U86PMz>g@c8BZ3 zSwOh+4lH|zc8LnK@fxeI5d;{j2*}g@3g2)M7eL3y!zu7Uvsr0Q-Jou^V<2T}@U$H^~glGd3F6E;gi)ajO+IRQ*KS8NPzB z**;z&40%_fNky(kYm2-kK#_&;*n%d;oXSg%tVDrR?{oGN3E_E42F4drz0aFfN?0-X=|*^fB`W}FS+RT=((X8xRtf;l zAt9L4B7To+!SF~AM1v&d_&J$TTaB(MtK0o8c0S{u<3kK!nF%CvzuzbFk}!w-psY`S z>BE_kI19DemCB0EgIEaP z`ofd??fUH0H^=-iGL`Rkrb45!M-rWYjo7dp##XGRgx=3?Nv1z2pek;-EDYiqC_bM> zqU(&uE=n1KWsl^w{zmF81$q0aXV9V!&4;M`r z{ld$@B?6Liva8tApci#OWle>|(vNni0IA1T?*f&{4g=+4b7Z`|aIqsvx%qFO#sq56 zV7y9PK`kp0Y2<(uM~LoPDXP~FlzYg2(X^g5lk&9S1-pb(jnywmT@+BCH;-hqZ8R{M z8Jw)eJJAL*H&oHcd?KqN^_^zZntKgj|816(>(y~pSXqhhFe^D=asaEgwka-uWW$n4 zeZ2bB!G8(ifqqXXv#1xmWg?BZ#Jl0Z8S-RA!xN&L(w3k_7bKbCRLTBh*0TK#iY{Gm zAA|`!_crS5bVNuVEw zFd8gwQVEaBmGI5u5WP0sSlU_%H58xF%#C&K_FkCmXCXB<(>s|+yICDmv9EPsxw9Oc z)3vW3clcpk+&5<$o~yh4#nzm|JxI^%omaEz>fWOB2Qs$5NhC#~p-@p78JYf~MFlac zL77@w1A19^dn!i(h^f8um|9BKw!it3&GorR3qM7{0*Qs!XI`fKt0pmj8#mkQN>-## z(j?xyrl_3$a5Xz5fAsP|nvpXV#N{(uKy`n7qF^?>Z{9um;S+neGgFqxh&E}KbF1zI z4>v9vpN8=8|B8BxNyzz0*qNux#jmv+Xl?n%?}j;)Ze{j7P1=V`Hvtt#Xi(9xhQ9g0 zM}L?ZXmL|}QYy%@Rr~_$wrY)Cf9YO0eW~M}k6NZIFGJBtel~7*Jw$(HE5J-p5d71u z50lLna5xL0e>782<>~cCDeCpb0-O3Tw}ze78TJCZ2IJ-vokQnIAe0FbFqlBFvk!aa z(-Ql6*--oVg0E6wUocyrissTivWkLe^<@A}#MEiBS!>0?%>%BLr)kt^tFb>zWe{Ey zXI3K11Eo&wGxkFZc5Dqv#vOpdT-fBZG%jZ+t`~Rx9M%YvY9dP+39;iP)|KzpSJl`oKATT=G z=R_$kbH->0g}F_8yl;Y8W_n-nl%|Lfme>+Bof%T5=ewP$Wa@lR@OF1myIhe9`i$Y! z*=5~jcz(uIUSC6mh*(CFrga2wI{KDJ?_!sC;kLRoBb%)WL!a5~|BUX(d7$hIGy2gvdnK>lP{HqdR+2%> z5nIF}@xA~NOR+S6w+4B}?GivrLl@!s1Ha~b-|p{a+-Itc5PgJ6;$xyq(7LC|mci%~ z3XB5rNhU=mXf6?iOEnXsOD&~qjarL792q$wr7X{6q;AjEzjvN_fajW6XZqJgJeq(2 zj}A=7e&F@|+NBRdBbK55O50LQs=jHPl*c@n%GO6yc*;Pu#Hc%SY~=2HMnJVpL`S2C z#3D)2_wp(lfP*0`#QB~Eo+`8v)nes7zv+hT%Qskn@1tm;t+A>6z8fd6{LP}dPxA*= z`H9GcgdqCtT5y9BvtI!0u47Qg#7fIJa`=L#%(6%eUpB1;Y;R6%_|1y6647N`4o|QC zi65}ON>%;<`cvpNYtV*$Ra$B+`MCXF0~cpF1J?r)4AaaD3m0 zs9Cg;(B6Jm1WbZy#OM)sWmpQ$+%(t!!PCV!zN{=txWczGxPeNgfW^N9tUc|45Xv&_ zt1MSMI!tNnal&DN-l5W8r@Px9EOjxn(oZ=Q-9zE=C04;YwMHXv(JUh&IA?1NlymkW zUwI=dQ&26pa4_P=w$ZRK8I29oc)~!Cq~KOk^L3@2M;$aM)efn%|p0H zfUfBVTOFGsX7=EPM_n;ImyEO#8osI zWGrF5d8=8s9a-3?9S`$sg(k`onVW38s2OLhD_h(?qKZ_-B;pgGW#|S~^({iB+|HaZ zkAKQDQa2TnavFIRB*K>e8bkfuRrS@5T;b_|p^zm}LPTZij0T*(JEZH)^N`HFb!;`E zwKkQP-8?#yx|$Uh{0!e^jSFLg{qgtdLYac)2t&3ijk+BtWD#?Gv2henUozaFrNGdm z;|u8YxY{2RYExH?(Pfd=_-q4nZrP^4a4ku1PR~T6D0yS%^E;?S7HS};Xw<*mRP0H~ zgJaZg%KY;wHW<-8XBDfX4}C(g?QLkF4i9X-9524 zie8#YsiPDarUeE6IK(Jmh#nP8`Qxki;#}xx-+jt$&xW2Ef$1ynK#uPUt{ecAPUL3J z{Pgd2!nm2OAKn@EJArhjxK38CyT2d>kkm+`TWnE=On6RM87qv>ml^g~9Z`fD_55CO zmY(6joN@(tR{0a(_8>(jyPh>F(aR5sa)8~rPwN+SCi*+C>$$GY+~$PZJqIK1Di!gkZ`_kWO=SMtwLH4qkFEX?o!PmR98}>8ewm@9H)krnkS?{UnUp*SYXzLN)fKw((yU$^N)~Sz0Bz8C{r;BS|==-B0*$tL94L)RYU~;hrc6UsK_ey z&-F4~b4N#(QXu$s5V7JV&8=6P6V08dABGs5YbQ~L+Sn1nqKE?33Y}=~Ek~#!18DG@ z0alq0!+RMUL^C(XLg+T{(x{@89EzE@+1Bx<(3?Jt4(}js?0tT`z!u_o<{14F*G0F{h|D-{OOkiv=o#VL#tjs8;awQ#^{_}g9cKmNK zP`39pD_(u_uhb~b7|GuAy1?#sU1oYy*p~nKt{+vc5teki7I3XyOlv0~YF?h>ycOo{ z-N-s!o;?28{(EPdP@d;g-c%k)A_hT}z#;Nm6rS2uoN>Hhcz*za+L(WMR#e|E>xBcp z)84}{W5_4JItFesunTAE?2%yjsoc^QU5CKB=#ynYQJH47~3=?|p2xPN!VA5W0i zU!j|yVK0j;+Dwj}%SXi^?pA;C?=-mCk#pC^hbTqICayd9xKmJf`@YdFJ#Mb@$(e(l zVi#!tgIMW!8HBeggcz5*2{>yjoBK&J^UA2f05e|zhuacft92=0klPX|2FAcuB=ev7 zRTn;Y{6F$e+dJaCdFwpR+SPuwN{klmNB^+?Ck0?!aA|9JE`$%R#|;e&3nosBDD^r6 zu6|FBT4#I1_B=E}BVr|%SHZfB2ynRRmntTInjz1g+vg)J z{Qrz&8;|Pf?EH-^mV9xSrGvO!RasU+o2W6}f|mCd|7U94#_NytV%x!PS@+mX_3S10cw zjbROA5T(>`8o9*yqgNYi-AK8&vsqN$^Q?mJFQE-U$L1x%f_i3zlarIr(b4Fdt1EGW zR}!YG4wko()#O0u8OacY4-}EIu`^vT(zTE!zd$uI z&q<5mM$O($j4LT>Ym#n6@Yl~H&(sGOD7Gb~{fu5p@#7OG9jzRIZ_*)OMaj}gOpBYU zWc<#EhLPtMeZ65sRrvPtk@Is2G#Y$ArVaoun>#!$uN%-}Pz(+d1nK$e#f#`p{ zA|k(IZ+kuNP;2a(174X1q~4wa-<&Lui~ifAf*UCwTiLMVD=o>ZQnawJENpKl!E&ha zzBfg8o6KUD6so}h*Ss;Aan(O$P%CdQz>hxqx^xww3U@S1SW%19A$%_>VUOE8UL$LD z*FJI=d+TM?omQxx@spdqAZMPMmeaOG=2?h#sR7`V)p_WbWsD3hxZlT~|NP`0x8S=6 z=)1yACf!&T7OPN|*3%>VqH;R_lB9+X+p6*%HVMiYq;slj;1LV8qJ&(w&?oCdZ(Ut# z8jxm@85`QCN>W9der4k!h>Blu*_%BK0mOPduozZPa(+FmwlQ#FIrU$zDm)A&~m`j8{1h_mq9$Dr+@^ydW(7 zDyHS0)vyXQ9bqz{s%qK3;n)_m+*~Jf%18gJt zj9HoofV@i;E=DP)E!ssZ)oCg7eP?LCHn>(rwe`_0koY8HoYj8#1&18=em4Y7#w)9} z2P{-`nC(Giw;;<-7#NyfZ3J~YGv*G-iwLvNMs-8yi=PAncp5P@1FjHSnc3T%L^B+# zM*a;U8_X)p`s$H6x+h52K>FhEuzP&boP4ufnj#;?2r_Ke$14Ity>9i|u&sfryutUs z@hzI1+A!5E+w!4+S|E_b9e$y?HF#anzdf)F^8v}I)6oUb_7b|cJtl8&|GnL5+>0G| ztv-p;5n1%$nE~q`TAw@=t|AEDTuf{zyS_D`ncoHpRln0 z??)DQl{;xr3c zz~86=cdIlaku-E%Q`N{<1XdcIkghY^d&j9_Ocz)fd^I<1r}Xe)QJVQUaM7Ef`vX8w z2VTfj&@8&!T+<~RtID-ZNlR-2F9_x)<%VPji)>x`0Fyd|%RQQFVfT|kiX8s;J+eZl zBQ-wG+ut)e>kBZ2 zo9|1EWT9(h&Bf$Cxh0#Ucl)lFwR8W^~msk4uF z0Y!1F)pUUq1w-Q@dhe~u=y>U{glj$&%;yFaP0cK)%&d2yPtM=j(p=y6Doznu{~DeD zniIU#%tsGu4*{F^K4zyKOYCZXqssbvTd99@KO8;;7g03pNW&>iD9mtp`~qvOw_%Oh z8jWo}32LUNhTr|v6MJYz?x_nPkg^6T*yR47T~`74#sb95*(o9p?b5u0qhUwDgD^X< zC<+9DEn5xN&9qF9%StD|7z;Vld;gxurBfOC!i{x)085F9dryzchkc#(el zujH9Wwp8reTALXNNBy{bmR))vlq|=KZR-N&lhA7Q+~Gb50@-bmJDMs;&E>1X@yr5{s zF@BD|!3(>GDN8NSE{sKeN6J-;Aj=bkOQJGbezp%kpn0Z`@RZR14csY!VYOZ*U~Q^d zX-tq$2tj6-&UI-?=JYo3|JSjrHg$OTg9H=V6Mssld4r?6uG-|V>j&YP{su2}FSZuh znm?B~lN(U#Xs_1St_=S1 zii0+}uWcBBH(EM_mtl`m54c(k`r9jK#-`$P=hFmI!ud zY)lt=-}Q&W6(*_aY*&3e7-gwfB`U8(1X@2j=9%ipweX)}bOs=h^~_vl63pAycO3OX zH;`LfJ-Ch$Aly+zryjL=Iy$TM9jg8(5U&vIP5mt0Tk=3+j5yKL7*#Cstjm(2){U0{ zZS97F((Wazv%>9-KLLt;#yU?gFHD9P7zwa{!`Dj%@z%DLhD%5-^WON; zFJxZ?(%x~?y*GVv2Py%b4zRuBz%l=Qg<PkhxrcwqGVs#H;Ol=5*O)6_YwluH zmjA3nAN5g@H+FXJ$yL(uGL!$4V>}jEI5@oKzy7z!Sd;f=!!_4rHM<7*dy&vz7;;JO zi`aHlD_b)!B@(+_zu}uFy zMwWG})6m%%58>Ee_c3m}X!%tGs9|~wJ{Gg1s)3woG(ougJ%FLv^n6Cj#wK=Xt0>JY zH65$?JcY;xDt=)rlssyc9Uf6<2n)$eb8}HnO-RVwL;@??4c(2wEm&CYB97ZPJ;bC` zs*eu#wRLX(@>gM}%^f2p?S&0;o#^_!X$`F%SQkEzNdx`zupUDGy6%|BBG#72=34NI z$sf!h@NqMfOnR5YKkRos?Q11VFSHu_E;1`Ge3di|lyD zu1^{V>qO0QJJRqBfJJ~{hakVal`m1&Ffp2pylHfmDfKkX?<=sKK>qd${VSH0Pi$T-w=kETrWKj9aSnk#EA1=xT(ku8 zOO>q;i##2*-`gld?*YH+5<`H6I`e=6r#XPyeEVWu`j%S?pyx{M!TCF z`IF6Pd-#H^GhT<-SDOuaq(~*$JcKwi5{#ruxl=Y(m*0@|wr z8B{WMy};f9$RZ@xJ(zcP{-mMjWM$p@W48uQzLU+%Fi}%)$Ml_a-`QFV)n-=w$*G0t zU%y4FP4bgi;+jpwL_K!IgXG$?AL_+d7MtRAWLQ(RLr3SSgqG@ewaI-scA`Na?_LIv zJ|!4+w8HUk5@#c7jWa)J)S(o)m{D=>>||z`}%)+ zTS_1teIhxL^XQYmaW~k0!lss%1jts)#50Ok2NqW&>$HAgL4gn1O`P*N|A@9kCIs!T z==F^xl2hHF#&DrFUr#Y41$O5SaVYS5%Jn~-?Rc7PRj!nXC;D9`&yA(e%+UqLCt!RAmOXNrNSUa z>CTmK05%h-U)X%7Nx}8X_F>r4wbe*}&ADzl!2Auq;LK(k@-Nlhm*WgW$EPE2=aB7y z;JWIpkJ>)(#K@~6kdq=A8Npp-Ix1eJl}Lno_n~XPJZ!@OXDz=c7b*dKypG!>49yxkh84-d07zAuOf+!#__Wq^ zXLSY!#I3gIm7aIhF_g?268G6<9H(=yef8p@{%Y*|I`atm4Dw6R2xU^2oGS&6GC&29 z7c?Y&m$NXzOnv+dApf&dLnhat&N_bQRo~U8bI2XV%=}PJra0lz+}#rQ+|>!-=2Rs| zNAk;dg;RT>l!;*EISu}xF37iS%8^#xE}|3-Q7e1SqQ{I1SB<9=?~ zCFj&VJ-<>If2gUW6ANE?A?%K~?PC=E2zsPC1hZ5--IA4=`EPdYOfcmD{9oMcHVy5{ zaub;PT4$$11a8kRkY;_J&T^i}yk4H3z$=}GZTz5_Rhid}gu0rJ3uQaX)%}K9@zUq) z%g(<0v$Cf~w}b{a@r@_D8fKi|*o~5NuSPA{KRx|wF6!DK-o9^sU~mC{{c49LyKfEQ zQs?Jh$2hP}ZM=Ul6r7xS12Vr&5WP|fy^wp2TQ__{GtJs;@VEE(&#}9@UtRSuZit7* z|LZaF6tu2?x+ypw$l|jjnERbNOvDM4BSKc;f{fvkSq5zQBOZGuv)r-e;+UnH{ zHo{@Z+7FZt%InJ-aZD|5c1P(4FikDOW4XEhhVP5}mKs%j40QEbgt~fn@ZJ;L(6#`* zco>rm6m!}2G<96-c?no}NRQsJ+6M%#`>(J%+EvB7EK7tpIa7B2v73bAIkhQ1bo>rH zTIa$oN{ImeRc`$ExcRU#ooBgm;|5Z%*}o8Vs_V(oZRfBvtUcS38N0X>Yjf5<`>H5F zc?UkhlKdNex1iH2(6g(4uTHLZ0^J&XCbJ5<(3jjS?x?d1@{hXr?SNnqUOZ)yuSBJ7 zegkxEOC&U1+f{9p+3P%=*m#JSm6H?bi+zZZ4mAKjLTWqq7PB3lPtE-y1WnieA=m2-Tl6mb0SiJB% zKD9TRQjbbM3!lf1v^Z?GuC&tRoyBg({oSqsJ#-P}1huT2Z2H@-hTiHNCkUWS}r zEmse3vvz-{vcVXo`c%lj134vnf1RcfTkP{-y04i=v|!JSjjes`bDBBj(s-2^0n7fR z=U@+n--tqa33y!p0C4b~_VU+xyu5myxCZCB1<0X!rhJKuOQrbK62zz`f$4it)|%3B<`? zU*tS}Q7@3F;;i$h-ph;Av%>0NQIq$BK|s-D;KvCrT#rN3{FOX^R&VidsW{;YlU#EJ z5mXx@>4YE5yPkV9K$2DtB-RPy_tyS-d$%;oKZ!gyGOpCeuHGaU)L$a5K5+()R4r5) zLn`ln>-5t0>=D)})%ily%i3y$IdZg}#cf!ApMg4uq`slFiSnXqRf}V$K->1AGPaf7 zpkHrOO{UZeEI0^<=HUb`j=Kj0Ub*^0)_pw-n}J}}ld;|_!jr=~Ux&vEke`*F&y>x{ z)Z*ClS)xevS#9B-M`_1%FLW&Mc`O{L>*_`6@4h-~*7ZD5$zCglnbL;!1MFKF%}c5k z!z&pMB*Qh{@;FRRCl|1{?n%FO3XXFq%gt1iaFiAK6Zov_dvO-P)pP!XQ|0XjQm6iQ z5!aW}3uX%7?>VcFJ6YmYS1HRja5hw~`3-Yu_Ki$}|IYZmx+#YP%h~ee#sj;=|NEKY zgoDG@g)k9nbQ=mGj|mA4h5?|dt2`7+ca84HrQfhbgsWbF-7ZBQj`GjTh@IV{f|I zjGeH|zim7(X$Jzpiqn~$6{OJ0=|3mep_^iu3?U2jg;%sEKnoC!gVg?>9uFaB+lmdDc zVEHF5OwO+Cn7D7ho-^UWmU!?K_B(Tj5nM2qSp)uTZnDx3zRb(fV^-WWem9bt1+j5Ni*5Z%0zS+VHt zT(7NILe>%qIYQmbXB|B4m`sj-K#i_!d?7$=d?8)0E%-obrg9F^*;nFjqcpnu(!pt< zusN;XFT`f0#6HA*CE9*{cH++3*vv#|s?jA`g7;VXG#Mr41Te?rsiI+JsZ#wWW9ZCDfrTp zLvFWQ>8v8hePEvn+##!ZTNH1*{|2UyOyUWQIg4>D*|mVb8qmA=zu~ zr4~D>rtg_uj#_Z{;xHHX)vno0l;ONszI7(e8K$pnIyakR9%?;T5+Qsf+$F%UrHaUym=q+;2ql5%M zoStL5;MQL?0@nW?({cZh-jgF`sQMs1ipkW(ef1t4Yt*R7<%KY5HJUfX*oN%c zLsHg{ZuNF{vFTll!)By+N%6_6Q%uc&Mp|QKJk=&@D0t5ANsCtK@_p5{J9^vMRC`W$ zJ}^ec9DamQ4EVpiejf?&zVh-krW$r%QqlE>aMUEcdTzhYI8oX<`*NA%^9;0JqS#Xl z`FE;wnApf4qKMw7>LFZq5((LQGF&HcMAi>hgkt!$6wI6FZg4#1v}n{$VfY#trS zCJwVJlJ_tF5)Su;+b)k9EsQYpm33t=1nN za%+{|Td}^zE<>~aAw+pSxGi$^gZ8X{v%2+A=6~#n9OJbEP}1P19pn$tUJv-Dry-j8 zL2kBe#D4U<>pSsu>|$>wGwN!o>!)#f_0pmIN%3_>mE+CvFFck%htN7Qv6<=J|6ft8 z?|phOR-@>Yaj|24c!c{^>9RKou<}TAM(k%QgI$-SkB7|ZNs+^la|M@BFLv4kGuU1f zY#f|@r_WGrGqS>zP_SL(yi3+Ro8dz8uM-C%b{nNg)wtM`chjXNts1+BR&CNzZEu1MX6S0p}ybU z@h1GQIeIOH@_7ANZOAAIbL~~g+Ok`RCQA{EhKaHo!Pucck2?B?X-5rLj3w!Vu*q&8 z)5000Il8a^(LNY7i$|?5cQikjQ8^NF{q~x!9Q9=OvLTjIRoUZEXjD4`e)#icW77m6 zP+VX*eFak&T79-FPTyEA}c9NrqeJbqED z4|)1z#F%KBH)*OISLKhl%EQ{K5tMBwfpg0*vd?`Vf_j+FSomyIhI)L;>mc6%>%S=< z2wCUk36lcG5YkW!;Wk2U1piPksO{5rF)>Lz_OEEKsmkE0MKI88v<1mo`26|VF&o+_ z?=J)1&60VPc91Gs2kpZvJ?HnDqfZla@-Gjq^pw#)9w&ka7)nus3R+34CEp15m;86Y z@m=3mn#cb!4NZ&YT#N)ki^Dpz_ehUbC79yuuBXwBt8p>0e;#gSZ z#0G9g(!#|rUXZsp)*3mqwpjMrqIRXbXZe*5tD0uO*G+4Sbxh`K5WA!0S<6}U?e}gj z8B~PfjLv%nQK_Fm)TH)xEk)HXMdgP)Us&{|g{;QfMa8t!M%06v)d#x*ecE_;?a=tH z-G56vJJ>;e{YcYYeyOj0=w^zA>+;b2#ph#hpOAt_6!IE-&%8f}n#Ilr*=X@EkG-Bh zF~(|LeM!DE7~o3o`w~Qpkp0w*S}g1-uvqi|NSq}B6KEu_AvzXp=Zq4~&)bYx!Q(ND3px6u5vwQxrcAkfI&+Qjdw$oBl z2ev2$b`kvpAP91~*o#W#T!ca!<0-b@lA^`O!x$nV*Xlc_Uk;BrCUSEj=-^JV{eBVv z`n|v=9_Zooav_EBKl~*fG(=X(Gm2=@3_CcYZtg4~4*mMzIX?tBoL)Gw^_!|C4!)kPM9>ZU%(K5ALqQ;$oHy8=N`$=4s%oxqCw zIBV78{9fp8K!=U)`0jTMi4>I@2#0 zuOPAS?gk!#dpB=_#I-OzH~6+WtTqyC*cNS-a4EVqS;xE3af}DXwOj3j%Fq@eGhRqB z=?|u*1p`^P6rojy3x9#xK;RrvXaQNi9fA=1h>JVaD`n7PP-vGNyZG$%{?_{pi^2|P zt*93wuutVtX&`Lng`nh{L!CS~q0n>_SNy2(n&Zi*91XUb%e5Z}Mnnm2-vL3Py;uFT zk-zQS3OVfRw#rnC*%6mI<=WSAvRJ--Rh)0r4ej2PgJMbmvd%ZVPigAC%ucB8SK^Y- z**t+Ab-zVu)&D)#OIUF_@?+VoGvjn^B>*jw^Hp|A&m_dLHaq9Mq59Qi3K>G9U%<;L6}%XC1|mm+alqT4hlEM zKxfcWVBHVbN~FJf@9cx=-+gqS$PF|?PTj>Tk1vii8COdiaeVdz@IjP-x(GnxGx!A~sLo$`>*miS+E_nf-6>;(jl4w5 z-@nEHO*7aQrVT;Ed6WLtB@zbriwGF4;UDVbpe?!|UV{{?ldRfI-vkw$q4nj;5X3i` zCf{0bt59>`BWbb_tuz*=|5U?i+mPc{=UzwGdw1)|@|~ zevP0b1BD*G>i4R~se063(;y;}5D@aqqZeXa%@jy>cyhq<~0V-D>5kVM3iV_-SA}h8LH^>Ml)?U0dDGs2|o#0wY9ap0SGeSgLN`aHorSElERZ=+Ka+-&U`)0br)I2%zaze&V$YC{N>eeOhG5% zKiDFVVwf+7=MV>J9km`jBHZH=NZxG9e@Xl*;N%Ie|Cuw}A6teC+bdGWgKunor}zV^ z;m<$uLI8--(=797=+`ZfxG>+{&@3A1&eEUQ@ts`7dYaX@yy0#q%=RBeKhlu_@FqTz zV+~qu2}Se1;k=HoH<)w2_t*Lu5FhH6vH~_`V>`dX7AT%Q{~-qW+yw=7BUkaQq8qk< z8SJbJm|eGY0@w?pqQ!9WPVV3PWA6hDw=7O4mu;+gpD6i=8^}cMrNgwv&RD`SB*xT_ z32b^h6-fU1ER$- zaV}LwSk%}=S|AK*+`9=lZ6)xr9vJ;w?Cn##3#YwFBN%Jjs>Zp@KN{$IiTIX9wg9gl z^Q6Eby#w0?U|@-kmPHCuc`FyV+p0rNNGbc)jn0B(GbMn%NMzJ6>IVSWlQ#?9YTm!q zUI8gpa87Yj@w2$6QPgli>ES)BUm?GWzjT$X_lWHQVuXEnpgH^I^xOyy!vk3cVtRaf zZfpC!jY2PsmSe{ifpbm1+A9#LJ));PBxU-=!iAYq2lc8(w(K7pmF=R zDc4t4 zLE<5p9^W1x?aU&cv()WSvX`^^VNo%l%2An9f03}KfGr#QwG9LcjUV=uMMgc}0%eJa ziTGUt5Taa<_Tl;Zl!a|=WA0dJ)=laz5Dtxrld2-L5#WDrB*EG?v6JGZ+enAZUXiMR zc|pl$y*#e`FTm#)NVjU+P!dqxx2!WShR~m@)zCYVZ?PTicS7Sfw#JsWU>C(C6yM!f zDj-pCa+aM_@9~YP;Qg~@QP4vj3av4le#MP}oewqog;{R{xgMEpF6mCN|E%_6$AU=N z-o+P4KK=a+M~ zQg@F{L7{;}KQk`>#0|RFEG}~WnRd4l)tbt~k^dS#1(G@_IkWuHc$=33SW2ZHHq-?< zwhw9=U>%tK)$0vuEA+*bX|tW{vlmFT(3$qtvwSe4i<^pTcvnnH?Fuz{pqmOdW(f$` zYoZ%*&v~l=Hd*Zw9VS$X(pH zajmDF8*ZKzY5o1V5IpuSsXa7TdeGAdd3#U@%F1*&&{a`P3snI=2AX^&Q#1s24ww`O zwi)ILlczR)Ed{7OJ$Cdqmq0U+CO`;0F(w;)TrW>;|5^h+wls#B(+V@km$>vy|4NYEAp~-5)?VlC`)RXww_99D_+3p3pwnQmR&Z3erw5YCL}&a1sjrlgpJdBwbT> z?Cz0xj^0^zJp8$%j5ZaoM6btQ$+4 zQAZ%-eA;Yu-(A3(qsHD2Jv!DJZh%<#mcBo|6*{^{Y=`J>faO4o(7V1^K&X^CnEpf&-t#=P z9qbXg6*3jH2^t~#-JGj!rUSUP#3J4Tp(daS&ynu6fUzbIs7VFjeF!PYICoOjW*By9 zE?4kzI`jP31vddnmlK^vd$4uavseCHgIx)p_U^5aARy{Em#F(XG7hiIZpE!;10Q19 zeZDYO4CGlZRzTqv@j!sKji8yQvO{KP!Ht$u{nCq&wvltx18#m-rstx?k*VOCJ4Vg#s6Nm$ouSazfkWH9wL+VzWKM3uHfomUIPhKCJ#pwJiB;Usp^Y`?4*i34;mD*rnBn+8m= z!`}{LFn{&10-MzS`UOH9p3LQuRN#Mhxfk)sLl`p7y2Y~*OcJh{voCJw(yv?S)%m~) ziRusXM^n@M!tX+%XZP?|N&W?vayi>r^O$o8=##nMlWM${;C99I@+(EyO^|$}?m0{C zgd+ebgBvq?&0HnN&!`oAq4{MO`&&VwEs}VyM!A`m*aLCHVdbi;gxv}rVP7Y)ML&;0A=odyBZ%iRrf@J>QzGe1p<+pZ~Io@r>{MgG%oQ1 zHZnj4|%3?_t{M+O3QIK1>bRprma4lduTF3s0%%=l zz)HG1OdEoV^V$OuXnu!mH3Eq0hmVUZ!oI`Je5@fBdjKZd@8p2>PH|UfcJZ=s2Z$x1 zd#{xY3&XW4BD+QC74iX)Cx;kwzk{>|XtI->Cy9%*i*4`Q-_ZrW0*K}s19JRPr ztq3r3gQ8_lU;gy9Dj)}krpCrUB&7it6U`(Mf^2PA-!sf?*?W9-$U=4HS>iX|a5h1H z*_LM#eg#53lBgpCK}ATyZa2-X4~vU@0&x}6z!}}*fBz1!DB(AtV|ITv>;Pd*Uj6Q1 zJL7kUezi<~=!>Y>cE`W}aySF)j@?C{BK|9(ue2;>m{?UemnWT*G<4!_1v&#r6<5wBil*nk}GlOi9aPn@dD z3QGf`NN?}!tlXJ!$eA~$)D;oPxCf`9DhfPxPqKZAnt85*a|B~oj~jMd4HC;Gw^LOJ zxA_i=8a)7~k9x(sA;k!!Y|yLkJWzt*@;!fB{-)gRclAt5JV2I~DE1#HBrJ8`NO~Qd zmo(Wm*352U%xTZlw(SDBsoS+pp}OYY2wzvO+c}44zSc_THApF%eu+g}?q}S)vGkpcovR3nJ75PnYYd%370ID66z)w^0? zM+J2kLIP~A;RQmUuu^`2ATLjWI2_-0HkLS;FpyAO8trq3fcm>%BV^Ch0ziV5;P3s> zU3_b(3EgSu`(|K+Z4V-azxCX&GdoJ{S#}n;)c9#Vg#3KS;#+w8?N9OhXDy|XdO`-| zBJ6dl+W&_AWj;*Xj$b{B8?n>@BE^S1ho-jIe|qrK_8D+oOuqCHLxU5=IHJS8GT?+9 ze1Bq4vfe49?{5s}0C%8-W6mMRGwp!#rxCwuonG1+bu?*b#27d>bD%YzB00)HiC4Xo zKkFDlp?gO=!gz`^ZdJJk@kc5A3}x*1NSW)>o4!~nAWsMxye#q{ct|Pc`=33OqzW8P zgcw(l;@5h4+PT%-UQ^e+@w79DXjggv%9x;6YnuM)ge)>Y=%@3~@ktxQXmJc73B zAN*;h$fuHI0sfOi`$^luQR@SGJ#5<MZ^x{Si@<#8AIB<|xoHr`-Nb5dt+tWyoVc;(6Xg}+D8*Dfj^H)=zq`e>v zC}-Yc->BNEFZ&asE8?3g>)ut&}T)$ z18_0oqT4U!H~QTDe}h)FplUQ=MGv*6UnB(cv=NE+r63fF8gUh}uI_l{c*+)1URez@ zkMXPGlhx|?9)skHf>kes8#(_jt^qcx&|b5#^X=TaE~^9w&hFZZBB#@^?v`h{B0kGI zr#G>rp{luGnBNQ8${Ko#W|{ha(nM!WkSgqThwi5bYovaWdi37jel@YR<+;}!^GPFH zoH;Kwt3c_Xh>+ATEz{k#o^$$HSq;q;&;3$1T1^v}kFC>rV_WJWYh;#Bn#z@O; zP~9d>Pb}YEkn)!1y@|e;zF{XShfp0?2%U;Z$I1kj@cp}if`TK@E_XS3%~kB>Uk;eP zW1M2K>_=l^b~5~^Wt#Ggn%f)!a0GF>`>J5tQ;uTH9J@2kJJ_!HjEiH{M;f@|UWO8O zIrGa(W^IOKh%3da^Lu~gcR>$|ir$i*cB&WO5V6E3`guRoNW+8}8jd!-Sjj2kxtEUI zWxufh%4fY1G^x;W-=bP%8KuGFSp+Md?z&YLeMf{j?;mcyP&s?6KblTYY&2O|VOBdC zGfGsXA!S(n5Jz6UVgMxZ4x)`7PG384K4gsof36MR3m3t6ww*c!e#TGrbPIa|LuK5q zNY{w4SY5x4Q=xxrCeZCJA1`R!*b_9`0wj|1p$NmE*C6)wZ;aSz54xi$LCeY_c zOX!(Ct$(oQ*(k3=3)czTV|pI-#6mu_#X{HxdvACa!N?Absq88GD-|!=y$}-E@TBIN z8$n0v9GO(a&N&$^8y|!=&E05^7z_RHRepm8%-K{w7;MCSFE z6Gn3Hc5m3I#20$nLP4bAv!=u|0ey7b1-U^B&rV^MZftGDm~*z7pusm%-q4u5eGP^F z#;&$(FKGyOD|Pp*g6b5aQzrJ#{%{i-XW`~6gBHN_6CPBLXV-#{edMouN<1SoaULA+ z*-hpi$dl&rlcRd}q5ts=5@x^8ko?3?o6JHVoTEIzB}Y)qAh0jWZKo2F4IQ2@K_SlEo1t z?%pCTwOH{TXGKTq4uQZ<3oMi&+~(3v2TL2x4PTQZNrDTF6xG5 zaLs+qLMz`CmG!ce^+ABag=2&Eo}I5lQ%H^pW4~oT6;u{j>0h|&XOC&5k6c(LH?v!! zTM@?B7GIqX&|;MOOj9m)bzafx=sNmTQ)zgq(#6}L)$b((^a23e)}4G+_d-j?4fay_ zmbJF%Q-`#@ylr|xX7eL>-3#A?zb8E!poBE(LwQTB2YqH0i_88tjP=7|Io3qdYd41b z4OOCJ44Twi$r2-}PQu!!6?-a*BH1{6Y~O{6vBJfxzHGdG{h2kzhUU+nodlMAeNDdA z{^2ES=i7~xPr*NKIt1Ry?5uO()(*~~cBuOi6(z3uURPEf-!|^AjvuN|i5s{w*pXb; z9qUi1(7DeWnZzc>wJg`nF|X^V-*}%oaI}$4obNAbq%(a6bQ;(#%y~_KO84-wK+V$_ zd)+e05Kmg|*t5$d3P!j?_l18*D!HTl*n|ZcSG77c%y^X`{G;<}|JT+Tz?QDY-o9NbBm@0XF4 zk1|ZYGicbQ_a$RS)&lcK@sMMhf3J?bx}NPqu!vm^_N*%U^fQvpBJ<(dV1a33jaa?> z3VL+%0mV)04s1S@4wfAH&ePa3%6wnP@Uh)qN@9BkSOqNgtfe!3(ew1m8~b(?j*&NxGKz=cFQ_Dp&mG1Q#GJ{q{T5}=`L-yt)HU7Y89V5ztHy!}pN z{f2^$SmNVUCHa4aOMP%4p)9(G#9Oz&UK_Et;q!JKMFv><-Y&k}7180Ja$+sWlogkA zbpi7Xe=qvEWLgo4G){l;Tq&^4ypbv%k3MHrVWJQ?Gdbu*ZqFhV``cF(QS3m7X=$d? z10+Y+B0Y=8=}MLT2ta`z)=tc>Y(C&PyONkRmO`B=TS6OKbYYvDs}iUmyn{46!unmu zusc>bj%OG_C!`_a7T7`8XU$JM>3IxQH+p1D%RVG<4b^dTq@kzKTz@lNY~aEwE2Y&e zg3_@^Q0%;}tbS|0n3DP>XIcNKB-Y|ct6@JSdkXjR>cWzbbB9~poR3J(LDZm$mOZ0p z1sg-dJf|^6`Sy6pQs_Z@lvLX|LN~MRDPi(~{nO31C7Nwcu*7Y?CjD{zK5G5jON zN?cdONJ?uwz_wd0qIVYC+$wan_rm}hc$TDA9eH+NqA@@eQzND?hZ?l)Ww}`bG&zF} z^q3WfDCivLTdr9QJjS}}JH*kjxRUUNg6SG=UkB(V7o$m;DYF;JXfOgzwsqy8Vs*)gBWs#cn!i4qRS*SEf0c22x@95oI4#6ot38F< z3Q)nU1fcnM&5{PnqhjxT=`aV8v6p^4k|N0j25Gm*PM}75=1lt0MlLRwh(@r6Jp1=( zQk<#bi>K2mr9{f&aSNtqTcngx6;5Qp5NzP??sY#X;`wZDJLPi^ozbw~Q!|j#{N;;Z z)p+5T>S)J&eIMGq7EqV8Ne8Ek;1XdIr2UbZR&q%D&!V@NdE*hNAF4fpxn7i?A{?bF zt5fV>;W(Q)6VabSDxASkX-uE$ns#|zd;Q!Q4>7Dq1+KOR*TBB1Izt=&7(KVypG?Bl zDlsZkapdZ7ZN7>Q|9uriiZ5^&gfVv@Au;wEzqxDBC-*bym%3O7ELba6B+6B-GLtbn zF({~7$|yhoh_n%DtbB%^H@<=)y<`TWCQ-O>F#;P1VW%^{1Rfl_B**+w;skmDAYTRQ zMOA^#!yC!9_+zx}8G0l|p=t&pix_6ERLme)R8bOP-$H{MX)L&t8cpe7vYFl8^h0wp zya%I>p+-%#X8i3rIC~oABFDejKLl77iE=4Eg%E|C4Kl8l^3fV1rC(hjV@l1aasK-_ zgNC2&ElAWHc0!}Wruf#T^V|d%ah6d*3c6Y_zB#a$8EKf&qtJepD z3(RhyBwh5NuG#qt5t@_K$o$hr_9QESt0o_59KqWJUD9l&{T(ve^Mt`1y>finlQO%Q zMwv%EK*#3fqbbtg=r6-ph5t6tI=Bv05#;~u=6;|%E zn4YJGXx=YkWv@ZKx9LYQApTm%=&~q&e;!g!LP=|!^KDXoWo1S4gEWbdvtk`I!DOmn zgHnxsH>puJux#%R;0VeOTGV!!l4r7{y{;`3KcQbFd*;u$SBhfw^BM< zYduV3F@-9zaEd}#%gY%r{1zUdxls8L+0l)l9Q#CHBnq&}rKd3J0!L7@ASyRazIx+L z%Sdnr0 zAL8o5l78F86<#g({a-@6CQpp@eAIgAmii5`!wjhe|EdK;P&H^&*g02E+`vTqOeigO1h z9xWSw@ihg@WCtJrb1C6Ib}*;m--gBX0ft4#)eM%VP_xm#-W^#^#I%)@vl8NfniWnzenD>idwb=K-Z#c+h?4V7uVa~ zyQyU#BB?tR+n?3I-n+<9dbX5skyAd59a#w&XoMs5VxVMqKwEV_Rcil4nZJ7KS|9Me zAX^CX&+x3SnUM!LBjYvmbv7`j+}yzU;3P&J(_$~Ie-NhxP>$1)0q38m%mJcveah&I zPs~U~clD_jZc~#hJHQE4RPuP5SJ#L3l(-a*P*VfQfQ(3SuQ>86l`ERtBuE50%iU8U za4oWC4g5lP60m2~#ek425wUl^j#XaJ!YPW(+ESwpaTcccs{I>)9;KR89j+uZbL$L_GIp{ zVay>tCg|`l8af_rXSaSJN!4+e2!2X&1>F#$RA~^OBK*;$YdBju0aR8Nwcsfq5WF_u z?4q19l#S03(InJsNOM|P!f<8=yuiuo8rcmsSoSYYp*ZKa&+j8C6ITM(R)FWKwhYu< zH;>F$ZKH#5u+vLx$Sa>hPo;Ed+g%B<^9V!bZ|K40nt@eIda(4%73*>P;I+eg`u6y^ zQ$j~UV`#RavuswL!r|-}EsMh*PU(d?8!GT*3Hx9uDn&RU>{YT7vQWZHb}cBM5;1mC zS5|izYaT(7UMT!c*Pe}#y&hE~TVWy>xLScDeAWR649KAfOn)=DBMFFFJRX>Z2WUfqjp!PBsU>{DON7y3sggOwuD z{HcLSftu|&^c_ZJg^6Z!HToPz3Qp(urGs?dM03o9JYs~s;nJ?<1!_Bp@3!{Q75 zl=+6srRL3-!=f18ySTjfJpQeW=*8ecO%|l@MBO6hD~XsguTcbdsulF{)T*Axg~5BF zaIrEHo(F_rb^-|?PA|L^X(-~C^j&d5Zg}seoAANAeHh1n22?1bKGOI=w)57-_NgkY4T^#awmi1eHgL1*W=^Ha&V_ zSz_NpfOrq0!>NlN_=uthvIDGuZXoswW{+l)PECH+xF203KZPy`PKmI%@Dei;X0%}p zd!m*FwZTn5+Mj92lYRp)!Exp$&?^_NclsuPJJ>iuT26(L$t0gUo$T z7GYCr#~bL$vU066nfd-4M=7H&%4s+F%plycHU{WHd*x6oI<{hl$u?O?Y-JuqtExth zfl|+U40sMhP@-tW(9~|wnaqcCsA4qB=3*CyKY)Bii&BlYz?@=YbQ!08Fr_>o-*F@m zjx@0I*|rGs5?N`HCFRh=Jy~*F{^nz=-;{fdC(-1zz&GKq4*y2COYiUCKMKabjUZ%& zS`G7?)|2*`kw}@+jFk#2Fr*QgM6s>y@UYYEu*sicy{626S@7wm<5oCQ&G1o+d7x$w zP=r%xS;{w7Erk{Ql!j%_2MeTt;$vpsa$&#V95SM;tHbH&Cpw*&lTFGoB{iFk(QmCV z76KK`X>jWDJcX3q4%edds5aKF(6uj;o9ABMgF~L*)k=KaXXvBczPdNnf2Qy0z_wAC z(ln;!dEIkT%bOj$BI+HD9yL>2S$4ipd=ujGt80>Q4fHWSPG3|~vtvz7%9w)cO?t5q*B@ZZ z2a+*0AtvN#oYc9;Db4XYITYqaM#5yj3oL^qY9MFTEaExI2air%OUzP2&DIA?^JJ0Q z^%7r!+c{{GSlr}$T_V}aTb=sC$T95S_*Wb5Y+CPWU}cZ{Ma_gH#I0A)q)K1uzhKiv zQEl7yRmVtP_4Rmwo5k4XAj#^Q8l9Qc4^;W93)LPJl0OGV+|aq>TzT0LdM1J%r^k)67fr6b z0u8i#t-d77su!mS&u~zSDa~r*5q}-=QrNeUcH5N3@#~qL-_Km>yB&sttru9~wFRt) z>U-dt4+6{We^y(QLb%^PiQ)&zoDL|ns*@-t{T4R8^!x~R{t9X#b$og6H_TOV@Uo1| zNvf)htjYJgg2h%;l2{-CeLoB=Lc`o`&RaxNHY~Vf!u|&VT*hBGkS`mxDcQ=Iy|hSo zocNDzT6O8YvjlVvPm7*)RS%YXAcJ+3#Bu~50*4j*J@-?aRd zGV4)lt~t2MnHVe0k)kILQT^4!2Cw0Qbx9Q|lemFuPZF@)J+rc^?|r&ri4?t>bzIp& zE7DlLU2TVB*R745lI`q3(N~45X)0Byv)DJ@Au$jLz?Qgh;Z z5St!Bi4a_|9G-sTKICQ22sCHTpBctbqc4+4u3sj}`92slrMf1(ePvnd&7SIK^H^ri zlzU#FtuDAk7@7ukQDvk&_1ReS0Nwi=u_oV8H`evhSVbOX-tHoaB)85&+M9M!PEJwI zh0uKLSwTCrLt2D(A}zN}hJZT7^}{ z_+qP}4C@np^Gy?HS@S5*%&9r^aJB6IG<0n0L^xIxpmjXaGoe|+%l#=pJ_t+TUw%Q% z8FzL0!s_jz(>*G2HT-x$xK5d@Zdte08mm2!BTzD{wZ^iw(~Uhm1PY_v`B^-Z-8aGX zhYL0HizK?AXsw0w5D7Q)5`@{o`5l1&VPQfdn;=9Ml5XH`5ZV<%?QakMbssn&o^h(l zU*Ab;-hEWv`|E7P7myTKy$n7kf|U$It@*20b5Hr0&+KY0lZ5N_(?P+S;+T|*4%EC; z-_zg}U8WYvx2$$FkTc3B>5{E5XrCW3R*`Q~lYc*8CbCAlUPGxKhsU!d*N<6P^isON zIpyX|F@SR|&4i2HH;(AA1trwF8kA@{wmCWu^pGtm5PpksZ%NS7+K#aHuA=TOSfEwJ zE?2IdBTMNgkAHKm$}rmBMDLWGl z7;~&q>&@F+vyKfm%ZxO4<8EEicwUWmzY;R78;SE2rPo9WtXR8E26bIqAPW?0y5*=Z znJxCkKL6Nl&Y~0N#eJ%!oa`{tMt0S>XMA0QevMRl1+B#+_V*WLaaw$wLEZxR@)^32 z-fsPwkjW#SSL9KM1(ybl{^nqi)uMMK?4A~;6s`_6Pki%bGa8zEm;GPoQQV#TL1{9{ z^g&FL4*@hU!~GVh%`cRcv{y13@e1esjG@ zVF$)jk64fU;3WBK9#$So&UHVaI5bq{=iL?i#q26KSY0(@BvOU9Cog0h}ZecQIAul z>X#mjp+@PTG6@8oKL~*al_-)29UZzt#0eKD!|_ERW$K=fx@;1Us6}sS@9Diw7ypSc zypYR1N0|((U=JJCrY{?GDyaANx9hor*5Ji$Z%bRj;oUzv4IvJKIqZ)#cWR|rGPn0e z^qj3t)7*y}kf?pX)E&NuvH#d8Kh(ChUfZB}sjqZq%TbFvJKsm_g`jYg#cbhefh1nV z8shM#y8i?5!O7JZ;oR5~ZsIyH(a{MhMi+U!8RkKt|Np31_P#~&uYtzk%XmztjtrP^ zD8-jz!x_8(hmRZW5*Q`zvLaqAp36aT^N_$JW9OCYKFPC{V@kW=#dp*LZnZruA1#^L zVhOh>K~abEPeg43Iox!Shsr~eKltH$@L%PF`2j#N{Ib=-#M3(H_Po|Bd*w|lVgKg3 zPuVM*RUODcmBR0W&Xc{M^8|>wnY$L~>G@nK>@X4ULBd`~fP?h^TCHWrjM@QSNIPya zS2ItYr|DepX=e?Q-qv>b|Q7+3)bno3?!d-9A(6f zep*k4a8FMC!4Ld)QxbJ~5J|MfzCvE{kpG4RC6SktdICQv7h2ZH2I>Sw9z+f8bd|D$ z2FOSLB?tqqJt%Go(}v>mZmG|7sJ~nK3BhXs!8sib00F6_Pi?-f4*nX>yJh3KPiR=- zzqum8gqyygE7`?DC45`NLqPAs$)EOnS;XsS_zpBRT@HxeOTnLjn(X`i?u>r^(=qrb zNMU$K#dxH<3E>D2VlQnGVE;2r6jIjyHZ0d0*7%Yvpy?lsnX=M%86Z~&lVgyB3us@! zC?NzVUj=(_0!Nu9S;e9yLR%8{J5w=KA91h$=2~3-lnkEN4@}=!+umskK6$wAH#zn8 zzrJpISoU9k&terCJc3;7mblY2p~o45gNBmu3`OwH zcGy1j^CrLj-mdg$lR-_ff&h0AtU$Z>B^F%Ei{~L_ZTq<9_LR0#d=xrj-@B|M zEds?I;LiJq9t|p1T?TWBsqj=nsJDqhD>?&j6)%bXREa*qPzVb!SF?7Q8^wyhIJK_#< z>9ub0TOfA^4-Abfia}efO6tw&f%XrgVA@3AHmYg0*BWkaef19f-15fg z)zg|T^=&+0p5my@9h)#-s6b>cVz68)ng>zc5@UV_#S^zNeN{QwU%i2xP`*_N3^$o; zl4ryNh8>NwSWto@mo07OL`f9D?2Ng)Cx0(cn_HNkte#BhvyvxUo$B{?3xifC&Q%wO zrWZ?F;5kzpFPk8EIgq?{ik7Y8e{A^2Ua_IC)1a^OM_Ly{3siYVDTEAai~csA-GT_< zGB?4|*=tS>Y>0r3QT!sNCwom%pzhXDb)#|rEe1(kez{oof?xau?eyXXy~rwMSt~eN zILl(GYR3f->Mes}Wa+wZ5%5f=!-Wyy9gOH|Wm(XjnXDBc$n6;OHuHN|h8M1diVxT9 zIOE9^cMO)-AA8M9<^hMlSe&{w^n&*$kXHuM-5e-Xv22Z)FW*0#n+2Pe4FkBaY@gfd zz;1vfY;5C1x76LrMn2>D2gXeVxWCw+TUb3~P;H5}z ztD64n_Q38!htS^mSc?YMpMIzlLjM3~OqM+d&5^e;hPmJ1Ji0)H7{g;6(YQJ= zpqAE~JyrcqN!86ZB?Q(kV_w?!;#_!ee&Ubq$1laAR&G#g^-}YvK5A|Zq`Bo2OMF2T z0}EOcCv&axEDq~A&x#GosvPgCg;fDO(eC7MKpH%-Zf4ty?Bh}9 zqL1mx?a6tPW{Z*)A38|3xY7Twy@~wN!3FeMy(p95bF9(h06yMgEG?4x2CVU{Xw|ir zwxTlnT<5UfG%e^nxSRNNRd*d$Q+QV2#cItU;Te$_Ck`Ib0R-n-l!SY_?u0{yzE!WR z#+xY0<3k@^%z4uo;BlaTea{^JUP_VhUPCktyj{UY%SK-x@Z6SI>{;Rxf2c$uXz07# zZHdfrdsyz^>6dr8lVF^Ge8uZiR+QJcA0K>*OPgGEDZy|ttt=9*?8$wSmeMk)PIkTV z>R%Kph8!`_H|gCMbet=e#BT1fyW4Iea@75!yq-L9s{B4nFxqrM30B0L+QE+3oOfP| zvg?g`t?CCWLVgkN^`1_Er$i=A;?$^7^`$3~k+Wj>0cN%lco)N(_-`cF*t4?$$)izT zqjZ5m@3#c9B!{^#1wyLrmfnwt;=W>DUT)z!*Fyf@lS*d;b%bH9%nwrJL*-YO5{; zg3#mR4nSIUTh$^574+nDraBA)cXXBiONLf3x+4nq4A*yWj>cbE8l{cEGO@~Pk9cZ} zy5z!DgJEOa9MB?jQ17ZM1DmSctfk`lFRhK&63%FBxd*H3n*to)T^a^nAw6sAjF-Bs z_hCKrMizMA!8tuW=0CPpG)LH0#TpE%TB}ltr0olE4d-I&ZjQa}vSB6?&nI&oCF${p z%*_{Zv@;KyPW+47mXIS;@3t!~xV}@FYZ?6Fv!S>u79}V@2$s?CHTVM3Xc+95Mm!q0J(SkzuysU(UlU>A%1_cTD;-_-6u?{M-T|tlday~rq&3^== zsIdN`06dsdMX%MrA~y8CIvdt>(H!-yzeOtodcEcIEutcUrDR($G}DSTLM4=xj$!_z zjn6Ra99CwOQq5Pww$GPxkbm#uPO4sUGN;r0Hsoa3ZPkqQR?wTwTta=aBIl|I{ljHy zsRz9ZJpJJZ@Xn^a^miQsFQL)7*7_!%Qtf@r_9?RvP|}o*q+7aM>Rr_uV0EnbGtr-B za=iS{QAX0eAN*kc8St2mJF~X~cN)7kPcV9Jj|-r$$F3hg)ujeMSM>PKvvo~y+Xm9a zjCM}37Tua(EM7^BqyS9dm^r2gJg4y_4W|$k!l+5{j{dQx|L2#3W;v=gjOd9xFRZ(D z$xZVW(mJ!&k5l#K{KeYE5<9PNpyL46$j@fof#r5Jw&Z+RUCgREOwz1%aT98lCU(cp z=t&x}B_{q1>VK9-IK$M!PtcLUpol>=%e{*N_qNM*zRa zone;v-lp8~WhOf_>GBTE1>?ZQURUs{jzsDG+V6knt1{Z*p)lBCeeOI3KZ&wZmPL&@ zlL8*{1MMfZ3wum0q zyfD1~$3nkT?Cx~W#Hi0)varWuD%mqtFv%J2tf5-^|D)^8z7v-LY#W8v zcX=O7a(1aQrheVs4S4B&H&FA~Ba4)SDfe(7n@q188f|!<1`KY9K2n)KbNyo9?%$8_ z%NxmryD)<`%hVahqykb+`(9(hnoMs|zOki0rA$d9=&9Xwlb}2o&QMNE?K2KWl>#E^ z*Vyuhb)f5YUjDo0994cgpep`sv-T@{*|JVygEE^Sl0ZCr^Dv5&Y zvrUhx#5MDoVcMW+Mn8{_|5em>QhVKK6m3*@fD7$2V;m2T^ZyQ26mct_$B}$67)a6u z?!11%HEU7y%MZpnip)->CrEM5hfBSA4aLH_j-bp@lib50Qz;{hW;@@c`*F|i@Sh*^&JS$Z|pfXjf>-p>? zxsfwZuV+8b4!RPQYe0*KP(fwQ9mczg=fzS)f3YJBvI!pu+6|HT*ki-Rnc_gGth2{_ zh?Xr&^kjB zYPw{1@euCM5dP-v0;Zb{q~qo8A!hVs+FtquR-GIGSKY0TL3QU;m7zouo1LmS5r2>w zn*R&QcQu)SNLxQ6p{+9N$)L_J}KafNeWj?n39Wd>)~5Jt!Ioo?{6pJ?+* zW^GW&jrg|mxI1u|PwB=E^E)spslIZCz;SeJD(+OX zvJ^u|AK5mG9nTweANBfYIde%URrv^R@?DBmunq4{yULbILc3}-!687_*ay$c<2R2&<=|8)M`_2iw`XFwYJQ|nW=;yMtH1JIvwEFay41+1R=c zol|=ZgtLkViU;KM5SddeS3nn~Lv5I#y@)=wq(v8}8;gjIVTRFVx;X4SvV=nDIy?UB zvT#)2cN;O>YeSuSBp1BMTe#mJL(6`c6w6Lu7syqs2TWYe`RH=}dHK#A@y(M4_yoZT zFNmKLDMz)=Qyr7za!Wot0iOF=;K~a#1V1r@-`{;IXZDr_dnG&Yuy&2r?1n_|54BEQMag9pK^5Sgd+eLLM36J1Vk?VC1DspAXZh z)I`u4i-v@z<30;izcv%!ta1vK1$5BT-Vs(G?f z=>TNnQ~1#Nd?AE?dmF^ox#I*<(lMyQ6C0UR?X@ctkB2vwd#Gw-=gbW}e_v=Szmzbk zIp!yCNE82ffgsW4&5AyU^eH{Ro7&npOz~IzBo*IROOx``W%)0!;Q9FTOaUP>Uwtw2 zfm=y!q3lIv{Mqz$yDx~xRyBbo90Tw35LV_<%2b*PYFsLUam&3ctpiXPzq!)>?8W+7 ziheA6bSkFjBzy8YoTCj=SUH6?Hc1+6tTraK_W-HFxA zxv;xNr=IcPvGJ!I0bt}BxAF8k2maZI<1lX7iTimInkTL_%4oEZ!sY}!hr{W@`?t8^ zYK>^J`^{N3hm_GTkn4TW-EsC6LsOXLrE6z%!s68vJZ&?N!z&g_qYaNQ>n`pQVQ20u zZZ*nx8<(Xgrrh%v>)h5vA#b%Z?X4)((jXil3jS7Cffxj8zKh9twVs5XJ}ZcL=jcgo zkK+@n_l@!^IXI0b^Q9VUZ+>`lIeyc}D&I4`d&YcrSY^HZT?rP{@GFd+LCxs{66SlvSlPC2{~y3%M?~6RCiobZs51w6Bm4GA;qq4;Np4Bt6}0!TLB7OY~~)1^tQ?z41`h)Pa?Zh4_!D{ zx0XG6{-e^eWfb)J^h#(9!Zd;q<=Y0{m0v2HY0K3^(NXE4qP)aV4KJ{y&V^=YerX-x ziA%nmnPG?}R7`4*9xyJaYJQ_YcN)CXO-WOq=rKPUzHBRtb=>Y`^RL)%09k8B8^RyO zun#;2y=N|v7hYU~H&qg>%$}`fLQ?i<-;V|ztwGR$CaNn<-`|GWO#C-~AmC1^Emx+| zm5IH4(?gi)URM8f$g;aL84rI!Xzpfuf>_VqIK*i$qfV9A`Sk9|*dTSaHwV3XH*)WR zUMu@S;~Jks<4YU*kJh>L?#X=X@jlJp%lX*cUIbZB6+SCq%lS?Ro!ddL-1VNHoen}5J*{fQ%?sFp)8%5s&4P1;7t1Ge zb0y~yq=f2k)~~g?tk(b3%C6(s>2voX);}r4W-*Aw(UI(=07slZnQL6_QybYxVhVtE(mFXCdmm-pN*4 zqz5nR7p;#Jbo61RMkJ>pVXiCg&JI&&muAFr-?R{-T5~N^85Ya3pjQ^4FU0x_C8Dx9 zqJeiUlY85rWyy3fXDK!cvi-b%@e35nvAC77j#q<@B{sfRS0TH$2u6d+j`?=4e-;tK zM3XJRt2_0mVx5`WwLgwPPeL^bAx%!-)Rj=3ObdD*^PV(ow3uMg>>Xb&{VtqWT$4~e z6~^*>3Q{}IbMy38W4|CWGlt-pUVRQD z=|!1WS>!YN59Gvcg*&ViCBCU)%yg`0pc<45Lt=*Jo6<YH>_E^|&VLO5P!|7Z*Q%~CsH|`}^)-I=fB%)kvX%i6#v&HuTA5 z?;B`qGnnsroKq7F@a5TF_QV~N=AsMqYDaOIFVpe8oq(l<@VI_cx&*Q0WmSV9F9iIW zxSc(o4_qzhj)XS})eRYK;Dh_#ro@z%UlCWQ?k6oWXBGB#1HF41jp`G7e=H8>*Q-Js-=#A^`+REtk^18sc8Utv9 zA?_eBie0j2A8iti6M`}dO;U4dD7>2rilqM}yS!W8^}?C$Y@qmC{t#-zt_7WH9CEss z!$F412j1Zzhi8od7QeX&2S7BK>( zwP)37_WNN)vfY2SzfI(c3coVmxM{{@_WD0|>9E<2LdK|WkhRA@Sr<1}oYbw`a`aCm ziVnJd)A+3j4)3C@Z2H)IUS{$wXO0BDJpSSSN}ZM2NJJa?`((S<4yvHiEUm6=^SH`q zjbmd=yqkMJL3DowIOajO0Szb`#RHxnM!(BjZucMT(RHUG#*$HC7I<2(=VkMTh1tQH z#iJbT?dXO})^GZV(nDNLMgsR^b5n0I*6g(c%(R9Q1?i!CDa5yY(Au{6XhKm z@$7F?iy%ThVK-nh7iDqyvMpuwy7CiW0_F>`kEOt?^UFC7#P&pM3w$Eg7RaL5B#X_d zb|()*j_w)z*!;7%35XBj*OOd2~&`~Ye7&73t_zSu-damML9qkYpmynTD z4_i8``7Mf8*X4|9Ks$MkdfS=DbFx7>Kp;CJP#LDWmM;N-P!y?a(1#yp<}#uCH3 zu=&~nCmOJ%bU&B%|@pr=9K|keqewTCG-n=tg!nXb)wA6cV1PKOui-PX& z0Uusf9nmS*dXGoxH8QoByaIjE5`=MS+B3%>>M}3U@jpd#9LzV96XtexT`eQn49)@>)>l2vu>i z1RR#+-=(*7?#Iw!l8^<=YxD&P$^-DRZZ#8)|`wooy)^BdZ4MkQ73OU&Pk5MIvpqpTK?_RD%q6CF5S*$yN{ecn><}COhh7BrnXFRn zWsmd_2M8MNx2khtC)lP!WsH*S3cOQe(#wnc$at({Kp+}`>o9G!t#0vRQwaucUyE`E&qhLTsQ~h ztkipSZ1?W3+>lYANRk471~f)kI-1kT>0ljC0){ju4>!L(K9XJj=bh%4`UB53Im;l^ zncY(#?%gvF8yh_Vlr0A^bNlalXC`JWmx5ovrS!yZ;0}9tJVvLb+6Q&ODk)ZG2r#Tg zw_afCm&DCX`a_xLCCrsgNEsWPp=V6Ks`yA$?RkdcxCG~5emEk@C8{nd*o)CDmv~Cg zdSVwk=Y1=q-_h)T3nj(Gu#sGo(fDe(A3np|Imc|?UseG$XJS(S$xRhE3|3}W{wU}P zK{y@WHNBKMBrM6+lU9Mjk0fNkCF_4$Htw{lX_$6Uded`g7lPdLKf~t;scQ!>H}N{h z%FN^;BPO=_4|o?LBSY-Bn)a777A=dB94G^(5`0buMkLfR_7(tx_nUd|o|(1EtEucX zXT3VDbaH72%5;ppP~>mx0T~n$f7WGqJ*j%={kzc59K{PRXWa`a%>pba-P+hN)57Uw z9XB~P-;dw7C;s58U*~6amHC*&czrRmubj_nuy=lO%?wmwJv@t*iXIO&%R#~6!iRW&st21?hP1C1bPCMZgmB+O>Hobo)P+B z!R`I$&ly_ck(G8hrd`-H;D0K<_lXVP?$Ozo!6@DBPH0=5~_tF6bkEYmy?U8BMkU2 zcU`h;TxbE;o9XiVZBf0wdve-e57W{^#1}2t!Oc6Yyxk+Dd40p75k-8@@3bvRh3}d2 z0B*o+Z$T`ZNTvr(_axA!{ib6d7Q2LPG#MIDg&TOQ&1cPA(R|rWFfEqB@Rnj1C;oP$ zXLOvviF*#FNEttw)V(<(ekyl{7Qo;9OJcEoYR*&0_g92MYkX{j2-ZQei<}Fe#(!HV z_!#%h^gwKP_t7!`(J6v^S9ZE+?g=^QMwT{Dl~j3>nE|FZhU&;vkQ4URgmQ`C z2CdJ=Phvj?x~nqMITri%7L_S0C}~;JfKWgTBLYLIF6uD*jYwCAXb+Vhz1`mDwkm);qv&msC`7cBt4_~ojvTi8-jXR9P?(=&9=-3DETVYQ6eM)k3QEIrwen<5ef z`IAv^Q2H#RPPsH#FEamWnh{!BGuI{b5@nM17IsdrL(O+K1Gjis8&;A=Yi1723An?D`h zHJg#Qq|yKD_ww{C^tU^FIWwq=zVgr#W(+j^_{jMhpbm;gUeMR!3o9z-Ir)jj#vU<` z-@GcYAc^^9TS9soV+yHVb760n6_8?lcw+`26XNj{6#ZZ~p+52k9@l$h69r=?gVAlv>EXDhfzpHPzY-?O?cMA(vKRS@tnx4)7 zC^!>l!d#^M+Rgbz!u)kP*C##-{JXMdz6tACSiU9f>rvN~`@>$AunepB1EWq&=Kl#H z&{R*%_X{+v&iETuV3TF(6Ux+F={LsGcnspS06WO2k#&34LfT*P#qjvEosykCu*beB zC?5ELJsKT;0f0-^ie#8oOZe*;QC}10^N_N`c{Y;m3=yc=^(s7=5WHalDL6pxX~)#v znjCfy7yC0xSz#rRs`xvrmLW;+0!-j5M2PDDN9g>cgc?*8pP7EZLvwhqs9wRYfDqGj zUS7$R_wrd8%pg7*Q%PtDfB0yvINBRYF;7eOgn<>YzZ-Y?VhwqL*wko$s;Hvkf?Lm< z%-rEILAZw>`B?L(+3^Q(xPEHp2o~h9K(wtS{?ql+?=?GsW{IZ&U!k@lzBA*KlOoTj zwnuWGz5hP;%Ua)fb)1F+9B`^zYz*xqR22Kjc=4Sw4-?rzD>@UkYv-GAB&=!kdtF}3Yen3n_!TmR42qk}w!v*LKoab~wkt{s zS)Vodd4*yn`5}n?zFi9AIZ(l8zdeLs|0$>1wAZ@n7~eVQU0{RnujX^cd*Q2JgKBw# zmzDR_AAW6wf7ZeawGp@|?u<|SeG#Y^el53 zZ*d&0k>ej6mH={lAI<9S$K3pXKe6dG@SqzU*=yC}yXi?57yP__1Yn{X#J1+_L_^*s zK;=SNTCx1(j<5bqATsblHf#`7`+Ic&oIuUqAq&t@!*rH5Hnv(*&97|<%U$Xr0Sj~Y zu9aEQS_kpBcqcX#(jzzeT6y!;M@@)RAJ`c2mh?%;hv4V1BN6?v8To%ZQtB-(choM4 zv+a|>s5THf6_MT{rpAE>0a}3LkDxe!mh-fh^Lux^C2QSKF@0NnYXiDDAMO+m| z{~Sz94Pgdoj%j{&C7?#W0)CG3A2JpNLY5zv*72$+fWqb^P{DWd^99d?&4bp!FP0)h zSZ%;B2#p{_40u`QWW8d5$Nm3!oCxGAJqh)^<1Ou?(Bb*Bd72vVeBJ%_MU!y?d~&1r zqvvdY^B(&Dp#p40ZXOFjvs;S5bMukZ=C;18Ad}D*l^i(1JRNHa99Dz402%W9_@6ze z1%SZWsi)vJ`^$ce#9G1gG8UWWO=yP0&7&7fKw1M(vj~GOzS}q7RG{q%1Mlf*fEE;2 zR8)fO%-KH@0TF0x#cV>ggcej~Aj%D&>l;0cg$dOkkYBHBKj}hEaq*clp*HGh?U7Bf zW0frcGHBw{*V2`=$(7Uil~a>^!1KGq{U~1l=}V(-R9B2qf_D6WF+C#h6-5912tvEr zrnYc#lBM|@a927utTTJ;`zmWm(0vHNL`LGJ=mZ9iTAdn*Nc$4-AfyTV5AP_Sr}5}b z)X}v4rQYG?Jv6lS-^(4>pV_qkHqaXX-A2zgjaNd4?)*LBs{+>hDm(Ma3)2oLM$YYP zKpOCV(6i4S|Gv%tbs}u-#cs{X!<@FIGp5vfsrTYzRZD&&CN;#~`_DdoWldcX7kL_q<0+0oBXc@d6nJKl;()5)|6WLj);SRAGK+7`qn2j!+jk;~t6Jf45? zm(yZWC*QvTK4@4}Fa1zI9`JQbSaMJJ4+_8U*^=cQk`U7`Vj_@}N_*P;8BQNdU9R`}obE zi`P8ipKr+(;@*t1_;RWaq_4i8Ox$~ZcJ;~y@Wh|km;xfuP>-JH#1E>bHzwMBF;x$Q zZ);(_SL65{q)|nC> zVELh8j-hC)h*EDipOvZWa>Jdif3-Ru-ESWLWCT+}HOpJwzy^M;TqPy{uHB|{v#)Oo z1rE_Z^$TcTiNh9IVG6Lf|BXVgnd~r!6l-g>0uBJZa@Sy)-1HW2n9mOH^db|k5OMcz z8GeVE0*IFfodzaV1lxb}XL(qNFuZy}$6`8l%foupCNAQ?N3qtVVm4z^tP{H0#K9(9G|>Ov1!M<`!*gJpmY) z(t8uB_#nqV;?Lh-L)`T{VVd3mh@yDWP5ce~ro(%@arKw_hZ{fwos=+DEv*JG3IoF+ zxc}l5o7OWst)UJsyKMX3BNUj_y-c568T$*#CUZ^4FcGNy?bp**e>ec&Vbym05A#mu z0W{%|D|rE^4Zah`85uY&%(``emB70DZxpL|V#tqwa7kZWIW*YU-T zHmRYp&Y0((Pyg{B6SfP^Ka@z3>L7{d+4lUAnv|5o0cb~qI= zRc)dk?GPB2-1NEWj3d@*E?V?_^}k@(RE?wgC2qyHhma>;jVH%?S-~>ac?96&Po8W( z_g|N=W%++6ObFuvn;-kS@~5it-B-Y~i_;p^KckPB7*d*&a%}3u%UcjJlH^ky%&odV z5T-WD=VH^uIH0=L0A>>AzaQhcnGxA`ua)JlUCk!5FaDCOnzp|;xZLTb8UQ}eB*=r7 zdN+>+=JQ;U$3VXGp915hGtGzCcJJUYXbFcWWZY8CmO-=h*I8k?RnSCOd%m0_exiEHJd~ z=g&jDqQ=#Ic}-u$$5&o^hquy+81q++-rpOJR)e>?16B#`7UGuljveeAQt+WV<`uKI7GS82=Ign zB8F)B!|${pi^2@*k}O9AswU)?>5X>@7=xN4m>gDJF`}GeE>%1`_SjcAkto6o&>hk4 zjHy2TN|Y!Wb9J3|lQ1n>1x^7ok$dEpZ33uX^UmSjnBS`lwNbi(N8!rfs5RuQu3pjJrNZ$+9!7#+W7KGiNzM9Cm5{RD^~L9SJHJF3Te2_UA2N_lV}qYojNd; z?;16YUkPGX`ImvJgG~BLp5a06Rxt9;m+eN&!-~>;p1Jwr+`!1a95(UiMQD{S ztRB_G4f<*J3alc+_}F(22~p6Sho7AB73e665!6%}A#Z+=Q$11a$Y*0!Q{@1^>d@)l zb#zQW?i3+d?OP#)*m(aDC40M!Evb{*&d&N8tTR7bmD{A|rS+ zM54|aUf{+EyFIp@ zv4JS|bgpn5x}45~Qq}kI+Q%!$4przTm{fN#4IXj0Q(PxuO=eFLoN#%ulJ{Wotq{da zy_gewBAGNLqfdwV##bzM$~kRm{A-*C&p3vW-Vv!_LRaKq!h_=pd1v}Z^({J_+-R7? zX44ENbnms6Zu4Tv^eE*t==?Lj&6x(~jr+6L^PN#7U-^F9e>u_GJvk&&%Cokkm>{!4MB=qfulI7gY{~5CAi-Iq zue|9uT?bFfXZMgYD3-K>oNqs6ZUdt%*`ymkxh&`Ob#*bl=CBjs^2R+1vr2`P2c4R- z9o0TsqBp)fC6P|lSicmZllak0$w8u8A*cn|)>fa_awu_Ck~y)Xpnh}7i`(cMaR#Gr z+OE$Wd(JR3=LeCj4pH8=$$X<;ER*5r_m~sLgPp>2!K^SiW2=57I2`K0b?%l^D^Kcq zX6yn;=7CN_rz`8c4b}^m_so`!VNP7-Z0bRAlg4mF%@lVbFEPl##5GC;$vrDboBhtZ z-!6(P8$H5o0yA-)3oVu`tpJl6)0)C)9(IrT2Er%m>OqC>P&u(sMMNL$)UiA$gxdgB z_u9;b`W98kT!rISuVzkCqT0+aEf=#BO9D&j%R**-Yq;P#_%;-@5Ojms0C*k0yKAwN zvSLoOwFi*-WX%(!fuH&Cv}-IgQM;U@&1QWNT+7U#n-LQ*vI7WDvVz#BnT!h14D_(>x{$-+yF_JTgT-Q+<0tyn z{$<0+r$q~bCRh(#)2MQ+aAzzGi@~5xs#3pvk0Y$W?Z&>u_SaGxRkF_FEpADi{?ef# z(VgWbi6v0Pl3DpWF3tx#+go3^PGUXZnF^s4SKwaH0t!_W%m_tKoISGB56(pAm!vY1 z=Vro|1U_99YDIGT0bAJFbV)_Ta+aHoISji}EADj;=%7p%a>d(?+Jc}{hdY&(!>oNK z-sTRf#n@y=HoAEy^NT#i=K6XERJTc$V8uUq;PIoo(y%d;IfSJ*ld;-%XJePDo}b=z z-{*Md=x$=F)Kc^pu$F|oz)q1iU|G+C7x9L9eSe*b(6QT*a)!ylA&uJRVd(a;YDbUQzfTG{Hu;x@5!5I!8TG{8Amhi>db zp{_4Wh5b3p7lSDmzS@}R{H5ZS|@tl1$6Ge^3 z^{4uIll;yw`{1J9NL4cCpF+`R5SRY+_-DDr;kyVh)a_$djNvFA3>UjWNFd?PvapiR)55yk7fegw zVI~xJf4cj9z50@z2 z5H@T6wkqAxTldl;e1AQmC2LFn_P(Bc)duVT{A2MoU;d_GU}tqR4{7CUe~LTzEI zKXd4UZ*1{=ji6EUqTp)H<1fe;{kemjfzr+)WYulu0SDjhNJm@9ZfA$|V?E^(Ws%nEq&x%5ooyf?ICTCmmz!PtO>)5BjD6qP~2w2d@Vv=w}FIx-XNW8 z&Y%QSskAEJ=KYPu6OEEwsZ?~gzA4;!wQ)U5gSM!p1TS9M>D&|LS9iN8ST@Gv2FevH zB>BH{?tCG$?AMakzjHxW=Vw^28cV{0Nzr%?X?L~H-)NfG*x15?+nlBwMpZ{-^X{E# z6Wz^d^=La}{z=XZ zufN&WiGGtsZ91G!KiUf(AfhgBXa)!7dVgiZwALrt)RuKCk#M757ZM|s1%RYXN;ggU`t9mq;vBC&Z&)) zK8!xMm_f*VeXc)kKQ6z|g6&LHbH&h6okY=}tVidn2J34hB-dz$ zuR`c?lDLsjI_bu92 zy?c;3GLg{+1^;DN zKWJHAiW^>;EsLx@5*{;F`SoZ?j$q6^F};}Jnlh?ZViyqHLSM&`NEBD z+WA5tBXpL%Ij=?dD)UL-z8Dw#NK*6-Wjq$N5dp3=7~%(ddx!X)ToGg~)Gu)DR|Uuw zHLLjgq)D z!l7L$BRr0B7ksh(b{5pAo~fv{gpBie^K#9dR_ z5x-~9enLk8H!p2)2E^a8%YFLw5-fPUJl4WcESOOLd<6` z1=ExBZ?%3YrD}$NilAucUD*l9@F}$!Jz*qX{M(o_6q6a4epkILK|QYYVqZ|dkImP+ z#p<8k&(lX3Q%mY}aoc0{z9i9!#oauA$IiK5I2pDcu}A3>uaMyjyIuxF(R<7yjYoKi zD@ic!$2BI5cda~#C%vzkKBd+%JW}zV14kZVnmV>iJLaa}m)@V7ZWCxM+4ZCJErFWX z=fwlMkM5DyTLsroM#9I?M!?$AjpV>cc%>4iAIEFI%t<5SLhp$aH_pUQam7}XIjv)6&G-qPhD{w zdIuZV9R;`uoj={JTVLepQBm`F8;Z7UarqJUfzso`(KK*t!?|-#@jd|Odl(;q6D{zu zIYKQsR}4v^DLR=GHw-0(S=gr`1k`}pUacdoBhl`pjN|cO_AjmW)ayO0~mD; zNFC2^BDP+`xrginCc}F@~cZR0eN9+jH6K!i1#=NAcoVf&}rxp012lX#EAQmXfmPib_ zrv_22`+%S?Pdl)UrvX_tu9jMc#{qGH|AlkKajm@Q$Ovi(%IT>a_mL3chuQ1!%OiEcCYEyL%Z*1ds z4zvE|;PEBEhDEI1Ujdcsv3~cHt5zBxJ*Rj)4spznOE>Q^7+lo>e!)FU#V*gjz|@ zwX&L}?Zn`KlTt{g)N&O2)l)<{e?%uMAni^R(NpO7B&4w!=pTCACi5ZZlv2^hej%4d?uwWcDEr`oB)X1C}&+Y zNPo7};y|f0*stImCp{d~%-A|Hbqdt`^Q3nz!LRg*;2lw2?yf$te7$3*vd41j*fp}g z3g8mR@>hir{_y|h1F)zeI*zVH*;<)7Qk6|S{!e)wn>orC9nq@0YvK$fz!jl(2>EC|8?_OoZn7xNY`2voRC6)Fj6}9QY0E?^WvNgzE9wMpY~MdWjic z;A8dtc?OlMX~bBYk-pd7;=H}+S%l09LqF|=x^FU%!Aw6z)^}w5(3QV8|JBkm^jec} z3nE}x^pU(n-|lf81Wa&Dj4;c~s7*jJtwbphj*sg-kp{`Wf-H&l1fcNevv6>l8S0^C z0}foS(2^hw3UrgKYuoR48;=o-qNu#=86m894+hnp3(`4C{@EPA=~a)~9^X74DZ>zo zn9+GYHDVD5i~>Mgu5G;{F~RO6xrl4B;KQ`+N1Z?S?iuFbK9lw=RW=VF2&;hO3P`cL z0FJ(s5eD8$m9&J$t}vB%h9HH*T>a!j0P;y5o7&ix-v>CNJZOPh6%=6|)mga!T+zC8 zsd-q|wZsA|a%c|GrnaZ-#h}`^Iku~p!4630g~+I}cM77%H{Zvx-C${AY%}ozC3hk! zF0!Qr9l^ut-#7iz39n3ly?a>u%#{N&%Q31~J0_T*urd#eKjPT^>IiA(MTTSq%NnA% zSocib^mhT*Gp?3Y;2Ms*xrEF_2bElBO8c=7fZ%tsU1?Prq7;)dDnXC@-vVVx1K8=T z7xCFf%;^Kh#fLCGyrM4iwbagz=?>)i&OE1Ip71I0Okt1%qte<>c?lftF$%-fjv(ard}Zd# z3mzB1>x-OQ?Tpj(>6VB3s*dJxi`3NvfU@Pvw3DV=qwR*xFpD4M#-L1naNj8im{NB) zJ12~i*K3U5W!1Jgdy5OG_$W*NQyy?|h3i1D11uhq%>-h{AI{1V$p3>*ZlhwB{4B}1 zeUk2(hZz){zz5tbW#mNO%CLLgZ&WK4D3ADEf;Z&<0|qx0twM-Yf(-t_cPd$%rbz*MLXi0w=!p*Lh>s&2J zU?9Au3^6bBX`M&iiX&JNdGaaG_@4^=qKVX{*rvQqcaGWJxuvh_mYsdADOuDJ+J=q( zHfBC(2;fiQp`a5t3o^hRsC$x;yztZ{B!G!-1(k%9#j_n-+#*R*VIPcGLq)<*rQ4kp z3`tJ}Me%oAwI%xp& zutC9`1Is|-zZe5c3mY!(@x!(NGX75TqcGyNG(X%caHfk$`jW-r@Ql;&)6s)L9JKgl zm@RspnGXe5rgF5IeB%prTtV9EejYBvOSKu?@wO9Y93A$>)}c3H(!*N8wBX6XePrL6 z-SBOSZQgcKFt9_58RMymMwFxNu|RqgGXHzlj>~vH_A=<(@8G->+?chn3m&^15WMs}6JQ<#jSzr$vja z6UiZs;f6C3hI6<2SCbA|T)rBgyYw(V4`}SAR+Bb@C><=*Q7@E`MF)gXRfJ8Z5`j(Ud8&z!L%5MCw&Z}zg^5Jb#0?28FN%9ntobjdQGP%4YGau&aAi%M ztQ!FO@FBh`4^xop*L9z~SrE=ls7`E_0;2clp@G7hP|j)I2h3bSjGz7n_+=?LsmY7} zc5EGDyv~xP83T)2HUz7Z>O1n48%!cYKl=diQ8#S8PAXWyErf~ht;2Q6F)c`ehm9Gw zZAAtEnIHui6LcbPpbodV%bgA=uKyBHTr_$S4Ln0f595Lzq3UwY^0M;G<36ZtB7SkA znKbTCdeXaR1Ajzc($#^UIt{*3hg;w|0eSBLr_siFKyInt7q(nb)jj(wOwa!{p?aj^ z3DGE_8(ZGh0L+Se66lgnoBwgdB=$3y6xj~nTC);u3c^ItL7=b2;9D%xZK504m}XoZmq(cSO3;*HUK;p+;vRiwssvW zLbjIvGLgF-S)Yt{GS@V|4QmV^%L1(w9cxZUAaMJ8=#pg%C(H==j#)|1myVeW+>Mzp z@54h2HsR91S?CO~I=Xg^q}02CPPmpf3hL}-Uebi>6D85V;K2R-~;;0DTMF>_#k}}OsugtpTE9gS0>wKE~t7-+~S>;eh+Yg&pBloUs0!MDa>9m z!kjLgiA+V^v zwMqRj4u}s{SBu1tqiPgGp&OMg!-qBERz7fQYHo9#iPizHsBc~B;)zp{cIFxpaKBA? zCR3ym+p)CI5=0|>-Nnq8AJ8&kSqPJ-*9AO1a81d2QU?WpjZD`alcqhGV6KWCEpGZ} zr5EpZaXg#n^?Oc9gl+SGy-G&I3@>Nod*xUWtxUaRq8+e-vBUAq^a?nvja<@QJ;js^ zu#fPq753*R0eNd!sXB3rG}S{syN8K=?7)7|51U#V49w`=6Bd|kK{$o<%VYOjTE8c| zVJ%Q(|CjhXW;XQa(P!FMr^I*G)vxQxs`QViOKfNTxd+vTO24&OABpKux4yroTYTx{ zGhMs0n|Ionp$U!p<9=^ee+EILus9y~5Eyb}rcx@dt*BL83e8$^jG7(42qdV7)iUEw zgc!M4;~!u^3*squ9^4d+z5aKspFv;Ow(xIqZA5+iO+eu};DHHpzBxedJf7ge*j`Kg zpm&c2{;c*Vz614qc=u+ebe)k1keX;FGyVuq0QIgID)@>Wv)1l~GOPsn^BR@Ll-(mNzymCkehUS+Y;5RVb6 zy57-tU&q{eDMEcwZtQuv+*^@-f_}w1Xo-_Ka5sBV7oDyTAm+x#aL_p+Jt_K|bHrLD zQ-;X?WmObqPt`e{V(T!!!7QiDto@~nnXeJqbsOp@2-k?2p)pxM^b5PvlS?~yyFnIz zEk0kAV|P;iym<`6?IGKhm`Udy5IMlfs_Y}Q=5@dR&&Xq7-cvzi=Y9ZJJ*?ljFw@5Q zPj=nD@FBBD%G0^!C9&jPb2pn4suQ^0ApPfnS!af4zgRjG64PYuXW(SMm|L^tT|0fP zj#8MreO>RpA-8X$s*jkiZ7Wr;*MSq8L%5Ye|CAZ<2*A1f;p4o&_(%(nP*1WGfwH(r z%hF-;uLy2h?=n!^79LfxKkbKI+1bqN@v%~{T#w|A)LLSMLg4N9ejt8XwrENOWW5Q= zuDcBkmGp`=NZoP*8b0L0v~;h(znx(ywgNldB!_}$(`D~? zpEdP6?{(LqYY{iw;;m)LfXpOsT$q?8u6Wd(ho;UY2cje(A`q^3T3Q~4uJVAuo(~^z z8AQc@Hx=oWyUYbvWIiAC_6?!cPOAL;^vg2G%Bf{YWQ`!;Whp1**XH4nhOy-`*?z6F z+D42sh#!c<80b@)a^uPs#VExWki4@)(h`>W<^x*ohI+x(Yb1z^e%v1#!q8!YYiPQ~ zgbWB%W~gcto%YRw1+V`eiO&i3n|<&=itzwZ4uRiq`m`yFaNLaDP=_Un6tHb8@ zU#Y0m@0=>P4otGAI@4-4^n)}>m=rixliEvayKhv&Ga%zw4s?@A*$^{>2R)zm#~Pi; zUO|&zWygsWU9tS11`XriO%YL!js7d6obON~rTU)NBEgwOW~{r3F%IlAWY_Mqorupj zw)>NI2{|3Al1du4wVNQzgY4ETU!gKC^JFj=0VgJ7jtm}^Cl0|SZu@#@j(HDt%ovIC zOxKjhOj0l}gjTaJ{Dw0jjj=Pl3PZOr^3v0#W{Vo$Vn_kujO07oIF!W2O3;uk{}leJbClCzdg2l%B#E+Uyw;|>+Qmw!_jrapy~DpF(OHyz zt=?BZcvoCv!)REcbJ5Y1q5f!V|2a!joQJIJVIOjUI%qaU2e?pQo>lJw&5G{i z5oGrgA+!uk2eOMA1^n$4V#Z>e)B@?uh&T<^oe4wsYPk4Se5k9y4wc_59^n;*~@(nnf95P8y zfyNZv%!O*?g4IFArKo(kX0#$dTXRlejPrZFx+n@i+!J{-3&%qWBE>6P|jF=@5wJEA(x>LFR^L*B;`WVd4$xrL2s+6$&7|mm6 z9701AEHx;Wxkj?WVVF*^hLF@mxZBt2WKj)h1CtzZ{TfR51M=814X1je2sO?;Rm@Gu zHkE?*3j*R)zG6lE8(Vh6;r}kto{go6pL)@6SO?6xifxS3FBDab)*lwB%Wp zl~(duWX$P@Q9vRqd5l4j1SwG$fDA!-zfNzrm@J#We1r_eEab##NHpSg!&xbpbxYkI z{y^wGHIjyu?Jpo)p%qrtPuhSwQ5CM;Pfz3Q%>5=;E^7*BXe{dMA$~8rGk?37UAH;B z%3t{+!(lSlIKbCh04oxAbBeCLApc0+0F+hqR_SMMFpW*@!}x7G5rR#B@~QJbp0MQiW2)E>1* zjM`fs)U3TXtzDbgBenN#ZDNx~M2!RydGma~zxOzP@B5d7;E>yWU+1{a&vmJPZ7*;O z=?RGWbK*Vd00{EnL z?nR;>Q @TULbxZo27W7;syCG3~QdX7}|nG(Fnq>~XtIy>1uFZCC+v8;9^SEVm3f z1XK_lK%%kY=&#=;%0rTF{BID@$q1V0QVHy4`9&8|u-(>SwVRisgkQzLbf*|o({=Yn z>xPale3}Fkc1jbIl`hu<{EA8I}nO zSXL-<@1hHNtrR$Ou(og}f^@}NyUrB#QHY`QHwy6tCJcO*&MFX%j^Ut60UV+E$?^+*? z|4TI}o~2t?n|9(UqZBML5d^9A`Q$YPGn^Li*Tox|uV?w{v|vH6%$I?$bYXxQa)MeQ zo=9b97I{MdE13Hg>v@Tr;LmhNF*blA{8#a4JyMVVmr&Jdb|HE#ZyH%{V=w-&_$%JT zLS=b+4%>t8Nm;ahJB#|Wcblh;u3(2^K1rYTpeqJUAwKwh9ew-J0?-p4bZLfzT(Utg zx(|$hu}tw07qn!@)qo!|<9V)%r%ko8D81TxDU1JAAPcNqhtPwBEMV{Lnh^f+^d8Zf zv4B-hvAF#or*3#IFV4@63~<-Icz@f!AWRyTScpfF{B>-g`JA+)yab82+TGNVo}20H zncmshZ@hr@X@Q{B`EWcXZCKPy1@dN-MYG<2ioAm5de`fK%#qs6z~ttPY_;bcsofF$ z)T7&?{`%*RJdukU?wSiC*t54^uL;F|jV=cmT_U+)glydQRw8P;a7{!98DvtYG{AdY_{!=TXH zCvYuju}tD_o&`+E3Bgd;YYGOepi7bdW9V{P$Sr2R+7dH-Q3AOtyf_D;?k-R^$l9x% z@wNK;vNqgxAM6kWo0&&LP!;%Zyu+i*HOk^<^`fBuw(nvCa8(WC+VoTiTnM@oLinCV zz)o+gD6ZRVyWD{7XR{;U95&NY@~7<==YX>XSeYOu>|);V?$dv`=N`E4hx5NFWVyM% zki+x2i?gh5PW(36UO>-Y1|Z$yrpd4jgt;MDU}5z2ImZ{lRVG-&uAnKt{}ebN^gq7q z&5GgVNLU@-+XvfYbB9sxCL?5t*CY5_Vz;em<75$x3{x#Fx@p9vTby?mX#*~x1{IKX z(f*=@Aq%3GL^xg}0lZGyr&Ewvd+Fo!HB~=QqaJdt#U223o03n(OaXWFPSp8yZk+?s zy$lP2*guw~xGOz(_4@*TS2dx0n%%iuL5p7N|N*0DmQ6s-b(d3x1G4MEH9L_1`bZZNmhZVu)eUpc4el z0k@TY!396JX)hT;WQh44Betj~7RY*&j-v-92FSGdJ8xJ4I{#y-3ps6_DYtVt(4GC$n=o&} zwakon%0ccoLq#_$7xP)&;vt8cmT7kgxUVMe{FHCuOxYh`;_*igc4`Ajz1x}*$8*Xv zH&{jJjrKWmyIvdm_l!jc|IYOmdL8z3pB(qPIn^GXoPc+(h!=|PwvpwA_;I{$wU6(x zz|GXV1(r}YChtYZBq|oqRjkvHP?0w1!U2ygz5C<~7PSzY~^4 z7MF7FYJl5wq#EQ()u9ya@}GDJaATPD2~y0kM7Ont5~2V*lNXE7V85FV%d4r25a{{M z#gAJ_6F@|Oz?l~;2shwfx&+z(d7T!`pu~tvz~w~VV2`Q?86wFPxV{eq zAsVH6F5B>%?`|sYyl$phR^_lsih{ZmrbWHa>t&c~C9#`eXk14~epdGx(Sl~^@ytRV zy3_VVu5#va`hKqL4!4K zbRm8QtJuLgH-htyuie8rGKQ!r@8zoNw2_+-mOA6!jbQ$>8dT%xwU=E<-R0$ap#i^< z%%59_D9Y zw)wnt5_inm!iX%;3otQ3kU5;zr1zX_sY32a zoy-r7B?ahj3j(akO=KJPd=^-Y>-YOg_G00iQ34WCE)$B7x*9yEy;RfIrNP|f#+j9TS1PgX;}FF42KKvL`y1U~@$ zQw;`Mo9}m6La{Feoj6{V6fT3`ES}sd!Pj5AV2HmwO**MXYp;BgZo?Tre$CK|4Pf;O z`Y~H?S6_N1<|2W47mko&7Jo{xQF2jWWok1!?Xf$oJu>Y;-O{q{=SITi@Qh##*6oBq zhf1>&ral`N5OyG)_1H>tvM#<$V5qy^0694_$tr$}31jST0DFwsMCL0A&HMR%CTe@s z`@IYp-$9=+XeTPjT0;mF;4HV#JKnMvxbO)qiRt!wC;kTyEW+13Qb zKj+BkwORxD1uIV!CRIAEs(PC+c@VFuS^T|6ovJ<9P}@&X#0~(;<5)y!`SnrDSAtwU zaABs1sHL8Wyf^S!Dd$y{SmLG;rTuvA(LiAnz&9_))5K_}wm{lgyjPwuBIYZ@pKliC~TWIJ1*s;#SSxAJ{P5Fnd*>*>f7R2bA6XJ3!7 zF$OI?YR3@Su-_M?U@-2o+V1?fcv5H)DJ2u4Ged|f`sTA%q^_&2YnZslY*^lRmuJ1u1w`89jo(=vL}v`hu0ODzTaI}> zdS5StozvG!Qn@8%0oO>Go$io}$`^Ta(sEv3`X$oF>M%hgjC|2bhGz{&wJdp7gV9~+ zM4K;ubbY~V%e*MgoNc7Rvg(!9Awc6yTO*f6Wq9@vDN#MBChlD|@Knse_Tc-R|0gSW z!uaqbswC%3pC0w|+RY@g{PJG(Q({WPt;)Y6N*v<3r}yc}$(B2ZcT&5Pcx@POdEY&V zA(Aiu9?B5r#+<$C|EaDuW(q&FA`XdOkr@`(6SW!jAPUl|y+GKkG{tY3+04ro$ z;jt^Nn^n#6GbU4Jhm~<=>v;}?)<|lm`uWd%?NQ;oU=$;Vl?RmveZQ*;_qBTP;fT}5 zH7DG;l`6wP)5V0=3CL1@Lx|c1zG-GG2AoRs8UEfn+>a<*Ro~Y4n!IiFjrqCxk*rvf z8`x3RA<-H(p#}^zqzE$m8(=#l^|`QI8)7vV{;^$>$i|r?C^Djn1?uXSRN%#pnq#0A zx&R}gPqaQ4qd6s$13Ag`gGCsRd?E>?bhZ^f0!WE&Vqchb+Z@^&4~_ZM<`>pR*LS)( z+qB$#G+SKg*iCUVx!<9KJdlgCf3q4J`o^2&yTj~~M9e|oJCc{J4Li|do;bek>BAd9 zRBn-sb+PeP47c)E!Sb(RWciC2j7OR1PIv(axFJ96b8ou?@KBR6TlE-lESj)DpFG-U9CYiar3d%T9!HfFXIge!Kc=u ze(oYq2@qGrQ$*7`K+O5h1kC!y%EOA5ybv|Z(Q+AO5-`F7F~-i0ApJJ^SAboFPGUye zw8V{$NN0whkdi3cDrm^5j*5V^uzdetGaFuMLV(CC!p>noT-lra2a#Xd8$3>j=Z6WB zs8c(})Sw%EOzSj~Z!|yB+k1K%?8oQ$Ji2aM+59E$9sA=(;0q?!3jW2_TV`d$(oq*l z=L9NwT4Ar}MZ{e9W_(@D&v?4~2W=IDwC>2Ml)%e{_!Q)wi{$6rL|5Ik9!GiHgwoUo z=Qsa~_;=#IcZ;VvYtOkHH?XhlstrzA%JJ33-v+)hu{!6aaY6S>Rk z5S)%(<5}?v6JO~nZ-GQGD8u`Rf3>o(y)w$%AL1;Hk6)ZQdZyZug7W*E_cr{4qNU#2 zuP;T-5}D4rlL8T+)!l5ibAycEUvgKQGQWU$e|hv(GkubX*OYux0Q=YIR@FJ)b2ROj zxbN0CkN$N7$uvHTudPj`d-u&J|nqrapx{-SxM(W*>5qTgQ)EjuihEK86* zB;9@JN~HV2xGU!!i-cC~hLA^&?yp?vf~5C*DX~?Lcc`#)Z3Bzh zQx>l8YLMyx0iF5Q!xr^z^`e-tR{Va`9aQwUT5jv&92Axcj0o}Tq;t$XW@u#UYQi;cx7r?B|5&9XbN4b-OF$Q%W~TkCK5$%7X;$#{%k;Jm*37V)z6Is z@!LOjS_dt?S`jfxBx}>+s8AJbr`WnHYd^7selFpYX!pn&Eqp{+EntfCf!u8lyLh$4*Nv%%4hRll_xZULIDBmmYDs;m+9&Ns{{TGnQ zD>2EBog36oG4W5Q&PZzbvCsNsoII_u=DDHv3*Hun0~%(@GRAW6=yZ}yIGub2MNGrr zt%{l{i}2b6gNLP&f{RvN;2oX;`VJi!q68~|) zwjMAT*9myvz3rNtPViBsPK;`gNmW7B7Ze`D$xby?U%b{gl~$)|AQj9Xx$Hnn+U?0czFs*k6VJ|%3qxuF|*5x&8G$GGf2D2$rvrh?-! zTRvm=Yuc3b44F)gcZ<|iXpKzoUiI@?t)L@87VF(#-~=A=#HtLwa{L02%HRyQu~qw0 zEYndiZC7Cd{+U<*4-aNu5UySxA74*^KI9ks%$U{)@C_Nbj%^t9JuGGVW{;~iA@E_W z0Hw#wz7Dl3KQM%7*tuP&{o0_cW8K|;=IYa8^urwd_S1y5XL`Ji%IY*H9%w^Gh@bJB zkK%+Ix7`+AtlOZQPG$wR5>uU=j@dQhOin${2Kra2sYKjmQ`d|eP)mFI@(LCR*38-y z)>XhgpsBJo`2~d%)S2vJ;f+fVS&x=; z8a!8r#5==B&t(kHcngwrnpyXy@r68Qq*wof;u3aFmyn6O-5w_%j_=RhA`MvLgm5w0PHkUaFPu)o#+ zkp{KV00vy{^?t}eZdnu78eek zw7xF{Wd7!>H6BUt1Zqz9w5{Ec>>%Wcs)_t4w6X^HrYuZB-->}|J0qZ`1eXdaOnckg~<%G4!`i~_JfmH7jMky{s$5f z4))IoT9_r}dwOEHFktKA#c$2v@oaI0ky%LpDrNMnKvY^L$GI6xh-paDIyfnR{TtO#v<+BE%qD#KRi>#~9yRMk6Pql2;bUAfl< zr9jta@n@K)J|%badB0inzhE5?Bb^@miLG&ED+K z%w`@ZVXgp1rIE?+b}ojXJ^q9VQq+Q3r(_Jy6Kg5qH};EpFqX?v7EZD@NfYHQ%MENy zB%0!QfgP>geM2NTyU>;O^QRjB7e14{jz7Mm-twM|r~7sPm%WONUgm%s>2 z@=myqV?Mz%tKKSY##k};cY+m@^;T2v^qmSerY!Pxg_Xza8TZmNrib50Bg$`SFi;ah z-8uOSo{!pVk5Ny4D(!#CaebIb(NmfwH*u3-Ima3z8TfuhVoh0z6Ds0aOYd?HoQ&wy#*y`&^uY{JX#-r* zgv{(a@c4U^pP$<4RUp*$h~G76v$Uc3UQi4BfwSLatz09eN286jRY20LsIMDpq`Th` zQr6qZ1QHx`+l{IkVEg_A_w%N};fStRts*{v=Z&my&Y~Q$(Raw%g$et9rBisMEPFqM zWr5$fB?V{+C^eg^s4-6YGCt3gVg5q=ev55ukK%6eCL?1;ue^51%zK)_A(4@B0!=q%caRA zsuyB*A_2TRD|)VIIoE#3@r0Mq*yN5O?glW&Tq)!tVfr+&d7E44y_Zn4Z{K9Fkow1& zcOblA_;f0YNSIo79<;wuu`_*B_)iS!+LSoq%xohZ`Ru^fhV8pSwPP?P5l32!J`zFYG7a@FvU3qtnU8?enHvT`LP z-%KG3Y8961R5ReuTj3i` zesOf%XMY77K}?L61z)D}X4@weNWjPtSFOHLZ2cFRA~=rDIrJnZ%zk@8+*>1h7hiT= ziDqC}?6X?M3MXb+-L+ZzRWt;^O4t-cm2X+_WD5SYRuF#%Dr+$!tNefU!-lW_pNCMa zXf4q`eMK7t`#HDun^zn`v8K2p>r6zHt3Aa4Eo|lNX_Y!pwnl*`1wm~{`sx7f(Q6kO z2D)AaKhvZ_y2>xd1w}kjbKq8p0+_j5Z3bnjOHF_XRANnyas^rBXLAw?=dZOrnWr`H zy8Gd9kv`Xqy#rbhr%UzzzEY2#mJ^)Z3}Td(D4|n(u(PmA9xBF4mruAp5s)Bk;0be>ioed}m zJ~mYDYIdqR_}y*x2}FIyk_MgrePY<25UI@TmA6@Zy|n^!9XeS#SG9f+-hfYlYaqg* zxf0Yw+qPY4U|;m^%Qy0m=U3}rgiO$PMv;yTIjxcxey{R-3UY85T5;IzO)hOs8y|x^ zp?j_J%iyP7&+2Lhr_30?Cbn0&+5K)sKHJU>T}y04R#rM@KO>0rOrU1ew;!i;%buajwQEcsy!44E}^?V8zQr3u_jDLEq!eYc6)OtxIp@?Lf`O;o=s5 zK}i?lVd1_g@^Ge};%meyu>$S58Lqncy; z>K$wN7pMqYngfq-`Jrf0rs&@9HNGNmd{KD9YC{W7jBlm#W)kuge9TA&(Y81-$w%xb z7x2k{9pT=@Ca|p$eqoz{@`o5l{=+dn+YUodCXo&F z&p-0&Qd_5NZT8&7|0Ryt`~`>~B3Us(j@>=2OO%Bxfx&As^&ee}u&?e9d4aTeMca** zGjTtY7qPxlG_WfxX+^zz(3V*UH~w`?NE(rK)Cubd&79sAtaRH{FRo_b%>33K=s=a( zeUF&Wuogdp@_jCckqn|s;%l=vGd6~rSlFW}jk1_baz1AvIyW?2_OD8XH|&rfy}mhdfYIH>FnLt58nd8HwO+8}C+h;o;>~xdDlHbFdx*RbEY|&~MAT z+xXr;Yb+I9x73t;dP+=f^OmfP4!c6u-$#ZNa$bG}Xku%;NSVkA%Sx)0bxHKq3;#io z;Hzz^t{-Ib2lWJ=h^i~5&q(_{)A5PARQ~AS6~G;>T@&%gXz^0Wm0b^`KiI&eD!5=| z@%3x9ZMVnJuK*6?0_41`ZF?J@)F1O{oEa83QtKQg=6JE5|FA6e#XS;0_wqXAy2g7I z16~89&5{UQZOUTC95M)~wYC-Ns2D4w_Ir?<(qtEUiorcyTob%HbwD3)c5s+T=SU-u zAeuS)?3qhf8FCfA37we@IDKfUYWbeiek6<>Z_ALmhr_G&wlAaypT~r)MnMWad-5X9 z-2a0!IrE^xZy{GZ;8w)=1FV>wDySH_oOHpoNclvFZSNvvX4@T>2H^qW!Zz8Vvtog6 zH$G&A~Dh!jn4Euc*du{Xb;~ds?Pq=3s+FM~*SC-5?fYg1fridBk z=u8T(#}=M-Ll^w^7&=usGP+zs0}(@GQh_;zC4hN9ByQ*+B&qR*Pp=LMIJpB}RI=ZO zoBG+v0+S1l)hW|0hoVis(WOU2N(>9(n>c$tBxn9c`1_Suy2||SZ^Ret=OwL49PBA6 zK*w(!p*^bJNXa=UbY3p75M?!h0?4gK2`)76O6UzL;3qwNjJF+V1qkYBQxVWlBTOH` z@Ke{4@Sm8&O^OS4ZQ>e5BOi{MK`}eWJtL=)^g{~5hZQZf6&*Sxwzu@t^J2Fd?}cW6 zL$Ul!DhK|^Id_=6_xm|_5sRn0{#b$*p<=Jqgy()niTh;G4DhTMF1LsQ3Z%nYZ~ zrs6!uzVvbQ_QGNAj`Q$H_9c=R0Vz@SkROHa?X5~&F#k5!UCb?kE_fXsEVTh6=M77; zj-+m}gN4GIhUT7>3XDxnD3;$?gXn}~<0w2HE`4)v30i*G`@3f;WAVfF#xMpI6j&Dg z(;tu!Dzbv?grK=G|Hd;6!09564bSSbDw3AiE%|~{!6GOIUc#NqOE#r;fL%kfzGm5yd_0#^I7ZHv7ZM8=c_@-LSJD|1;n)1V*g2g);zsD z2k;aB|yJEP$|zz78tTnG`H z4};Q2-)2(AFK)n({7f3Z2-kTprGv77>OY;X*f!4AHqI?wWUpu<`7(u7ABS^1Z9lsI z2ZV-D%a@qr?45e~fN;BwX05}d4VCNNYuHl9gO^mJ7uQ$N+QXIAQA8Gb2;87&5nDPadiHKu10D_+ltvj55{@zKas zo4aQ;%E((L{9RU$wn%C^NkA`Mb^3>3eY}znlUsD%5lM7#QW~ZTlfsa`L zHwp9ZXLUa2syPhb`}FtYT#Y!sN+d+WEQ@s&%lGBXUGl@=} zP2SPm@+D;FAxuuz&3)>^qO;LU4aJM~;KA%Dckbuz+GwmGAtk25$6DJWoj~rOL@S7e z-EBgsqL(>=8*`YEhA~N&%9B=*3j7q$6sQZmqQb$iT^L{f1CQL;0RVMDiEZ@Kn&#@v zG?b++0SBj*b3&;qtL%v})e6V%thaU$&ubIvv*2yvD1F!r36$wGmL>}>H5}GF&LfR8m+e7ZD&oZ>b|r?uUrtU2 zSJ^W&cq8R>?l>eN?EY#lL=9z!g{ri<(QWGDZ`4o0!OU~3q6@1qK#{G<;@Y_#sNkyW z;W!s6qaK$3S7LF%=TvcYe>DQUt* z$opw7O82M}X?ej1|8n)M+)^>DY4gZG-2!^alw($35puZIp*|hD#~}Tuy2biMm%{OS z5D)eeG*^F-k=P6!4084;JlU4u<~RF#12R6VmU^Clh-wT(CdJ+J3dG{r6Zk7-0e#<9I zkN-N@Zx*o%q^}xEO3i{WQfWF&7SK>;;^+hdLl3JMKZMdE%u>Z%Z``?YoVTIt2P|z} z`24im%xsbRKZ$+8iMhD;iGNsdviNd)KPW-%KBW!L@cc^NXR`BA~E&IBTt?EEPz%sv2P;~mT zV#73)eEWPnw60CnjM#SlE9mEGZX5gHA`}z8GqQN5O+k=g7U6z1K6^7k-iUAQlwM6( zX7?Go-isuDT=6>*>j6y?N7?clh~EY2Kr)5pM?fR&hUcu>Ohx@A8Gm^mB=g+GA@Orp z<~n2&Js4JFSV9OPrRx=nf3UopP*Ni-|lbX)2;oa@o5I)kKBDg%VPC0uq3wAu;0g{I0d+^kU_;bYO$>_ z;*djd`Bbzo45c^UU^x+1;~jp~i*~0cx)~D+PN|cd(W%R#YW*bDvzBDxv!TDJ93u1% z8QAFukQXrELy!z92ql54j_5BgNxuiYent>0z@I;Qx@kC2Fmnka5d-)QDm0%`a7EDU zJj{JYfS+7xFdwdShppKuWv07+2RD;H-Tp$Sc23jYBThZ-u5e6+k4L0hnBf8?zT!jg zM;pGj;aA*L!86U?B*Z-8j)1_N%oD{@djqPR4{sg&kX4%@nVU8au;st^a;n=XdWxee zuc?)$`M5YR_xOSUi;qk$>qLSJJ8wvp1S<7cRcf!OE-C+!=!QO=o)y31H|DWZnyh|d zZzZ=9l68~{Bx9;*#4%H*Em7o{sJ=zHSIK*)Ogk;jz6l@Mxwm=ykJPp6%hF6LDM!}bVH@nnKwtJufe2HiHj7tp<`WJr0 zYVfeS{YDx=uwb?q*}h0%_&Sm}ZV^n?RvfO@VQKbUeD~pq_0_y9iMm|VdSbM;^0yx-5;n!jTORta%cVrS`S(cr`)!tpD77x41< zh7U22RndYC@L2uXCllfZzaHvJ8yoNpacZukUjPeM$mv+?sQjP#)>7adW zkDo$5bCuhkbG;8L3&eTc*e}bhk_W$0w&U97*;gu;NHfAp%SCgHp346&h6AN~%z8RQ&Ld^h_hip>SEB)?U^fn{sMa$PrN{O2x&k0Q_sO+oF#C-hI_slo}6d^i*S2KqKdzT1v%0wzjWE#a?kk zEN7-6DUkAL+x-H*m|R0e?_oD#hC7vkQ>T6O>)=Cm$03%D{#bEY&Fx(Ko#s*~IE1## z+lZ%UKJ)@N`>Rurf*^<2y7-mBMM3G(QGTzy8qRQuG`YPYzVcDaL!xM#l9a6$iA)R4 z$C^A#sP}&Y`xNKhJRtBsc{9U45#`im8Me^w505ZSk5*W3fu8f}uaFB(d3Z%RY%}m2z^?umaanrqO*(0hKf6>>k zpK30;=?%wc0y0y&@x9p%mU*wErc5}!g)R%;Sd%(quN8(&h>L)_ao+H~+wPK@4exw| z$INgqZ7HTUHYR$R;Y#M9ct=VRXyjW z`_l~5^v@i4eL-@F9+{C|n@@WpyrwF~>Bg!qKRzqbBh6pZ++}`uWtsOmdSIS<16mES zCpYw~`I_7k8qED$3w&#n{s<%7I zji~v9O`qNd@rH=hLhWkTG19;#G)1odTmRL2Ii&8hK2QlTGNAC*rz@wc@?(Ht4(RXK z8)Fj;o3dM863te;xEvjqul5WiCJ5LJLbnxOj zvFI}qO15MH`UStikl3_%?U^p}=a}hE>7A5l7ZH>8Oq>mVc2#&!^(ClJ(|hK#zzR+V z1=1M$v!87gRNO>7z!OYM{e(!=B}h(EcWK!{UoL%e-m04cW>%?2X)_F`q-SF0`Oi@_WbZ0VPSlr>Pi2 z@H|6Eash*T)~>L0*H#f5_N?yJ{q&9c#?H!|*T%av?zSHA=(y36!z}=wBu=akA*9JZ zs+g(JYVsC+yZbp*#rWw0xGH{|GMVq(|3TyN#IoiaB$4;gVmRz+;}}s5&oO>5BK`JQ zb+#q;$tC6$-62Qw1DnGgjH$#`9(QC26m$uv7^hJ?;tAkr&ZP_u8KE|=| z^d7taLdv*)nmBS94bpwTb`6|A2gWeB+3dBA(SLF}2RLq4P#40XkdPy$vn*o629*#FZCozw{8@lP& z$}jtswVKgogZ|L5%jcR??u2N8rpWzh>0xm)h5U-3IQt zBq?5r8d(|&(=*vx8>klA&`e^F2~-S(6TBQ5vZ-DCoaG=-FTgS<@mfVKCT0QO!=(;k znU!}1=mU;CH%SXr=PPwBf6K}DRs3n1nTqY1 z^LJvFQkUqf5W!iot5kUjm|!ui^wdEVQBQ)BAUcHEczX9aYl!;cQjlI9z5 z7hsG9V0upK>JcgYNrUf`kxen?wyh0)CZUtvcdvs}TA4sxL7~~6FZ%oC3W{3<=LO+o zqa3*sfPtptxBVX`V%6N(LpByfne(QHS7i!ZCnZdo@bvM9k+&q?6Mh_MZSXU6KoVL}x^2bHeJX1)`H1qqR98Y7sl4 z*|8L&o62)1{P&XInbTqG`SC)vg3@Vc7Ii5AbV=xc0K>J>-*sQX}+hw-@Q@tlv%-d`h-On{Cg5!gmV!Hy@M{M zjvMb=*%@8deD5;wYH#d%weioE^Z;V z1=JxJ+2yO!W71Vp&v9zCj3%Pp!;(spqQ04#VSwa}f%0#c! z@nP13vJ6UAMlK-+!=$JD95O}PFHhcAb6MMYeP#aUjE3dq$`q6oHFOq-(QO9`gKQ40 z5L*>%)oG_oE9&6^q)1&%uV;cHbk^%iZ~S>w`IRsmtv*$kr^K~6azV&;Yl2QndttG{ zsMDz3hI)=i`5~b)UnqNq>r=8#P0kR6Pd6`Y6Kd_K9e_R0FyFg+RFWDlb(82Hw~lu#vC&TGkdOUNyEBW518^iD~f z_GlR+%;yDvOj4B8+$_PXNy_;kL*@WACrS8<(`An4Ng(9A-s8QvM&Hy+xap~Z>F2Bt zu+Q+@aAmy9Ddo?TS5p-kCFM2we#oWx(?C_;^3{GQOPZX-gN*S)DVpb6HS1t9_eadx zrqb4IXp!b29#zu;1Ke|KEis9tl-$lI(4R)X66>;m3y2%l;rH!pWRT(A`|i*>(xkYCXKzRKHc zeV%|g@3p7}@u{}6y~ltim}G*dYRrME1b@W##oV{vbvx@R(ACGL_REk0$=OkCVzByl zs!P5L(`%uYVc$(|cx$Yl8t9Jr`Zgu80QAz!O6rY#%Cby{)vH9Z465?XUZbvx9VR^` zPZ-}bf!Lpm9`-`-hQ=eaO}nh?Pzo)lRyXLoSdNuP52J>PThU7vdi!x5Ei$dI9{Ufc zr!*WGuV!3oA+I?Un&BTk7uc^es^q-#Km22atw|O3d=p4xaYs?<(k;}I4)6Hnv>OpI zA{$KEHRc?w*CziQH8%BDC-YX}@ML6k`UmGrNopcP!J#)!()B%-(vdtOZYfc2=qq2?j*GG*K~HSv5f@oDV*A&>>v7G`%hQw2Go9(Ip7Dn z5VhmM`}|m|?Juyw%>vb*so0t|j%n0id1|gr27VIH-z`g@tP%1&giU@UT4-eXBC~&O z05I^=8;X0u3S)nK-+rgv{&NuBGX5SN)NXB06&7Jh*ms_1ymuPS4pLA{rJB)~3=VLn?Tc`flPU=bg;j}9$;DM1}%YI=|@^T{3| z*I~HUZuMiWT6^->t@BAGn}LXEViy@s0SKo!m@!ST{KP1(hXNW-7QFpTHqI*QA;TxF zBlOxe*!j>CKD}hjXftJgpkvg@N!55~u3;#)P4${qCMZYdIsb-;n&VCSLE-ziOq16c zIq_m#lP=v2H1alr`;S{}tP-o95aXc^y_7eVf%sr6;vIm7F8K#L z6fIE=-8T9xSBWRF$ii=SOS>;S)gkJZlGeO-(dp24I$0e~Bh5Ji_BxN1^g~r3sb#PS z8;Z#HA50Q!W0^P)EF_hE8=ax#7o+Oaav=5~4{3WZmzcSyyz8+%fRSK*(vBFhndMOtr7vl@XTB}Aa|4Sm+Hx{PXr{N@|AA;ZaWCZvR+B`xgwkz>YFy-o;-80H;}K)-|-+ee&aot_Ui;9DPjnz^6|_Rs_F4{ z{rzfZ*X#6xbY$FOQEEfl3K3Uk3O?iCCMFzp#d!xpG*?mPy3O9(vUdYYQ|73g!P;xN zGFMDYoZ83`JBemXbd$>Xt=W8YKvR+48~?M@={HB4Id)m-l-kFeg`>x@_d zzphuT&R8#R7dkedJ~r9lfkyD(`Q6-5$L@tw?M*cuOptftR|b{dZ$TJ5yB)T&x1KQn z@z6Ssxg^xK^Ye$bJK&8UHtJ~0aZ-S>9aE=fU2Rdiy)~3~yHWqpruDlJ=FRIgy3y@p1WQhyQXsd?h=0 zQ(!qKByX?U`>oksWb;*Mwy|?|$+W!(l1pUT-Hx1TIkCDRdWT3@M3>9vyxr)A1VOO& zx!LpmF`Z_Yi9y11st_}-A{Y9@G(4Y|)7Hjp{SHWuee(5RN@2G+qF8Q(rh)w{Wt*}C zH?IWGIJ5D0PFS@Ld64GsR}{!3B$wx(nGOcAo1=xXfvs;bThJV2?Ku}u{)SCz^BPRDt3$55HSA$}=<5)sPS$^j zD7JScG!~`Drfo0KdY@A^gCL(=t;7n&Q$5y*>Xq5kk7@dH>Jo&rSYEP9KjH%#YPkxl zRfXh~Q`q!fd3i&khF@By2&F}2_@OHeRmP|MiC8w@ol-^WfWTH%AfC)oM zwBtMNHm9278lLZR93!rID~3wFi@t%p|A;*10aV{R-Cqn&Fo+P#+zs%ijEW__^V{k^F`iivSP@G1z+LM1?TMO|B7576 z7^P`V%?+Yr#w8q?MP2+Nl7*j}7TNnul+-u{vFRb+J@=)ABRh8ixTBBvQ07dFFXf-qk$o@T>*C#L6A{WA*EwpD8F^dw(T zQ+DI}h&pnss|j@+tH%YD_$sN)HX7qxR?r}&1VgyBzuCYtK9eh)^Y;ieIe%UZ{VY9< zN-#T>xc<=lDE6Pq$^`xnJ&N>}ed|K{>f8M+V>Nrhxj`Es+5(||6|*#xlcNYN^GVyy zA-p|aMQ*Gnkh1z0bP48&B2Cs>a?p5I)^-sP=`OtqkOY=5PVtQjvQV{y$O(bJhub^M zWuIL8ksrLGMxs;=yS$|T)F(y+5Q2|rzQC>=4-XHO)1;-XF^>5c{j%9 zRA5US6Hx3}3ZW_)52NeorK zZ?;Rv#;N=foK|13t*BhXaM(Y>0Il=Z2u;wUIW2NKG6%ZBbHa)bShF5U{e~UMIu_Op z*jlxPeIo;d%@j5tD&>e~Gbr?DHHDs*`AL$WM4@))F~dL=O7 z7^&*Fs8ZwQh9p*wJZ(6S!tEXJ`oSkX>GyD|XZ#g|^gCZxh!{SGcMa&V&Gvs~#6OQI z=dGU#9kNeq$9xmZd!#9dhlnN%Iv4D-{oyw-;RySo;lAq~6bO$z|2O?4hdxpk?%NWi zSYvc-BF^&`jsrJezj;DZ%pJ+tBO9tXG%!1lrov8i>Vn=s!H5^xTKiple8(w2#=W&! zOs%tIZ%CjhJyR3sD7>Cq{^UDS*|WiJ|8o~|sRncJ*Z@%iHT=`Rr=My=!3wBzGm+s5 zY1^C5=qqEq6p-wR>Z`ir_~Y*=b`IAAeLy^N_cYlOY=<`9HQG&PX&+RaUR{UM!|DyU z?4@75bxzeX(2-l?j(=f);1iH=_vlKlcfP{Wduw zbx0|BC_i}Uu%>4u8*N1e^5QoBMs4$au41d5{XEyOzT^fq5$WGBZ{AxCz4?v_=NyfN zPgokFDCc)LWbw2a2FX@ierV!pp%03pKB;12PhNi}drkBu@ge`p3CG8$$Y9OS+UQ^! z6)14ag!ywVkF~yPk4eP?x>kFBgEmzW&K}}Rr@=T)$oEr)jK5}h!Dea=g~e)@NIvWuEOuu4NP!UtAHg7ibb?vaM) zd3>^v^pI1v8A<)>hd6kz_pu@Kn_c|)dx;Li-qS&?m_L#J?KFus3R*n$Q5HITEW`p3 z7BwD+q!N-u;5ml%D@1KJV*h-IQr+!D_ZV&59Cmp1{EI_H&gPH z2x$q6kWM~wKddR?898Bj%1_)68M-qHU%mQLeL3x{=z$he>Hf5C@( z6#|8Bcej>D+V}YO>dDEgkkizGwa4h6tCvK(RE1l(4LF^f{Y}jhMj+TY6FgBg#~OLk zHg-!)*9WhEH(ojM>a3o3M}g#Eo8y5E#q_pQ%>Z1?q!|CkvS4!bmQ_}G9g}wZ?lJ9k z(cMmXFGwh@PfDT)g)gymVA~rj9%V(>uymuk<`Vm$fcD~g|ZpN3#RHXO?e~ZdUN9=}NtjYmcwwTfBLbFY8 zV@_Y!$07UtH)sOz_u_ep*wBxiPO7_Uikf&XOk3K@={?K>)hKjdn_s;U(24kiuQxW$ zR2T0yPjN#=M$jCITm#ESc#ro9!?j_Py432)Qq5dhADUW6^-SeZuLYdtS3}NS`{Ygn zSJ=^_e?Kg(@?OwmbJ8LDaAexc3E;?8JFqyx5QOw{C3S z_kzKXvUDRv=a;CxGJOt-pT#+Lj^jdH_aRoq{?2WB3&=to%Hs^D_wl~HzpenQwSesH zw)n_9I9Yj|itg}>OSFEgC|uWNx17t6-syGx1ck8@g_VH>>Swndx5;lDi%5fqUF_cuPFwPWOYz^ju0~? zE&Xyb@8iVjYiu?^AW=UnlpsXWRi85Ui71eu>5;7z+V_2V9KOt=qQ`x`?e zaKHRoiTutdCPC%P(r%mI+bCwwuhG?H#)dva_K2g+HrYHW5;sXWvh0TEnk4-K7ReD> z#NN(F8T8ddi_*9qGgw$_VN6Az)a#NOn}aSWu7w)AAR;%MdAA%C1pFc1KQE9{B7;ujO76$ND^SmpHHGKHi9I&S;fR zMZfcvjY6|)6o(lsyD2w_`Hufa>Yh1{VQDoFmqW^K82Q|1A?w}B5V$${rBJh<8&xRG z;ta6pm6Fi;fj(h}zh4CeOOD@qXt3?jC;L{V@ADHk=snrm65Y#tXqcyY9AY1Rp9;^v zEFzQQghjRjp;HD*mU8`3e;c>8A>IoRRe?NEhkA6LrMyw9W~EHOTu}6He`+GIy`Tup z$JV>2IlLP|wmIQXWH2w+1?FM7ZuOHKSSb2Zs|0O&|1^dA=v|K3NlpZ^Bi{ABT9~a%a6auT%;2?UCL9KM{qy*)51k1-A)mUR97Im3p3CV&&SBMfs zzf4T)fCTRSGtZyjlROFK)EP6PH`JIuL5{S9wFi@&qJ=uobGI;P8Ew0k*9aZUq&&eA zC4#pCCoi35T841;gj4}E30Rq$vQ7(MM)@|1oUgMM4JOlbuBLpZXF5{O{}WBREu&AEhPC=XPlvD~ zT)2z^$G0?wh|fa*VM&XP)3BcpIXV&F`&}laB$u`;nUmY4-|4JW#G4;{W)(-xza5yN zB3ERMJ=E!&WPckPFfhwr)crcVYWTW4t)K5-2P4xz3_(f~# zT&k2A`kATG6qAJKYVf(;3?Q*0+LIo2A>)Sscx>s{P9fPNAAhX92X3{ zDS75uJh@n)daiFX_Vv?=j0+wnoR^%qGkesCZ@&~j@p(xKDb^cMsDzEBF zp97N~WxP2$jKtx0?@;RxY_a$&JonMdi|C;q4ekgHHgy@(FI!yF+=r1j_m@dU-ah*b z_0Adh4-$rqe%u!(n!H>4oEk@?cIdO=3GQ2;+M?zLFqZ7@`7Kl*KPQboGjToPTqysrc~FyT(#=uxm{oFUmJ8n|$wELE*uaGC30 zHG#M`-kjoe12eT34j63n_9E64XlT6`7+Ct-q?LKvbyI#J(Hf#R@vDdGsH2AFE>8Oq zpYHC3+X&$$>aPzCz2N`i2Zc^UF>qco=F-zR=yhELh^137;b97GbdK}j5@M;EQ##`P z_9K`ui4GFRb#G4(X<4=e@VTHcVZgw~z6G@J8%IXP56u(^sezOKvfEk+q;#m|5S|0| z8wjKh&07Pq)p-nHNwPn>+yB)I4zJu6(VpJ3myI@JA6Md5X)d=wb~w9MUTCN@;y((%hKQB#`e3S1{U!XSkb-zAa4T1Za>f@*q^oqf$yhR!_eTkH=r1kgHw} ztb6fXg?W#r-tN(!c=*;V3~K3i#Dy$r?rOFe*x(1Oww*xN$<04aPPOdmmVRrKTJEx( z<3MBoMoq|${JwY>Gu1JIGflrB_e`AbTc_gaJ0lan3fygOJ_T5VM)AH+=%ueliQZzs zfcOzdjcO%z`HR-V-CM?GM^f?n#lsvt+E7>qhLXpOhDzQ{ zh}vx4KPX-IUp)m!_n++5L&%>0@5D-0!o)l<+7Bn&oM}RdIf?)(_VFXv(qGNWFFyUs zMitmp4^7D|_EfLR-l+xAECEp!uJy2L^Anuk+9#p1oN6&5EAHGE8U;nR z9=+dxr${geO7WsyPcee7dfW660y1RHC&aQ0n%zVfO%D^ygx0vR8N)>AUm&WrSOPF@ zPV>w_I<2IxJdBpGuuvm9)@I2eOE140ubK^hdFi2H2ry=3R6H9%) z_i)z?bl?+O(?FxC;uVtr09v!&3JT0%zEwtJIO8j)$&ZN(rYSiiCSRGy$EM=m_8(}B zsf*pHYPi8PNr#$y#kSp~g?vQopIQ5WGCY@^smmGz&a_-TQoZdI>o5GKm6_n@Z%ho} zs`wkSI7ueWpNh}KN4UHMIcxjRO;UBZwrI#XrJ1c@exhPzS~s0VKRr8bd_w$IUMK@4 zXwju)bud)D>7~!OrL(-x9f8H1G%rz3f=w!Yl3z(=B0c>@mX6F|d_qO$zZw@GQ0 zceST8O@1HD!UY>m?ZNGgy6^-YeT{p${=&w({Syw%qL+p zij4z$zF>xytchzHd|E({by@z9Ao;D##fOJ+=&$>}>;ltx-^2p8ET5~~-x?{iB&XUjsmWJ--H3l>3<2TQIoFQcFkkV~yyk)s)EVGS@G6mzF*u-Cp*ongmn3 z69uLXe8AdnyH`wYy~W1I3rL7Wp*DTJ?`y1uUK)n|(#wY)Y00F$$*dnb5dVOyk3t=U zClAXiuey%~BglA!UyjCl%>mg*RO{%_ow$eZ|RJkPoLk@?I zordA~IUd_p0}C!?RmxzQM1@`PxVKE2-s5*}0}#?JxA96}+cteK?n(L1nrJw-IsF4* zDC39vtbqMOYpWF|5Ga2-r52L-e3|~MqICr2`07S|0udk8r#<;4v3M3;@D5y~JM#H0?$zUUxN8$+LDzu!nK+00&kmtiB7YeT&Qz|~zDRDu z?*?^LoDHXok6aN86&3_b;hiC=)^Jnz330WScnZ!KS%vP49;(-`{-uo3G3{2IHnHaF zClQubM5C!2`4*d0BaeF7kf*_slA}{m4Q#k!0)9zx^GI-w#QQo&4rwtaGPQ-_n|5Zf z;5ljm05i$8W3-SQ{!VGcLa_K8Y@WzQ@$ErxG1?MXQM^#(u90q{dQk#Fl!h3!r{pM0 zK?!9TCLY6(1Z)!uFU3Hd_($oHpAoj^e2v)hBoP|z$Y zV+{d@2J=|?vEBHum@wVPWn`NV5zAo%$B5n5)CgF7T|5N!_1G5CPCd1fL*~Ey>e(jw z$PBH=-{n$3l=$3oYcRL;1$c-A^Q~2rqz#S{#kseRF^l8dl*La?UTlzUh;HwAcn1ZG z4D^UCw1iH6T+shgM4tNj=8Xo)F1sAYclLCM$an9d(XIDnYf&DgnVe~O6SW9HL_@;( z^;;a9pBDY>-(MiaV*J7*eFap@_v=6D-7t*tP%Rlx^jPJF)Il2R@$ zD-)Wjr^3I#%q-BwJhiMv=CFMexTwfX9uYhxmQ{M4t&568nRV`N_NuJr?fef#){J3| z9Z9b`Yt6a-{sZ;a{&wJqL`@NoQQ>9J%-7F@zvaH;pXAwX((_^{-mNMKqpuu$MFpCN zS5?p2HR>Wbj%F4p*W^cfiM}aPA{<(sbSdy~Z%6ame-n=s2#{bNn5X!!-k-~(%ovYF zBiP_5Bgp4FBTCtZvfoP3V(RaBWYXvK(o%I|aXAZoU&|TN1(GJYXk(I=QUWQ zrT==9ko@4S*$_Pg&5-_rklAS_d-M_z|6qSURSYx)Ro}krFLO|mEhGkF&2f^&JRI(t z3M)<|&jGQ{Xl^Uv`y1GnXPo&99r8uS&;&tVm)Nio!Dr4>Ac~!PsXPAs*qGoDuW?L4 z;%xy;lokTog0tR@Wu?BN0-@3rOajGwl4@J0N}zSs8y*vVtpT z9{BI_EN#6gw+3Cm@=HHcLlqJ=%hT79I~n(pG@l}!D1G4Hb4SS1S|kWGh?Gfd-Vj}` z_AOpe_LKNP=3B<$-NGxCo?S_o?oSLdyg^s~d&SHn13xgWK5TNIxpG9xPp(mVcADIU zGlX5Lpo2>UutI=Zm$gU>C~W_t_B2!Rif&{}J?*Y_#kGe;p(j&!vNq(j3^IF3sNI&+ zpQu);)Lpbe!tvu~Qg6EKL7lmV zM?!T*zcx+i!Yr#Wg&G{}I^xA4GzZV|oP3>gPp5<{L75QWXl7F`a`@|RP8<5=?a}?Y zN4BuOq_1ZOKZ8@sqqp;FqoIX%Pti^oSveg}^5uI+)p0-@sR|33o_!mrgBu7Y)D9n^o}2`j$(4{EW0ugTG#sQ&IPw#*`KO!0Zf zuIw_k>H1YSTy^)*#RnuZ|4b_u_lP+%xwUza#l zU@1KBW9y^61B&|ldAx&OK|0&cU%NlTHhd!fa&3a2cwbFvL z0EURbBQCCdE0QaIeY0^3(xAuq(e(>z+W(bpMu0vI=6_aOQTLIk9vja>h~t&YxUPfG zIV##KFNF5 zDa-~$lgC#*@3^nSd)q0kC60ZH?yI?M4j1viiI;-s7yA)chvHQX)DX7#oOM~}&8CCCV?mdPa!u z0tD4h7Erj>DJexI_PRb{iu@_29Q{p))CpBqw|)iNQ$}{CU90W+-W5YndImXh=rIdF z1MolnY}a1p_nuzLq4qyXWR0g`s%L@)@o`6qH)^iqTs-qBpNlw~PQ0eaJVJJn{Z|1` zV=j}2&oTH`rTIifJDBE4R|eks)cemGF3*ybP)WP}$1-x;z34ca#(&MXo#Ux$CQHk* zGHOEUy@#4J55BRP&JK#PuR_x9YRvDvy~B*g$1F(>{<=3h8d{#Ah`Sr5^JGX zc+k}x>UtWxN6;!Px&cQi{=q2MWpD-*+SO)uFog`_4Q8IL-y{2H%G@IZFS^jdjgS<# z{3>W2Yw~^hj0ptG_xJ}N*&ZyN$wx+?uy-8p3ooPpBX@7#r$$hSMAD2^-GrYs*E=gc zh@(TpmBc$x4%v;~Xo93R#5?Rsh;ylO=PfHYXU`BzlrwpB-DWSGx-Bc$OJ5Pe6ZqL3 zR@Qj1I(IOXn_sZCDVw-xQ#dE4Dn_fa`Dd40uTP#3KLE*82tcu$|M4DwUG~IF=EG~C{+hehzj4t~`wg;XIvc%h;?Gs+lo^XF zgFGGRP40kf=)+S#$Z`}oHe^GWTsOT*{YX4c8R*2|n<5z*|FKo-?3@TXq!*dTy)l-H zH0l#VhlxGB0hoU}pUh`=3vwFbj)i>h8P+BbBlUj{rn4*O{@i@ zLZ`?t3};UPTp*R)Mp9tgj%aDn6@wozWZdSZceUC@1GmUCr|k{xXlc`!!PVG!OITZH zU}aZ3A@%aOu#1B)l^=EQ7zBcuW;w6lS;}f(I64gU44@?}Vr^RprjDf=&?3~>_pg`{ z+80bJ2{}cb_SQt3Tj#7-%hp-9W#u=6G-a3h`IyJlT(~8^{aJqGie!6rexZAra^0uA zQPXvE4~Ezr`Ihiz{Ci9?kKOn+@)Yt3#tvAQkTGt1S0>O@g5mke)v;rh#C}3BbE+Ti ze+5F!=?ffl&-H4uZup8;)lNh)F!R#?`=43s_=Aola>7F?pbcwQKkfDlzc2T?M|-*R zl$%b~USlk{_M`ru`L~{2Q@9cFb6~Z8bla_=TEwX)Y48M)Y^8itS7NxXrCWVi3lz&e(Ql+=?7WtoxyMpKw^DLuizBr;Yv5RSB#@%a5Pyp<+M1>jq^r;?^S+_a*y7EmW-1lE5pCpo!7`L}(YdM8& z<5x+5@6a8`fq}YZOH&>6-|^ZAEy1%3+fl!~0Wi0)k?=vtni4?aBy}G`U|6JR0jpa!0o<0^G^4wFYIm z<`_MF@s>gU{R4e`L@V9`>%9zDr`B$ijA>uBtJa^C z;M#Ud^7AgQ;Y!ekSNBhKRl;w&lCQR(N5bu0lid9HOa)J%GuUv>QQq>7&Q$13EB2+A zTEC-xLYS~?1RQ_lB{1xKzPn&+aj+XjBdTQZlhJ1-43|T z|B2E6f1)AgqX923N&3`s3LUNG5WU{fx{+PmpTN-Vc&&yfA#Rj6H|5mrv9n<=n7wFv z?N|dFkPQsPvVf{CO@T@maSa$nH04?JWHa-f$^Whc?|DsiowdgwDv;pVAMGNA;rpHE zz*`ClwyltT$i2LJIYZCz8lj3as)*KLA%Y#{hdsaw$?OhB#nHJYxRGR1lx^gu2VzOK z9+d!LEy2Vonj{n&%($(j5&-CL609J5J8K{4Q{}pD>HI zP7XYvE*!wdtDr>yv-wziL_v+o0{K(b`~Xnf9o6g zfhPtBI)k3Avt|tYhN&l)1gfxwcv#7oRiSKJ+~-u*vmq@D%0Onq!XV+)JLqKP&a3-y z?2BBvmy&tU^q7#$XmVTa7H2=}10gO@5K>wpHbg->Yj`r~GIS0lX5Q(C3+(}XzYHv3 zLG_E^?NPlP(f4^lFEl-6z4*n@L#ini9SL(_G~x}R&Ea6S2BXHaNdnsM&{Ra=3oG5dNv_onmXHg3W zKRUyIj?J}){Z8xiowSSm-x~=gu>Oz0xOlo)2cIngC{ZG-?Z5TjJ#%x*K`l0xx!CF(53$FEL+N%r7D-q{ARK($?w_X`PhgU)*nW zvUx-UT{tGR+^PFIGnl`HC}b#=I3DYBa5`b`?-j(qCBsw(W}@zJG4r-OWI%s1yAh?4 z;x7YFlb4PEEL>#Hbj}XWp&9IRsFfwrrR?_kni6>=OYopsEh$KUctLLly?YJ@tnxY! zMp>Sn4zW;`M>o}xT4Jh!Wo&jL4CopdS^tXPntg1;=3p?Dd%3JDxEDSGv?sdMbwbZBiMxaAu`w93y zHlDtUFJH0F7kIlWdW5J`5I*oq1|kX!_xx0RT_oo01P^Z4%2|=<>iDx6;;Ad&7erTi z-|OfWs=XZP0y`}LFV*r5Zs-0kz@nn?Z+kD*I zF?+g^pX@aWkEk%jeI^%#Rr!C&Wi&G(Dz zgkp+b4%1qz7_TN<5}qKa^PqfyGt4(1CxM#uED#XMTqlqK=Ib3*!IAdG^v}USFDn|J z7c*zeJ)V~6AGy#vbRpdmmZUm=98f{e_2RikwIX9m^BnBG{(nq%S+j2ovHB?kr3;SOqPw&Vz zJTb9&_EF?UD*jI}+nl_fq*X&3KmMoz8EFWvAWZGjP{h@6Eme0l}b#w;&39NrlZxv%+OkDg4j@* z@MR>7Da0nI)%I9Mhg10vu~_I=V$dULI&pzq33CI`;5(CKsdGu+p6^CRSACJVPmA!f4pPK+ z#$j>NOl_j~CDC;*zz%Qr_y0G}_sj$95 zpVmyJIKzYpu4JjOA?Z7F4oW1!_(MrZ38h+<;~>@Y$<{^JBQ)O-_!TuUA=-!V#5y>F zn*f7sFT7? zL?$#w=~!~|5!|1sZUsG9%<{_#zg~Z1EyLT%!QAkVFS@EUb>Bl=sWw=8b?IjoN2PqP z7Y)X1=#+m7sn%Tn(1Ou1)V67mK{k%^R&PrfFXo*El>%O?K?3q4!Hkjdf2mJGMlY#W z4CVEN8t>KlC7;Be_YMIs_#fM!f3H~%9089D`(j`~rT0veKG<#N99W>S(f)z=R50Yp zHvvwHUtDd+ptMseE>qwQkaifq;#jAv_-a42t1iH*g!_Tr;F%kivPEn1m0m)}e9*X6&;w+xr5d z9J3J;oxO7JF==?F057^6;Kz8pd;C90Vs;){po|>yy9K>&Df&E^^jQ61O^A0)>(AzB z&>R`Qiyl%4=~jQg%5~aQ?v}a>_Y2|mTs(E;qcTwPJgdSxgMaydX!AN}OdNox1?|mE zglkR}oe^}T5=bf3XXG~6Mj(0ca2^Mh`=DQ~GR@d>{4Ph^oZ~v98IiTfZ8>_+ZRSix zWaGoCB-^iDknXqr#g(4QXG4gPo19P?HT^5VM_zwMElVRajEw)V?-i|{x7b$yKSK$^ z{~AXTRHvwl>bN;4G)ykRLI00ZN6F?RtQFVcqBJNc~2kw`M~i)}8SL3=?3~|fNbl&$3V&rGiAL0EytjBW)_wJCmt%MYZfNn*GcgnnQID`F zV-#g6kc)fQx2cD9QUTvOA5|?j?_KKH2Vk3s$?X)Hqj&QLo}|G!3wMA!x(IvM0%Oj5 zRNCtP$D*P(tE_mhx}OyZs|p`*#Zq_dx-5H0DP@yGN>~_^-&%j@evsn`y1Mmr;>R@L z39=zdi0)ola_m(&QqpB4w#w^Q#08Men@$CoZ`{CeW4h^@eZZ$XB`DQUxn9>Zlaa_N zpzNY$6`D-;QgXM_75LBDOp7gMrjSQW1%nk#XL~6TmyAqL0(FJoq3|ls@=+JDY!n%L zWsom%YCOTPR>iv&kbf$6MhQO_cevv#`>t&EP6F~d>m_nCXqxwH-8ZHKpL~1m){>0_ z0?%mvVH*lx2YDnWbBt*=QmJ`-xRUm2CxtG{{@H7ilnGej<{lOtNMRkp;Yp~lcS%!i zVL@?7TDnh<0yc@8Wa3j`zdWK@G*|_kY7#h1>U0t0@(__+h&!y3(J## z_4YmW5G=4euvnucX-goPI0y#)^|lYPvZ_)4VK6>}nx0|YdC|xB#(H?GfY!Ip5~MKO z)S~Tel(wW9FoQ}C{iJ^*&97_p z%>`K}Jp4LV`H06j*e_Dk%>60mj(gstXtMe=uc)MKS>E^xF4|A#26o#Mh zJC3Q?(hxgLW#7RM?=5P8wKs0%ZT)|(*{Nq*nBHH4V~@z7ml_SCb(O&km%p-2uJ8-w z2sH{{3r>u*X==`2Rh!gVV~BR6%ma1>B+`^C+%r#mi$Il!6hq=^h3i}`3CHW$aUQoz zcTRVapRfsY;y<$Al`9YIHidw-bfVMM`gTYwNRzeUG23>LEhypsLSNR|y8-jG)dZ3a zJ8KMgx2~+#+9e)|1YCFxvrqrXmSYy4kyKly6n0SHVuiYy-L`Z`_h9N0?Y~8k_rAP-BsFBhrY-S+FOC zihY90m4m?17vlH$u*S_#31F^if(^37OM^kAa}*`Er<9g*vxfm@)W_{#eZr8ttek4s zWF<4S-3ITaCT7ZCiWuz=)YW9A?zJnq!zOl{VCCjzEwAgKx##{xyltZq#GXI4qPv&zlMKPYb6htgB760 zV}Sc(bS5ggUdsWs(-x`uZKAX0rwFSr<}$okdDk+H^zvfBq&9c|HdcIraUWPFXM(%I z=0gB}MxON<3WGvOKX4gPvloGa{VHh1(X@rYi5Sdeb&oA0Q@_(ASMF)AP;7MZ7}GR>40g~;%_mg%!gBEu;=GpS|&0t4`n5JobWvu(+f;Hg* z5tBT2;a^&S_5)pGxzZ7J_}hv~%(aU@KXH($ixEe?j4du_hAG)s+NmVK$d~)I?#N8b zoKcYJRddu3DGn#sQZL*6eod{zBW(+x0>NQ$tuiZCV9;!60rZ;fIu!O9AMRnZ9*l*s zUr)#@uWKKr-j)lbf^XinkL|2nEVt~5vypqoL-kxT-5FM#Qr1H^u7D}&f+{n+)Sc!l zm|Xe9Pj`tI&sBez**&!M@SUGPACv1K$ef-2TUOc+b{jL+sVU?`*^hOE7bh>(HzkJi z-EXG7b($&ne-&n$4nXRhLOq9xG_(hAzZ*a)L_WN%G2}hG_uMFuJtKOpZfJW|Xz|z9 zCzKY9xA52AjTzVbEYvku-=LruG08r_|2$(_gvA=;dKAxQSaXWQm=@lici-<>U1M*m zpo#I}#CDpyxDj7A=p<({VJvQXop(Av17iZ^RT#!GQNAOL=0$-`ELCoo{dB{vg!ArU za%{|(i+`-mPIp*r_~*a$fXSH+rS@I|#9t|G>9Nx2|8&K)Usfe}rgE?QdC}v$*PInC zNB!lu{(swk5gt3mQ&=$J+&SYqFJRJUS2#$`_rFap;viq{*8X>+UjBF{1w5$xgFZBU zF}7CpO&d&Z6)#+(XiFAieNH_RbQ9VKlyeB}sZWw~f4`=dOcyn3bd#>?+MUmm+p*+h zq`_v#TX3c|MqK76s6hQc{l2!Ky91&8ClXsr)+sGs7a=)lBX?_{YS8v)WnsI8FM&Xe zSNu?o>_s{4R2gY9#HzT(DM4s9$o2p0zD?Car8E}~7RrUTe85LZfq6HHaj8@t4bl@F z?(GvlM+7l4E3do|7qnL@w_Y`IISyEb`APuxAJ~_7uYa`tr^wSH*6)5cfB&o}z1nO# z>>T1teyc2y*Sm|#*qZsJ-xfFv)Vr9*?Gr}6XZRIkny0{KznN87>W*OuSlY5-wG;?g zDW1b?2#EBQ5_2N+L(5AdQ27a zrPxqbjtMJ(*dfB=U*eg#&BRdSJtGkEy0J0{iqG^Q0@MJv9#9FNq@@X8**?%;#&i7J zRab^LCeFKkr9YOwd-FLUj3xw$Fus=OH;qg69@`@x!u00I9g9C3({swh9E8cv@nMF4rTK<4c`kXI!Q^Awy~M!`+lvD7 zaq*tu^%KiS7eOiTk1z-A#W4tJXnp0B%Nf9+ue+OtY`(scS>JHF{dxpNVTSh^mbzpn zdw;$NvJ1KeoRce}BQ;VIlPo_dG%#zd0_KULXMarsUHl_ld+@);WMBP}$J#f61^n{< zAG`Iu#COHk(74&)+0&xH1!ZqJ`*{600@|vue#Wy|Gp9*BDVOY|A-Rzr5rr-7X3XR_ zTD8ZLMO)L1W7YRm;JS#@m!77WA#g*8*z)hI&->RXoUV6KnO4=VS?=>Vb;mnbSzTv=S z33HN2{Y9(hNOL^9WOz>verT(FPnY4-lIxP}Z`ZrPT6 z+vIs$fPC?NC78?EJ8zU2*z~D&KYvH=Ne;vqRKcsZuB=ndwONED0?uf9|KXh`k0 zSId`GXL>0eb1e#z4Pi^wE7H{nb2;QR$@CHF-Y&p~d?^y|%%nqzy}8*W3%QAyOZ-@? zGmyJJNzJY%?Xu9`Nim~xAB~Tv#|T}ffXKc){g6iGAo67MXW5E_buX*}MdD~=(=k6N z(16k8o3+0k`tX45mW#TRmlW&|8BWKntL6ye=hIe+fxnYf2w?9H>9&*x9xYH=5U ze9?XCj@Hr#=f|Ve&fFXB#r>|$Yv50&{)JM27?C8O`Uy^aNZOT|fq0jWaBksD$kup6 zK(feKXzg4FybWRf1Jzq_L>_)%RBIWl9^qTR`cP1I1ROX~1)al>wW2x;Pr{C4;gfuYEgZuWDX{=gb$||&he`(B9l&ogJ zN2E9LTgl3OD%Kwq{^8k;Y+`O?_7)5dlh+eu^Fwi??u8rw!=yP5m@uQjvQe3#N>kQGOqN+!slHs5*`h7X;slnvs5-frlj{Gp)ef*b*oRm|RhRnhS zRaAfSvTPk#WKNrY?j{Gv+9F$;(QnW-jcOT-%ZK}Qz<(Gh{(NCQUqjxdDymBg;5^tr zHDSEub@MO2Ml*QG;}pLvc97k&wFjE$-4w5+zPB%fOl8JVay+8k`7> ziHmnr5UME3-wUMLUtO>uq^#`M0z})mXwz&e2>R?JH074-lvHi`-M{JuFB?F4?kdng z%@Pn>TCYatxS~-0M5f0VwtJWwu=Qc6Eq{n@>ByWOIBQlR)s7TGDJyj%G-}7y7sa?x9$-x< z2A4HJ4!PUqj>M=yAR1`w2wuCD#OMY=yhqQsa2~3|Lx+Dxw%KKy{|G4HrZ8rpeF z!2i`dBK{5kP~Dw68dc&2Ht1`3AIyZ)a^y)yHSVMNf3|eQP;G*bUCq>t^wVFl z$*#kGvob{4fVk2OISD z878XG6i=GapZkO~(Q*E&q6e;P6=d*SCjGeknUSWBF`w|N9$!K={bH&)(mRFpSLBM~Hjs|XdSs8S}9e)T8QRCmK(qnM(>NVG0fY+jV z#w%Jg_p~f);xlCdaTJ3pi6YVROOquKQvXO-dNLcZFhvC`_ByKBQAP`lGif6fTGDxM z`JOqB;rUGM4l+z+frb;VkY)N*IplqD!Lbwd@T@KbBqZd&xRR^P4Od52?BuoL9Ym!A z-;)!cI42m`)I)&!b?ZpMobRQ#I1n`)8kSvND5$y;yoa9{-;u5OCvtaq=o`&PA`m?0JPXz?0sm2t~aWcVw3VS zQB?eOPP+51;ci_Y1_h5!_+N17Qug=$ZdYaAJxqD0j$OC1&e(Wm@rua}>p7nqfA(v; zSBmI3Io+p^@JnQS_yvibacx((@X!t!q;t57RfmvsjPH*&F2&%P?!2Z@zmt7g%tQZm zi2>@RR}f*|4HG<_g24!A&YC~5yleT0xS0p`S<}<^!!o>m!{kwz=A+e?9gxLNLUZGT zR#0rWC+dyI(mNvz;UuNoDi&d?-$D%YYl=pFkpXCKAfkYhvQr;_lkr!na5{48F*}6G z;a;#HD*s0e3J{ss5%UXo=x3)PVMTO|C(tyTj8q9h621x&kBq_rePFefg6#H?e1_qY zl!6WlW>Y%cjBozF5`adO&!l2peWHOc}Grltfo?S<@E@*wy8>y@n^$ zbOfu+@^HVZ4246StI;4x?9MxQRESYGehJ4{m^eS+R%D~WL>0(UaO#n+^Kfs#p}*Rb z<&6h_(dDu09cZ+W-7g2Y`-@d>RK`+mUL0>@5*pt@XT2w)ps^ubOjdrN{I!kQUF49X z8x|ioes}psMm)7+#us99LT9U?c$;q{Z z%5Vn|AsXz4z|z^26CUmjqVhp84~Z>efo71zfn69~wuPs9`zf@CxKk0h?Ew{UH%7W> zhjGdfi>*^`DYf_V-4Q@9-;|i#LMQ2#pQK|w^|8C+zxGbz@k|BLpCz~^IParbdjhf- z_T;LxG`66S;dp^r`(r$aN%wa{;^#l@!;?x{-9ae60lJ+?Dh(X;<&ms1Eld z5$`tJn-qvklVBXfNc0!VofK;4s)P_4Ud3)w7}D-!JaqpsrN#Ijf5|NLDtn>;oWvID+kufV z_9~aDtO~zpctHr8?)JMEPF1fArk?opdLwxXzx2!@x)t-`cxub#e`T=CY^s+@LvKkY zw}c3b7Temw8d{)YgS6H@8G^)Roo$$?Ir4qKQadI?BKCnJy*-C^qt~o;5d$0>3W1I^ z<5t9Pb|amw-=Qa8mjIuYKPA@o3Lmxw6Yvp?*lZxs~nMYC0Bl zYfnb#RDx^JTbcBzhYhJ;-${{DJz^{YP-)+g)R4GHJCgH*g0H;(rdymO+MzL|$)$Kn zI{sej`J53-8Q(-tal@C$0^V^(-JH2L@q4rU>SNvYzz`p5N)H9(kaJzb!4J(C9MF^i z$$y-b&%o_3P)W|c1M=Sv;82^tOYCLY#fIaN@76Ow{rnAxveF6CLETsw9ULl{sWUl0 zeh@K0;m(8unj+wP4CrZ3HK$KHKM>|u{W<~Ez3XOReQZS2Hq`!eQIw+Jalk=&dE^$^Q zcyHzFc+Pi(!j1O!$4oHIkk1{EG8{&tMY|>Hm@|86NVeipyp5YGvfN?4f8Qz5Zh$kM zqk%Bb^}zH~+R@(SUcAV~NM4eOVl9SFh{b3r;ilTzaw$($6oOEFI8x~w>uAtwh>yfw zi^>f^AcXEN>@tENRJ~9Kq&-+1Ts9^F?ni)GY4w_jY@fsWd%V39+|a7DEkgP@sNulo z_8R;o00<1oy>GU^Don_v#v~O%eaq2IZs?XMNkrpw2%b$EbZcv&qv#V>6%ha9M;8^d z68PVM2f}{i3HLM3`Ut*|^Spk9Q4uZseZ`k}&XpHR-iBl$6S!h-+!;xgcmIuS{Rq&M z;uxelLT`CUA%jvrtIk88xBdr_HXd$|`OkRgS6NXa64qJ|f#p&w=QC3uQ+{@2`jI*7 z&z!e)!(1HV&sF`{G)%UGdS{eLWLT&K`*&X8+9?^ulDw+4PkIUyl{#sAYW+7CF~!i% z7@2DF(;RAzJAY9t{=|dolH#vbu1RG#oIt)G#DLR%g<756t^G^%yWoeC$Xmj{#Y2FH zvCoY7_*VM5cGsu_kII)-Xs?g_kJc5Cimn9MR-+Q70xY+w>6B16+T!ybod~LPxqn8Q zZFKN_kN{FY*AB{(+(i$L2eS8g&xP5jq+VoWC3-PdnJ5OI{9dhZ?G9mpYsaN;z4MxZ zu!v!eV#ND~Y;W@3-NvH=1$^Ke)ZFXS5HsLea6F}USged>ca$ULwDiI}e-S3C2IND* zCBLcCvY$hfo3M@<%-Ve%wGYz*E5=VY@txP(85Lt1!C}Upu9~9zzmJZe&wanyok3F+ zs4mr@e#~E)cp!JmU(5<;?sGIa56`#$a&Kgsa~WCga)yXK%Qd1J)heZ1$afL%ydp_| z*w{T2tRn*YOm}&_M$s(u zaX0ifCXfRJ5#0Vp zr#}C($s~bLXm_blcFf<`;Vx6l;mUkv;%2m$cW(kY{0R8s$gM^Wj{SGJrvF3=5Q&wS zt~t4If}SkanK|aDAGY!51S=rYrSxt&liItB^Qo~pIBMauYNwtU*xIQclL7Pfc~A3I z+ti{t6ei}aDK#xsS#-(iwnx+DerJ4hVBx6DCAypL^z;{BQV(IjHF4y=S~m-#^0G&- z;6U)9Hml+o8|e-aqCLkz$>We5TJu&7H%8YFB507>+vM+xXtrIS=P=i_?D`wFKaXK9 z4^ij;F_KIOxcO^HEVL%{+{|OrN6GvIc$aYl&K9^1E+hG#!A%TgkmenUXa?*g88j(~ z)7ykgj8L<`vk0`Y(a_Wc^=1Dppk_rWj6o%{`f{Iq>OKd!^jwV&>4#asp7kDq;>Tm0 z>hv$)@A$az5{ED=wOD)XFwk4FU+Rg9??)I&&DDB=CH?;h#sP_DiDqovbTSZRe%yo# z!KS(-2h7o0mw=UUA1hdqVZ3_xQ=1Do_RfUG&%uAwE-DAMwj74eO#>Uh>xUt+tC-#P ze(e+YL!rr2;YCsIt0RS-o3|mcby{*;I=Ao_f6Tq9iCMpPoxbaIwLyRXb9(L8gu5}% z(K7d6Z}~o6nE_@cvdf9vQ3CAg&yaezJV)CpiVD)to_7S#!4H)R@oBhaR8SswDO6Ca z4uT2@iKBg%SP<4T=n$B7cfKv4QpYoB+VJGP6OzpDft$F8ZUd;XGb4R{I`s7GqkVt% z613heRWTa2UlNjG=Q@X%Q!2`jhn9LjqnDe1Nx?5L1Qr949pW~VxU8yeXa+?Ja}`9! zzykiJXO zl=ukW>%;v-sc*T{z@y7+o(2|`y(QaUaUXClFkdb`!?NW(uWSKLQ`dwWY# zVPZN7M(a=hVzW6x=a??1<^P^;RKT%s>_*k~)Z9=!p2!o{DRrfn-v%ES>W@B?@&zc3 ztKVC(DU-_D@fqgg?Ia*-@^k6J9W`M?z>=d6YiPQ5%a}|#R+PKP4wkRMMk@?S`MZ!J z#Oo(=fAXxx;GOf83#je6QUGvCKAA7cH{HR$8R3&60@Jg#F~>EZCXD5rm`FQ!*tiqz z-Tbl(E)&|Rv9W0gEDnBY9llp=x!WJsarSRbckfXf0P31+9)>pcpuQ&8jbt}-F?YG`?ov_`F}5+NS? zN!%F5o8PfJhX$z74RZzJ&}^y^1Kmx~e>~tqe1qyUW_nK6thIe}dXSTa=JGB%7I29k z3rG6G%bV0wY-ssnwvFEC_3|o4KP*5F>sj*YyUS)NE4byW*)CpOl!6^@SW-3MqsBt5 z(EV|rPb_zOx)AR7FMqs?r8CQa=eFV8J)^=6d%}7kBu-LvP~wh&TjF+6LIdS6pllVI zuKkmhPW*NBDEP)^l5ir?YD+kuWic3bAJ8+JAHQIzF!Hf z3$EM8XHh|`5J^Umesnp-&)aSmNtOf4j}12ezMX`TsB1)69|QzNG4bQPQg6I6QHU%x zTxx$6C;?i_EfsP=DcZ78D@>WT;&huvlO1W-?9#8u8|Gn$rwpHMNXrJFgA$5c>SPty zE9*a_#=y#ob-P(Wt`VQ_Q9IYl`fE0$W?aZQ!+LMrM za~wj3hXzwseDh-5iHKLoaZA_B3OHLv7DMk0`3>&B{tpJemI>ZIu07^{iTjG1tF>=H-DC+(RY%BI_l2h*i5-3L#rTz=2~}6*f#|+~RJ#GzUXBO? zx#Ho4`>Jb~i1Vdm#i>W~Y?T@HxmVmhowhnsn z&SvLzoP6NvPF6iR1Xd^i4);D_F@+K6J99b&NM%9CKaH$QA(2z)2dE}=-qQl5N>36o zB~q@>FB#RH4T+T!-yA#cpC%Id!|u(dJ<_~IAcP6t4lFm)KJTcXb^tcrTdx+U;vn7p z6dLr4h8?iNJ1Tet&%Xf>+@C%(2Sr^Sgn2%%cpW3AbsF^Im>rcw#+>RO>8l}V%GaND zL`upS|MY8qe@Xwu>j(rPYH{Zfogh79tk!7?gF?< z1fwUr|6=1~QmVg&NPqMZ8sB8Hz=@Cladfij5vcz-y@C!wV?w~I#1)ty#wR#f^%&KW zV?S@@odd-+L)zHx+9hbC3gzfTv(RX5oUC_Ei%CQGfRL}Ioh@@n8hG~ z%Dk#_a7OSI|g!;QW?PYgT#{Rj7RpqI$UcSesGcc*l+#`q`;DsSX z%qAgT4XyHI%r{jh&YI8KW_fPqQ`?r2ZZ=Mn?`GUfXjU`(n&}}p#cdkd)_e6jWLohi z_T(?sZK>4Ca1a((CZH}g%Wr-GPtW@-C8riMN8TLDm7 zt^svb|B21HyQnoefpX9>DWAlLxTYPN%dS&HI9ctrPI^dHqP^har~t|lej$UUeumJ? zuP)FrooeBTW~sAuaK&p+i;A)zd%+KD_q2FvJ$pDPE5i*ind56PIO4x4YmHDD|Adw+ z&k|K!a)?pYxl{5UsT0ge{ci7h7_gVWnx4w76vm-${bDX~Xd)YoyPbN_`wL~l-8 z%Wr6uD5(-YI)dnb3Ll^dlChgWWd6J#sCV>$1tIr%s-5q7OwaJjpXahx{{5k2XX#9RH`PI!QZy@ zug=xSLwRQ_)pb_7?DC6MsT#rP8h4P50l`Il_x5qwA`0J_s6T71IDQW!XCKXW49o7H ziv3Fv-ndF2LxW6evQ)AE{TU(lMsUxNcvQIP)zwjp=#U+ECoX`>!se_`XzzC-L0v;~ z#X3te*|k(#%Bp8K+v=9|NjU@H7)7HFJC@GFL|(BiwHj5G@I`}eN?B<`p{CuKHMX%X z>!8zY*>39g4op5w1Iv+m9xX3 zrr(?j_$!3HCe|CPF5~Qjr`bb1{lunMa&KVjWs$3UPgF-`Lr2H%bg}!lKJ8}Q-}3gG zjRy5Of2kT}1W5ckX10u4O9!YBQTq zFDVacLk0pr^V4onf`hGeS7h|-KVd^+23$|Y<-jjD$kU$hNJZzi237;=%kh5< z*&$TmriitUBzyeNw@B^Jg_ql7k^kTwq9TT_mEL>03cX5A5EgIXP(yyUjHb5o9ffKu zsP*H;V2`~_rHXwMb!DISOyy#%=!b&SdvJMVpPrmT&H3)f%rpvBN01)yJou=1y{|!^ z(cCu09?O0qSm&4Q9g0D=BG?T>54D?%zMoEn1%X4dAGarLr=@@(ny~K9y0dDZl}YHT z)N!9~M|#x!--4C^niuJG&jSN=-P8r5c3tD}IJr+*uiVNlyin?d(_Oy#eADDp6q?b! z$75V*+AAS}&`@|o>d@?66)F=f+dEkg2Uax24<<8yYu^|{q3P#GA_)IJv8iVo4l}JeP~!%NY+H z6h%Q_n#rSLqXA`F{+@xns+yWx^hwsG(J3;Oiui;Qci+4lHF{tsW9#%=uk2PCcsC)* z(<=#3tF+l-qGk#=Hq709DG}Qo-TN(MEgpsfF-Ql+5riiI-t@G3SKg1v&dVdQjgaoN z?Oq^(ieJT}ookET69&4F=L_<#mdK|HmZ#L&zhI&kaY}uq_s881^+CQtN1tGSIgU70 zcgWKS{+<_q2+6WeQp8ACuZYZiYY-Q?EE+^mv4K3H+%NAu*^ul@u;zOS9n@;r@e2?2 zviQ*4619aFen3EQe>zlP(;S%Vk+rV$T&<8IuKV5czcb8pu?N%ljtD(2;O!<#J}OJm z*`Decvvo3{Lu6sog%qCZ5v$>^u>2N5*yq{%YRoS#tmy8sU(*sbP+>{{RzxU1+7KK? zYJAudZ1RqpJv;(?NMcl2@f!RDV;RK@048F}&TD!Ea4!D{Pq7r{(IJXL3t*@(yG50L zzkjb)j_&=4DCno%U|K4S(>D^2Of8%o|%~FDe_+yW}m`2iNe&e|*X;r>Hoh z611_PQb3={)_DT+S<$fXDfM_UhRNBWW3t*CowQ`6?}YaQf(1b!>K%J- zLrDrBfsqm;^lxyeC@AYBP@0MkSgpUeb{sMIJ#;CdJkU=n_t$3%o8R$7nF#4_Vh03m zzUc}6760kUwkywI*yD!HRpfm27X^Z$HL4+kUf-Xud4MyoJJFLDKn$Zytuk zQ>YZ`E5DVI%06t~LAOT}A&oc$m`wcrWwKGdgAg~r}Q--s>M}P0W9~Asn zcLEY?+@G1#Hxtpy3n6jX@RZ!7zh@hBad1RF-$#0PsZ+1aI->)E>fN!B zJqP z5I}C9nflh1i0mt<{1dj%X+Y5NcB(*~O^-rL5)EQzpd9^6N>|iC)GGG_MYxN1sFW~* z==0yG-rC`~&BaCra;tWUts%}z(1;Ihaq@L*kIXQc`OmN8hor4jwTk6{qkD{qFEYcj4GQO+=@fXcqoY7$+u|UxEf|Y~!U-TDw2b zDZC98V#fLA)cB}bU((Gy1Z1vxh`Rs#1)?-tPbt)IH{;@Xc^tRNCrX%`c7hCFwTZAM z(p*UUm5V-&8Lnx*M2lhQ?&m6lq!02gqXY1jlZ0U?bvV-b$~w2$>jEAO{`;WjtLTy) z$S$$^z>+5^%&!=WP$GGf=VxBnRzvbQZucY?hSkifVB7nJ8uhvZ1kODbvty<}#b}k& zRnVVM0?Fj9_IE56;D0X|&T1TRy^b|8DXn{b8c!bVTd~Doe@iylpIwThrF8nfYxx%O zOu++oSVpP+78|xEEi29S3ttJzBI(s&B8!t~l2yp~z;BzPn#5<1g9T4Mg-PjdpDI=i&J5btqSMWjTs52iUc+ykQrpv&j)DkZV|Q9i6j8yti$%svg7BKQ~c zMDNOqOE8AZEcz z94Xpcr3al$K8eu_kBnf#;Id#B%taG7%2o0)XJ+;(x2#F+iQzp8SEyK50Likt4u zfqu%Mhx6|)C);y5c`aENE_4Jc1z_L3*^s4>IgNaA^NdEZh}F3~GX0urF1v&!wN8h6d(zA2tJPmMP|2szgsHC;;$k65q0o`wkzgbub}X+MP#Vv zr^g<9yL+j@_1U@K_icLw&AtH>`=Tu{JDiU8p!5(zNnHz zIv|K_-<1HZC}k{S>y{a{*BKPrx4)==kIB4t*h%2Q#^U}Uvkdi@o}YG5rr3I7EAmMW zn~1FP6Jzv1H!(@<+HP=i7u7y8#uNyD?jXu;GCwD|;Q;xK>>ZEhExQ+;-3rWZ!%jdz!0bhZDyTS9EF9K`2Pkt$F_Zhsf5 zS@FAz#>+q_Y4ih4QHus0B5B^{^btRem8zH#D*yFerw7^?2FwVG@4HRD ze&0zJ6TNT<_Qd)0pq_z)zOpChNW3%pGYa@4)T^v#+&ZE;cWBxW07QR!HQ{+cTixKK zaAD}+dp9sS4gPGvwyfwYw{=l8r%Ic#1DxHz#U3J!syBrNmX6&u?AvSv%uM2Hjc|eb zl}8?yVhS4k`hpNyEkfaCt=|j|z!0d{8*75)s+A`-eqTy>!}e{gb7j=)*%A41@kX4r zDUz4Wt#T0;Lh%iMA@1iJK?I#H3zmgMnL@miTD`*sF%Z!X2(muvDR-9-Ka0*VE(!pT zx0B(8w9~D@AXpZG_3ixBjtj1Z=wjfyr3(}~TU1-a`dHH8>uSf$JHjSF55*!pV%=T7 zlPrvo*nR+?bvnRj+&s6%FS1!I5034(?>VpGpOhe>jjIAtn~g8k-Mv5;l~6&^MhO|* zLXT|R6p`9&KG?=($HNkFKMV~a7W1d6E87dfsO;OH@ZGv*_O8$L>r0r4t-AF%&b;pO zN}H^=k65<5DWK8!Aw=Ffe?z`Z$Jz_e=>`$e|H?TJtlO|s3ztrN3lu1sHs5= zkg>yDObAf&^j?!A$Fva&EJvTH=+wtOS*H)&Pr_JDeX_uAseHu%Yp)4$Z=hvCpdHw*s4s*$3Ffc=O9R%C6~jCqbyi<670($)Tnu_gL_+t>|K zK2SoD6^m_TJR3?`*QP6apiySvC$@~eHpn{7H)k^s&!OoJY4CmuD2;F#n}U;=dTb4Vg_>=KqGIv_C;yfvbjK+@xR9;G|s=*Ih@SYh5>QVR6hUuhXY7_ zf2Rc71+@XCHs>y$h3ly})qgBQkZZm3s}NZHj?An(2xBh48*y9*A%$)32`4hSA&@Ra z^aVI^ROS>U6SUX5gVoNR%YXqjGUl`vDRQ-j?w9Um43ZJ*(cVz|4#eL3HwRG-AOU zC^je2H!y~dC10?^v*&Ki`;s))uAJcp51S?zYFBc8yrGYW6=l)7q|~5O`Obl{Ds@4R zHJJ*n2JNb@?%DTK8;^Z{0|i3MzR+8J_`n0%~pL@Uxlm-1UZm z*+=cY#P$n1?IlG3<+JN7vHauZV{HBP7uBg4e;k7TT~1JyQhV%gP0CWybbJ_%^-cV(d#{CP8>&R0THQYZ#3)C_r~Qx>BsB_Iae=j z+ArSP1~v!bLOVg~%A|6jabLZwV`usdE)m(?X=CfF-`r~rGnJrHW+<}JV)f@{2HlaSYhLW4~%o^W($mhbbBRj)l%19`i89OgG)B1_ru2C z*(~PB<#yyX-E-gZ4z>^n%ofQ+_j9Pu`uCXYpXwQ3@7}U}g*9E!gQ0kO=ljI03c7Rm zP0je2(OA=Jqg{SG%p?xati{ediKq~oTsLos-R)Wc#6Ko0-iTV4a+W{x7sO_BHj&N= zfq<(cDvY|%r{~qTwHvNF+-UtnYFbDB8ao1roFbxdb~&KOZpaWpm-GMAN9=&65iPu( zs1sVU7qoBxDz*7=85^hj>(d7@_%mZ>2ZdQ|!3%+;tPj#{!B9_c+*l6Y`cbD}LzFA4 zNkpKx8Q9DiWF`9k&^Nu0r|kY67`0Sp5NvD9;{~jAu#H?QJI=O%c^9w)?f}uT&;F!(?!1H@RuwB0E@nTmC2T+c9CTJ(I|_ki zWJq|fU+_=foyRBUZnuP)eI1n&%ccfCW;pT?l*c zlS)Ab7(QY~=zb1uPs6A$yY(1Pe~*-VhqTDIK9}KnZt>yKC|s0(_*!>jOWOHhf$)fn z$nimQMU)ovqjIZ^AD+tuE*;r$Wxb3*UTUkqpJo4MXjbUVH|b~^tL2bvp@ zjcb+coUVd{8sT}dI|JsU?xJrRSf?TF>@l^q{YCLCEb0m2HEm7@S_39p?$f2t1(#nE zjrt3Idb4!C>KL{+|O5A63Q-^ z-YQ1nL|Xw?yxZ`F7ppRBon}g`F2ZP#G1uq4McaW?``c~vgHVYtjx{a&yf;6_&)$Xk zRVOg4^H~UmyeFTLJ|aL{-&I5HYgwvxrk86=uYddiD=O`c@AB_2jgDE08>@$+WX~Wu zseD5s3!&7*v`}8@h%EYc)UyzVCCJ%t@-i2S-fliNzQ%`P)gCpd64L9p3fbIn}l|l@+cc!;oGyL( zoCP$j(w%or0AxdHg%i4eKu9R8t;(%fsMaIa8`(O$b_}d_(Wlj2u=WpQ{>;jdG z-F|@;l!4*bQ5^XF=*;D3*)LDAqMsQqjTY}#n5mT7S_R%pYp>tq~HWa7{p!c52>;>ylX^=Vba;?~=|3^`1v^YGIJ~gj zHg|+?!GY#)=3Q4h|7aXAt_xeZ{!lSnx8O~je=#^}gt~QXBK{Q(;}{+|*N_rdeMN@S zP(JNGKxx{y(uaAt7TFILyQ;V!cNF8 z9qK9f$J9W3RzmVyQ)OPz8&cXzjU!y_P=pJt`fRbO-X<`Zsxc37NI$W3^vBmUgke#P z>`DFGev>;w?erp)HD3a)!L*3YlWDq6$R5-@qo^`!{t}vp)C}GGAU>3d$Bz+Y(HNZUs?SpL;d)uUBpByTw}is3RB$#@0VK z1>*s&uKcSR_mVAH_d7SPt4|<3u2)0$2@1pt0kk#%hr3~be$zqywh?l9RRi#FZ0YJ2 zmlEBtOy)kwE9W_MyUIt|$?-EZP&bo3yEl#_#Xpw98!m}$Uuh-?%IKgxkZ~%Tzwd5h z1uJ+=(AH+*SE@SP;6UN6!p+dY#D#IhDNK?)+p=#7F%q1{YUGkZ{T|?UfU&HagWIS` ze`L0_VusS3{l)TyUnwed^M=YEtA8no0-mknBl*3B<7dx zL85D8X$lE0P^3ZSybMNyh;A;QGl5netiJ5@btU|*B_Q*mGxfN7InJW~n>!GPMe>ni>`0jK8Lsn{6LY__v9rKTE)kKV} zG+(H67l-`#?IHiA1VHnm+Jp4v3Dt6VMkwx^TA*=KZr+Z)T)69`II(%E;%ci}o3ao* zpKTXf%1|IiPuUODvM>1VeH8qW(hKMIQ*^TxeP)ciAJTnvRdizof2$y1ga9DC#|D-n zuO>K;5^4zq^Conj>lEB>K z-{chi>#;>#)b?GsD0(iu*fLa#+P^!BPIkh2hnW@qcFde@kYGR7qixqBl7LHEwlqy6 zN^n11Ew&*D!E?juB2;RuBlw8f_(@%T#iMst8V#bdnnD6;BW1OxI=01!VV(DFP<6k- z`=+5UVq{bLK;&a(07<@x;p$-q0WJu@dyQ>or`pwZz|-4U19Y288BEQwGgp-PZ_a~) zW)u?XEkNubLURuaMgx<^J6GZ5n{d&e7P7#kYn&(+Fv@vq?&3c@xPyfyNqNbToS5$A zpR4nxX!*0z!~rf*@e4|tGumzL$l@W0A&7FI=^4X&^uB zbj5vx8V>T3hPmS-@ktsp3glESoMV*)-)iIk#=Q|xBXeZ;)}4reLjfZ*WPcfTEY+rW zRI#7eK`2*9%8o?Cfm|l)8xMTp13Z>D)w@4r32`EY`lNUv{pOlTTuzAwZSQ9D9A@jV z9nJorm(MX&Gyq&JTK9Q@k-TOb00S_GN3&o{)%*~K$c~qG3->B$BMhEynWN4frFYl< zTivuN(ic2fQ(Fv+@4L)7)i>|3dM%#d^2%qtNYn?YLV~BwpXlB7U??#j7d@ky6sMoS zSbwrcW2#mjJr|52>u|E6mL;wiYO3O`&gjKq0m^ijwu<`h>q*L~0FnB0DW`MJzpx zDe340ekEg5B9Y^sVzZ_JIpj4yMoHLTOAG)$w!ARQYEOeyW@bsyY{(v726-Tiz$0^F%s|yAHbS4(i zA^3tI(Tqpt(4Joy&=Mk04v+FtA&Ly-FK{qyRO*i`(9T(s|L9JWK*6cMLG5NV9w^iW zp)FSVN3;vjA0HzQi_+V)MB_K!jgUMo|DY!jfMbtV{;vrKjaCIY}r*(CwSmxZOTHrg@={0H23eb}v}8TcDU+)oMwq ze}lHMX~-yyA%_~{3YlPkZjVb)H(M`o^Eol6vUoEbSr(^yj~WIbMGu4WIXC9^LxPFf zP`?QaVV7RZ5yeKfvh!}gW6m^C>2cfc``5KDA%kmEQ6$1I=yU>{Zb*KeGBALT;x_n} z#>zZ*V`tx8Xm7Z>@DURo=jQbB?nW34xc*2y+a4GV%EtMIA1(4qY!=tEfE8Oo*mQfi z(I6i7j{MzE(PB#dGXFD&Oj!QKfM~58oWBzpQH^|t{e6^D<1}Q|!u@klP=Tarr&yX$ z`dvv^6E}0+U*7_c`i;BEiS>6B{FkNd?=%IPHYDv8LJ|m~Jo0}Z0HF=4apyBBMAp2y zCseFI>3OvQL!m}}XqN#5l(M*n>U_g9THRycXv|frMY!<3mP96{pQZHDIc?kBxSupsRXqA_^&kOnV+&>ZQ zf~g&L;!YJ?v=}@soHcxD_$Sld&V{nJ1sbZg$WV&yJhd1nB(8EPi|t=1f|aoOX~>_}NqNa>f~<8p4fcus!T6#dSdq5qX?? zjJeBixV)_26*>mHN#S29jta$6cL|bADyJIshF9J9^ck(}1(sh2^rdc_>s$+1?6J8r z9*l;zKc&X~Il=gRWHCh|@gz^3rRyVda}ui4ZE6r3bZCB2WE-w?!j1LQY5kVYKJo8a znxkl19{+&5F%FiR7fm#&lst`CO}Bt-nL?9ZIL!%koRoi5^ii?T9xIk?HY_$YuB}JU zPwjSdqJm*;dA6xEgO>ZbBahpR9L*7Q`n%5eX1Ip zJ%%8iH8Nwz2K=F$B07DOJt#*yD$d|xebt#1YL7d|!b7W}nwn4c4Yfh%^pU!+{If)- z-aOKK1Cf!2kU$7#hox${p{;B@#YIW4F%(wmQ4Yv15;&sr*SNqA*I;l5Z-zPxiCi>o zkx=+ji3f4^&9AD#XnRv~=v&YE=Br=mp<5rI(x-RBO*UoM)Y$ia+Txs;;AX^vmb$*k z)nDQht&R?&4oD&t-lfdFx83|W6~5=DU@TTa*jEcpx=;CWXbtmoINDIWunO@1 zHm2K^vM>9P*KCF?X78$@6v~-WTvpWEW5*zym0T`yK}Ivk7LyeYjs}&&&Jf%y@2S!G z;v|SmI-2aKAY?WX)N1|9#ZN)k>Jq$8mQnAlf=gHu2nlr2geh{&vdX;C0kJz3iUTqU zdkOr)ZNInXVWvS8d=O0phPmYQYZW}YXXNeoI+3>yenC;#^!TJZtxh>@OaU&-d0az? zmimS#pW#FDu0W5}@r_15?YE|4x12uFVNIY*RHSJkhJ0|Y;*>sGXV;>9d6W3~b(T&A zws*rgE9H%YVQzu3q-ByaMC{y}<~(*Bs7giDqS7>%l=#Cim`N*TS52gLGs*zM%;SvBhh zuHRe#ucgJNlS)4uzzAqRubRBi7+G-=sYw4*OyIztjb@k|IrMXx5F!)#WX}!~bfXHe zjSJgwG_3WGi2d6O|KKw$ITa0SBN|w-O_z6&&NRurKbb z&9UwGkcW}@v=}7~UzUt|c_eLH^U%~1eD(@p8uQNIk7dmqt{{7cz_{=nnNR(eGsOzlF^!bm8}|@{k=uvTTjX7kjmZ}bdq4&yXxQM`ta*=LIQQ;g{l@# zj}vJ$uE=eXUpRVaOb#a3`t8iizy|%-j>S^oSVEI&;b&8`w^K=-ZaCAX`LfsES!>o{ zdLL|>{JC7ma)Fd4sEr4l-8l}j0WuKQuGO<5r>w7-Y_R#<7)gnE|EoH$C8qFjiT+ve zQmhaJqJ0)rN=mF^8$6jrh3Il5sqFaw;^{1-+G?Y1-Qr%{gS(aB?h?E>#UW7K-HUs$ z;_hz69fC`7D^Ms_in~iszH{$5e>0LFN%p(k9DvG@vFay$*bp13x2b3=I@5ow9eqW2p0}45Dbl`nN2-tqh!u~U1~W< zT&iP_^!n7`CjMnwMSd@13B~ADWM026Ml#R47goCUh#*QAFVp&O>I^w`*13KVTp=(o zIN|d(`ElEtJ!boeFnLfTsiy@eR~JZe2zZpOuBxLd{qs4Fq62MPvk6FjUEr**xZ+6F zbEQ3%-wQ{L#|31*1Xy^yq+R{x#^*WtFJ7`F%6z5zn4>5Is6Uj#SiEU0ia+9Z@an>c zw{+5|ysy<&S1-eXcj#s8O(4(G(!zs3?0a_DbL^!FV;RU2ay;aO3s%2f{p5u=+`OBb zf&E50Y*0tb*MU0_{J!D4=b%I+DvHSqmxlGeWA+UeqQO0QbE?#tr}$J$)NGqF0ad}h zJG~h$WV)H%_k>SFgN#$aJWfD39kSlk;2N?FxTneTfo4aA3SIPp4SXX465Z$gUqZ`MHNsP$BJWZeZlsX+L-yIM9h&rz7 zvFQrF=@cbCP_w#y>Ys4HrzuiH2kUnRB7gRHZggG)0UsUC1kl75@?^7Ae;)4L&*ahd0BN`)|5^ z@?VB^;sCgaep!0hKKodBxF)klWKR&9p7k-uK;+%f3X$30obRxD?SNtjVXHcQyJJKy z#y1C`2@c8j8I)B4q<&^ow4(3gDHB6w24l8Dp{Z`8gNq|%gpkl}t~+(Q1-h~qfJ{Iy zq@-w+b87Ol@^6FZAA_skVx_dI0_rw?5H*VNQ3gWpI2$j+Q+J6URMz8`;=R;qCaTNS z)*b<_*;|7fpOv=H@Y!-+-n+GeJ=jt49b=a2U4Y{#7oG%og7AaajTIL)*mPt0u?&67 zY#CTqM2D*aZjql5*nWm%W?OrR?z21OtL*$NpMLFy*7QitRW(zg+fd6E`1yP& zZN#oP<@Y%$$CKIpc4LN-Vu(oP#-En_{v%8UXR#| zgFS~AV)Q>85!s>j&SDD?e9yrv*gD)SMT*-GuJ3A%1W;8uHh1(`Q|3-V6@5jAUVE2L zFn+ykO%}I;km7hA02MdlgzBfii&t%c-j=8Jo zlhNiy|A%M9=v@$!53KvRjhY*1kiI^)?n!z*5+B@C>~5H{zVQoW(>bSpV9jMZJ|v+x z$6@Qv&@t&OW4{_&a%y6`wp1^Zo%HRQsR;Io6!f3pv{YZsb~}lis?KbV+6cch&d^c| z!b1otcDhz!7eQQ=c(Hq=g;TV{9WzJj!1AtMm)LuV`$oP8=SC!fna_gm*n*O#^t#cW zlW0h(Qt`Gq!;(1q8eqVl4oWBzDzN6;9lN#PLOB|8=(fG{n+3w$?3o!IwNy{raf+@! z)4X4jm?NRx#TodZB_M>iY|z)J*Rq9f z@5%O@ghDKpbO#0ph`xNY;-iy_AHK<%4vai9GPmWWfhz$EusOh+o{At25{bv}!LIJ+ z(AL3!aEQWv(yFh<%GvGs(^|$S-JI65-f>ZWMB}84X=bN;lMyW`0WgS0b^`cU(NFIo zDcqjbr|E(Azk_2%yFJbFat@56y971Hx)`vnp@-gx_sl59mxb8pQO(vxRcC;yOfe&6 z?`(aD5AEeSKfrc{vYYEieIBQSG0QRcR4)O z;A%ChV(q{!bUncU4CTYWG81Iy_(MlJ`u! zb%}3CzTM&J^@>oOW#tBM6jwFW=v%#1NZzbB&O-u=wy-9Lp>Pg;DwY>sdTs9SH-1?o zHhi2q4c%@tC0lOYoMR5%qAot+oyq``QRG?dej+IW4;Nvl2FS5>A9gdwPpX1{B_qhl zIYL+;Mu3bCF?lMk&U~kKW7*? zw=8%rm7Y?dwFZWfsJoff+qAKtikeBuhYj~7Sm5rA<-4t04mmLOS6o(O3pF7glYNXH zCpx`-$$jPwkVB&xt&@CM(@k3P?Qvsk+s9aDSLw%|7)2CLWO{C$Xf!WlB7;fKlJu@i znEP{az0pVwHvejPv&;pL!4y^a?Z=qo#gA4N{N~OX&16SZ_q-X&p6@=Ja~bC`EM24P zaRu-hf^Mw#;SMU*sLIZST$fJBkM8;UqRc2l&DftrjzR-W9Xp7*t3T3KPkLSPlqmzlyXLo;yRI=s#qeUK#4&NV@&T00TbdnyyBEQ%z&6v- z6%e~Wsq{;P?FsqE3`hWZ=9M2TW^+R>gf?ox*4yL~3oS_&FOYHEy&=lVsrJQO7)=s3w876`f^|>aR%NhNF3R7Zytp>C9V78yIDDW%`T>7b1($<2QtXj1rgcsTvjav z?K$}nY04=F=hC10A0Qt$+?wB0;WiU^-h^)zwBHiz{4hZKhzrvAXXXfhj_5>N82u8= zGP8y^J`QLt_K?9d*PeXBv}lF7$m-tDXN0}jrsPW?EdjS@+9OIr*rN&nIO7pl8PON3 zgG|2|CY^pfI7^!g!sU_=mDHEMf)HjlAgg!cMS5|ly;5^v6PRI!KpGw{oOp9eNyR#R0!$GaH)GBix+f2zdip zNR zg`WRJmg^=X>T8LnD#(35N@~3@&O^u@ns-kX*z(#i&MaaVcII(X^^5}*%1w&iw%ztD zb)hRCG0K)VIo(hqwqR9`l05U-?CQAInO9)!)YS^7kSoo}>8qIyO!7Hi>KVG=eBSFy z?2|Vse6!gt&j8XGuF*rO$sx4**uU<=SGCoA^Uey4QXIV=0gUd;ef zqu+C6zWIkC&R_ct zEQ*j9PI}Qc#-Hp^^a@75i+g~4&c*^Sz1*TC#0)oBs?nwnsbtv+X4lj999qVaw2*G6 zXaE^JC}Y|b`8cDixQ{mniZWSogv1Ec=jxAFK7VJ=67uS#`seBLBQ0CaFC2e_s4CJo)LgIH( zI?mP;^4xSR;8p9oIAIp+iT?J&6(jk{w(C>e^zy14l&JFhm-VLel6idM#s{4|JWl^% zT>Wj3GQS zs@Sr!zUAIH{AX%&6XQ=(cYD&u`DxR)2POgQ7l*V~JRrw;RiKYuS8z>Q`mH_H^LPTx znYWUp)j$yXD;Fk<-NQ!RE#Zx%++ATu765wz1z~12;_r`6zS(rp^m1EpHv^BkwzP>u zfr=2{Cz?M5n~Y@E*Qhl{#GwP~b5VxP=;NJy{QKKwzliz20B9H(%w^`~vpS^W<9|Nj zq)qF`H(V+Z_uN?$2+IMo7G|9M)r>H-o~;v3%QC8K0}O-eS}!rfb-vCsS31Ikj%=?7 z;7{A* z&(%Pc+Vd5u@l=GQ!ZR2hJ2epJ@zkOn7V{b=%Mcp*@JGsf8+iqXSx=7|K$f(1K}|TF z{c#AxU9N4R!F`IkWL$UOi_WT~lQ>F<*vHtqK-*-09q4_*M4edleMj{E^U!jN^i|m^ z1+VW9^B;8a;%wxhK%PHQ%}c<_?Ca{^`v8V%krsC}$&0E4_p1ih3dc#<9zG|w)WN+6 z1odTRBZGWq;#7atiTYOIv5H@jpYg0ssrJ(selKNcDn5E3`=}r{xE^xZc#*7$0nJ6r z_KoKRc9|2-3Jc8Ilre6?G{(2XYx(pI|Yus7U6h6LM^>x%7Jrvt#s^%dh4DO;ZnJhOzXxVc{_U+y)Dfe?Q zC?6ZQ&3>c+xUqq9EBt+0-rx8%_L{wqxS7OHE$7ssNUl^T69aQ%Bn7~^g_u%V%Ov|faWnq?br4l>tc1Lfmj zKz+Z-VQ&7#)SH5kAWz-FZk#!*=}%5y|kpEC)}Rk7GvVYrjS&G>v&oT-0o zHNOSjI!=*`6ZY|W;jK^W3x`9;tq5O`ODxXh}8Ki`#5A5 zS6iH~DG}wZf-)9VFs= z1{|~zklNyCWOXd*-`GM#Fw|wgmpL;-895>A;I}?r44BPa(CVg?(NuCIF;W6_fYvV4RpmA_v31(S;4(xoSM*A|80z< ze{$E#%^F`8$TCFi;p~4HhT#pbIS}Ur$kTdzqm4}04wZ~6T%hM7>7Jvk3c(cLJRz*X z54}om${J|l`Xx5!jlUo#Fc1bwFa5P_DBsWSOhdysWD0%3m5DbF|MaiX88@IfJp(p& zQXbriZxcFI3h_j%`G*}_rZQ6hqvXkMpBPR)-{t}{0q2(A`vJE-N%U+7otyP&(3*T3 zZM}h5dFJrMt74PBOoy7LUjY}^%8mg<|78-j%I`qQNQ?w3EdX8j?7jUyMAm}8HZ)C5 z)o$~-kS#BMGN`$>zpsYnrX2W_j34=mp!S@fn$RKf>mcti<>3Z}XxJr)4VGtqILk*LUT zTFtdI1)+U8$XSePfHh=Llltd>UT+`C#L|)f?z86KA@w-Uut2wDu%w<1M^NGc5@5)m z1D&r8!e=|eqxAKH$51tYZ;MNvpvdh4aNz{if{>zBY<(18pIM3bPisi7A5hdw<*9Dp z(??}>kroDmw@U!$de4K)v;?vGmiu!|8**=VyD*iEb@r!@v&A*UPm4ZHJ1G3+$CHa} zpNjC)qxI8i134Sp>|xvC-^^BBpqM?reXWQLNwYfugE1d=wp-B5cFUo{79xfL5d+5G zBz3p_OJBH}NE0pv%Xy*9KSakzelQ1T`(f7r;7#|*A+sLHiAz--o4%Nnf3QO1rQg+% z!<+8%65_ei1shI}D1j6}5xJiByP`O>$)I++8F5m6inQq9qam1FPKC*3PPZ3q8(o(2 zXqa3^o7Z;uVKBDtOF(9hj=iS{bt78E93R>UnFc0_@jTul^3_0s>my_T9emc8c6P?>_lgHphTO~HCyXGA1 ze}n!AWsu-`Y}a9=M7(OGf$dv6iu?1EKTiV({?^fc1&_y*ttO}v9fsZ85xSiu`H4wP zPRQojO}y7jsyTm08_R$Qvxm zV`))i804p}e9ThxhgIL&8puxSWv^zZ%A-beVaV5(wKNR?Rc^1*S?>eASsu(YeL zhp(Tx64ToCdeDU&R}ovla+af5e#)_@68lNh$yAEt<89pT6v4aX(9} z`o({(KF38+n;dK(ULxl6;ph&P>mQG1&{43qMDKE`)idzWzA2qJI6)j&$hwTDtf3{v zsRl(Nq`3PV6SFB;$nLzqw_vmH2=@o{>8mbK*JOzU67)cKz2FCysp=O|G>cbwR6Mur zd-J*O@CGiIWKU)nv`EbdwgmQl>XdGy_k zx>h&N0JZx`B05rZaVDsuetDBLxV0LjG82q`szEIi66}1D{Vw5>jMF<12P&TsTD=8Y zUAR2rq42_x?AKUR@B45rR}}p8p&`*zyu~@V@)|$vBZFCR;8sf4zFo=w!lkK2n{WT3 zJwnKVSg>AWb;J#(!rv5=c9PIR+F>{reRBNf2gD)11=yDMT7xmWzVLnYFVOL5dkr-` zA%%dp<1tHO-Zm6GB1Ti|2YhdDC1JQEZ0p2vd*vzNF!#sn``GL;d!fi-Zlu8x$OJh% z14`(>8!2TyEf#<1FRnh6UBVTx+=y67>%ugi302PyfikHA*f#4oCEU3SS!3Eu)tX3A zF8W0N@Pj%n%pC*&g>UmBvk}#ftXF+i()Lr(w>%vt&&1S`=*tXEXJu~Y>bM~Bu{Cef zCV9`bTS6@AhU}Y~29tf4`1hyZlXJQNw9*P;)bVCx96ZHuQ+_;#kyDZX!iZ($rX%w| zwxkAU8Bv{@5u~v|&PxFkGQOwzg%wvMapXWb6m=Btd|8xSow!SMzlT59=$v1ckU=)x zb>}93S!;EZl9Ex88Ex1cc7bTpJ8Fs{Nv#HB_9DTLntvDJgY<-+2C1z%zTnTP#va6!Hng?ZJO|JUsU_~(RmJ61E1o28EGw(|uEVqX*iG4)IG=cuTh z4@CK9&dy0M$$fVj35_m(!Y|1hFx{LpknA%V?mcffC65_=-B5}G=t6VUgES%ltLGW9(ucc7+8+I8c zS7CFAtV^+ituHm zKKx%8T%3x4DL!Qu2ajrWbmd^ir5!d7FZ`yNy7Mg=qF=h{Had4(^QcpfE?EChIjH;q zw$Kc>xl8ejB)z{ikt0xEGNhokba6$o`_G z9sZP1VBwp^snM6$%BI286)N3}St`W8`;537ORGInP94zAQP~ee~ zyh%wGg)pa2#w6LKH>bC+85t)c&;bHF*sV}UtvbtS^4VwxU+am`!YzC9ykUIJ-NIKM zFAvw^Ary~&$k1Dz#ZtuzY`D&jZmc};+n+2G3|RYHYzO>w|DT%HCV)S*Lu3NLe6be> z|L2P>CH~ty2LIUO7xI4sT55U3j?~)Kz=SigC5aStdTTf4&=VKoz%!r(jmN84hFR7i za9GH9mcKu!XNT}G#`m-G{2hs#_(w4g%-=NUU>WV^RF3?B)tp}z>S-5<-4|z9 zs+MystF_f5BH+SximB*2=A5eU33Dos3$+h;|F!h57Ll3Oi6qd6HmGfCxr?c0Ni>p+ z5b?I2rs);uZc8XSNicd`x-O0)5t!>kwA(CB9#nVO%BSNaph z>r53vMkcV9P5YSez8g!@QF|xMX?+QA(aO@NwAKEzl$gn``1__+n<0!;VCm4wuCp~U zzV5C-jl@!0^U24n5@O{f63z=OMR9K!dDjR{iCdvs6+j*wiC!aE3xPqt>_@1r?@~b4 zy)2lW?h8d%eTSsWwagI-*&mSfJCxR=(Spt2)tuTKeBjT>(1U)lU}sP@LU2;cdgm6h z8^*4u>m&PZaN4Sinn6nguL*lk8QxpXhGr__h!fNBMaEBPY>EORP1b0Crrc;*9v;uw zP1qZY;VYMW5c^q3x;cU8E5oQo&SsDM)-KDut?{6qIPQrGPUhBA@-06>JBwS2*<+=(=0@BN%ezpX5@%FZ&`g zAM1buFi{A+c#|f2noyuB3!xgU5d@BO*0h_|Z4{&8vW(LNnw>g1c zVZUa{r&!xJWGQX|u!(Y}IF9Nx&*2GDN>VnqiJy8oIrF$<>G%m)4`{`^QM1AO%A2a9 z62v~iAc#~+p$z?v*hL?lB%(5R9uz||$O~IUq<*3Q(CVBseqQ-C%tpHlXt$&HvbLyQ z6Nb2m^ zBY+fh&wzYIweZP1-ez!OXCW(4JXvC-~se9>=T!;G$ET)QzQM5++=$r*3 z>A4REmgeBg{$O30Jdm0nJ53&~V(v8?O6WVnN60@6LbIMkvN((0q>d75W#4?Gj1-lp zB50i-xudF!-}D02Fc<%Q6OT+9l!MBkT#J}-Oni!JD6aY2r7moGikN*t$2L_=(0O%b zjo}zmjGoROAKpb0U7Pb;tPW0zwNGM`;kEr+TjLzF(5ojO;a=|IVM+;ChIS`GXaFsp zFsL3Q`66zDid{c5UIEpC_()KOoXFI1KX`041!OmO>g5!(D^g0(d4xKl5cdf+pAD=!nx+cEO>e=6eAL%Qjbk8# z*Y>Vt1gpv)z|9z^-WZx@u9f3*LKS^1m6MtMnc>C8d@j0y`|p?@KL)XWB&PO(aDLLr z-*_kR+Ym|7Syp2EuQh3*E3MV&_YDkl2!K(uLz+A$b2Hn79K&-muwstO$xD*!Gd}6DdfwkR-W6igBtYd>MB4>x%!PqsqpjqY7n~s5@MIx*>n~Y3XndWPw zyOdv&MFUVm1mq7%_8{CNB6rJdGWe@sd6FWhREH{EeRB}ByoLhT6J=qGPx&fEUszZd zINdJ|Kn`cc>`EzZE$3zfQzXRJXKx&$wg_j`l(qOgXSA@?bvlF} zKJRVT#pERi1KCWvl~)~_B?c$Fs7L7YWb*2q+O*f`IfUx4C5>SQE{10xB)K{tvP!!C zWbtAQpsI4=ZVB(MOoI2+xNwZ}_eiasEh7CoDxXu>V3gV~is3$z&RBO1S1!c!RC!|a zKeCY&n*$kce1$A8F~N4iRz`AvCW6BIrR)e)>qMClnUds(u-;KZ`X{yoRkS9mqr5Z8vFIB40m1#p;Wd#dsO{t67&$xa}4R z*7609a2hmTBj-|>o$}#=x#`V}R0qv!j6~YNSkOgFjF&Y3f?cA%@a~amuu~XuMQ6{J zn?-V;VE|)5r-{?cesipm>kSe$Q+fDBS+9?tdH_~tm5-vsYoIb2+YiaP`>W4a)LaCY z?$*6uhnAE$qyC87kLTkAn~l?v>L}$GxU8VsN2!FTTWnM8=ne65d7K^_rI_cq(&0!3?R{@bj_R^-1a_We+3 z9=sgw(wp}&S>_clPveHthG53@Xm*>Jj?l?F%B6QXUD6W_Q;*24q1+7uCO^_8?|Z^? zvhEv9kG!|}kD39ZW=Fh4Nym2zfwYh3PwyWCi*;*56Br+ZtqbuVQGc1)9>xX3OaB%u zCSZ+CKo4@62|x?zLI+K629HraweC5UfahM!WM6^NuDpRD&B;cT+}GsQaFdVX26a@l z^wkSUt{A6Kq)ra|g2qBm^BM;G0`+z5wqjPYzj$n)$g7 z;D!uV%-o0@#8;z-)k}$^Mxfu2ul(*2kOxBzuLsSkXc1fy6byBug1KNhYQ59jYMAdT zhR;~y%OtM_Kj0i2K9dsh1nUSKKpZt26MqU-ZkpEc~Tk~5oW4dCW7Bz>I;vH zZ^L&{xJ<_-`{1RkhQZS^>|}`7P!QGX*jE?Dy6JRnHbpMQ>TMkY#hPyNLZ1ugs>8Zn z$1y*nA&>YIEUULtk{@*CH~W=UELBg z!NX&Oikk?7%MDj9btQ3$xNXPMo=GU-=8MWMgIeUK!!E(lfmkRicyw@c`aS%QT%dI# zB^&KA%sxqT)wY3mpnlM@0Q?Xk(MQWza~4of-34=DjB1%XC$>3LWb4hkC45~t;SVDQ z-9L^!ChNE%xG=-i;uL}NB?;?JcE-ZFHPnwrWRw7``o}J+UX0z-uE@1AUGA8^G1nw4 zg&+`oh#WcRqqn9E$*>zMAxMsYZuB45(0d&r{Q{B3ZhQdJ&FcK!+t_edB@f=UvxArz zgb(^nGt`tC`Dk4wPnek1o(XXd&}=U{X>Q@0acH z2kgqhVx3d4dE<<2?Jl01_72qns^dOn*Nju4Uhf21E5+vB zFhh5LYbh%{pj`6m;`EufW>^qg{6SgiaH2NS>Emb1p=0P#w$`9;M{2zC^VqmX)2KEX zVNx@|C_t5YMhUACjSkt5T}OgBiZ{Jyo_^^@gShlx~blbT!d-c75BG9 z5kXjQhdQX9}!V`~W{S!+a)6LzuF1BB?x%Xki^# ztU~K|ts7-bgz2P9hCyr*Z-Q6?mON7cI?wi{Pm3CG(3JiGu<2!CY$Ym~7kqrebCVbF zof^KNS2GVd1X?r$hTLr#EBk(w$yf3?`YzC>Z*pw4$69%M)YGDUE38@~E z-GeYy91wHcxI|1Sh=7&n9@P24f&J`#!mQ6uMt)0&RpPSDYpn=orFgkJ7t06p(-_4# zln-I?-1~qf<{-V!k8(p$BRy?Tg38a?oFY>n`rhIZ*`N!Y!skWJG%-3)_vIV&?^Lfv zerEUePY4#9S{;W;T2Nv-$Qfd$%BNbbNMgJ@R->W5wVcXn!L9SPTp%cW!^0M{V$~_5 zJjWg!TPP4p(dqf1Nn3jnRmqLCh8eNm%|BcMp{awL*s2C`{0t>o_%rZhPPLgRRJq&i zPkwOpp_O`rtdN&M<_FQj+AW@rR}u|8ShQn1`aJ{RL(4VN)Uh>gaqL>meP17_?(-<&I4kLs_XCf=-rqh@YsspK29IRBY&1&>pRBsO9 zmAikF=Y^Nv=F`g0y;&EgULiMn57-L4jga?!&x{(|ucryu5oX<)m=m}oo2IG!jX0Il zi{C4`hks_zO1=Tz7B3`r1B=V~@6)Sn&N?Z;Ry7dDoG*}#sm}SN#iu9fcn|}}XTWPW zmd;nC>@q?cJ|2;kKPTHFLBF}Wli;LYx&i|yof>D=^puKVUmh@)p;~^A7j!Gwk-n= z>e`)m%cJ5Emv}_|(*YR|XGJDFjPIY=5R~hOce=e2wfrbZuKXDfeEuU|;fI@GHV@Fk z({77fGMd{=x(z>T`rXm*m^WenHVcxlBtNw+ zdTuGRl2jl;sH=>SCvSTStbx?ZS^ULsU$CG`r6hgM6duSV#8$Q&ncoQo*|%>KutFrERV@hO7B0H77+z6izMvGG%(RKQg<=zBY*7@I* ze4nF=hVS{@!a~=TF^nFFm(yS_<3>*eHul3Tc4avyd2_&iwyM6$xw${p>9_3=$xS6- z&sE%XPaPFKEi%UqCY$JQU@L-M3-}!Nzw2MJMRcs2eGO-nX(<~=JS9J1XbMOjlJqZBK^?nfK zk3HHxI6=qLL>9nO9r`}uR#aPHHEA12v};4I>JWUdp}a6DUliF`M6rh~R0spd)iNEs z`h%HxAZMx=GArZtDQW8AoIy_NGeM^f2zV)XG}7b&>}oLONy3)HLxQ!8Ma(&d>$&S? z8gOgr(Iw-I*O8q!_z9*#G6hY)%5$aZBz;h(P)lbWW2y;GhXH3RGfI+2X? zy>2$lU}K0m*HjpUuQ?m_Un+8n!K*VY%nt@9qeRM%}L*ageVxJEXCvt^3phi6PWx|c+<88sAZhew3Yr!4P?AT~ zjSgk3;Ar~2KU6V#Wc3>;^J8MO{@Q1Pf|Q|$zU~jnb5P}@1I%o!aJk42b1<0v@S)T0 z{#wfFhFxMOub=$Si;ITnWatQX%d|=K7mhhN!(8lJ=??W-VA$48(A&ky2BPHsOIF>I z7H%RRts+3Z;PmK!$L;qb2hL_(0P@A$SVw+u0|wmSuVOWiewFkjCYBp@%U?e3vw;(9 zx)jl<+8Qp}w}H7!*a-R5v*}Ww3!PEaphNBDjfF2L8rgd@buE<7$p=G##%E9cQ#;~a zgLTP=X_P?&>Ss1dvO3zm1-o1-ZC`=&f8n$tv%f317xznY2A zzp#hwzB7P5ffIt?H)f={)Sy2%;vFSeZ|5YGr}L<~>qmYTj7*e_4Z^>C*5oJ^(MVI; zL7T5`qly~AOrvPPU>=Q#w|%2{ntpAdmNA2vu;6) z=%Z2YFurP!5;LGI8&7h&A9Cv%m5~ody@?1?zHJ&4n^xV9R`?t6D%V_A575Q+miUzS zcBhAFG1O5N{~?IGavhRD=hs6RHRWJ$##+IsJcH4JJGZiUaykDmi+boKQVe~?-<&gX!1rT_&HlldcITJ+5M)U zid$G@CnMHl*@FXWH51BJkb~7s{{|_6`pj#=8hnk>M>GcT62U z{_0;f+9>v4(iT`kacs@#DQj3+3){SWZhm;>9tJCz^&gh$OyEyV_e4cV@pR86l`nlP zZa8%bQyz&&;>q6|2!H1t8$tcLH0D{u))V)_p6fBd$@3axUl-@;V$xWLvi+VeZthz^ zY@@TmMKD(oPcN(xv;tZ_;R^U9;O?qbip@IiX#N&$B)@LDr)vCDoUzn4R zI6P4%zRgr4^$gO}>GnmW9>-!+pD3xANbb*jVMtjyK(D^jCR3_@fiGPW3oi6LBPEL> z+|)7)(B+#Q;f_+BiM)UJJvb)}6TkT=^vt2)1&|>I25jV7b0f!N$_S%l0?RNLqu%{f zB?E1V|IF`|`vT9~9qkh6Sk${9HT;v8NT|=@8j-G~WMF(z9c844QO8WWY4LtmfU(1y zh{mXPxbd=eu&E)#Q|WwIShFit_n#+B7UYfOnAf$+F(W10Ay&MW#4EuV8rx1rME{qo zr6NUN;7{M`ViX{z`2V1L3m%hYiInvxb1K|#?#7IDR6UPX5B$}a71;2l{5pd-Or~z1 zC?^LcR;B+LuGBjD@uQupk)vh}F4q_pk1n+D(m!ISb;V|>fI#^{pnW{j1gJ9t{kT(9EMLyezaMC$(Ag9 zl_C%qwuU}og?H|)9Yr#Kq)JtI6*cV?RSc>S3~l~mW3Kz?hpBRqXn4tNK8Z!s`-f*< zp`cKu?-6(YNR6u3U3_Qa18pUQfs$t{CWjWYJW1~j_Qns?jSz;g#dGQXgfR9e^8($( zO;>3g-49%1Oy`y4X_T^x3&YdTu?NDDB{9+(9mUc8Us%f5wv45!#^hNBN0V683I%L)V_zZb9nuP+l=pv>Q~`*8 zZI6+TsxR^kIr>oBEfj$;hjb$E&nuX%JwIO}e=FE?M~ur_*>hMks%6|fU2xM`oMrcn zQ?JptVCQkY{g1>VujC2 z`vf>y!0cMR#!yK<2qPk;Nc1>*+u1oc49Ct`-jG&B%=+9*HY!s_Sjt12zrIS9y5*AQ z`?@tF;oAE+9Z+Y#&68W!y{69{gb()cu!T?sUzwDvl<;XZo9&QQ^!%eQSpq%E>@)R7 z_9|4(WcgPW4B)9ud*PT6H*0J%Xag3HQz|y-b@F9cQh8g%6uzbJg0Q!|#m!ku#X7FZ zswI7@AJeqgzLzmO5zs|+Rsg?$JiOg%)V9)^pZCF1-*UgcWx_q;VA-|J4?dF&(7!eN zEbjjfReRoTwoWg`E3I7mPAhbg-?NN%$s-8~ljEM)N|J2Ov1L)C%^w6jU(e>uZ?)5uKIeRwE$IhoZt@QED z^j4zkO0q(F9I5!&6M>~ zfmsX>%pYXd_1)@XTOkuoowX?vTUjT9b3`T*uJi$9Kf2}=MVDR*`;2g2J&IaS@%NLT zupC4gIAUI6SW|ZrbCar`IG)bG>G+YOmRgVjWHm$u-MXIRG3FvcTMB(fB3~b3q(L>J zp*+;>D`h^Ig4tUm-^LD|@5$?~0(=L#rkIw4KGX(eD(=fVn!rwp`RUKRQzaFBtvYVf znP{h4lc0oX*L;o~$pe{J2>r9djqtbD8(4JajGCj8d^->_KTY_44wh9z=y1>tkI&L7 zwBx3f>Sn^reQmtTbl+^6Ii-8OBZ#{sgC~*=yo0M(zd(TucPqNWjbFdQY1k&v_U<%H z`>B5xh!_&8gt%+hmA;)fX#JbOLPfSpeuv6z?SEe_R#TyZPoubfo3{<7`Z z$Y+YRIzSwEvAJ3htI0sK@;=Tul{XGD*%EgJBXud_^<(RsH|tki72J7*mY-Ow{@Sy4 zo%p#=GDjosJ+8TT(r#J%XVk5n8_&i$ROAYW^W z|8e5lxsOPE_afXC#vSuq&04fd!GwK-Cl!WAvZ@3cYA-A3AEZS37c^Q!ui70ml?3!;OWE2yLkggn!Y<97&tTV z)NQBJ+lj>_buBP+ZKdu?RcA)_E$7sCYT+{b8boNBe)atAFWI+C_l~fe5L@r);(sru zr{CgA=_)jbj=LipPZZ3M)hTPuk!$sGe~JrR|Dt$xfe?c=vOQ2-ppGTafMk1sY@UXg zqEDiZ7cQ!{Yu_{uomRe_ZHmRYd#-NoVKh)aQboqc14H4~DoLvqQ~Ix-)Ue;p5Q;!Y zf~2VyugT)j#pk=XtTL6t%Qi>KdK)sk(%xFH z!&`~h%jp|d%hUIB?ntN>+18tkc?^MFo%JJIDHVsboo!2{2j!?m%`g@}Q#r8WhP-iq z&#EXeen_?Far2i8VrG1Kk_4D8+(8xk;sx}0m=aq{qz!BL^4*rT?2W$fDj^Y`m~dNC zpX-||;=JuXWVg5n_3D48l-<&-UcUzgAV;UjZc%TVxcv`pZvho$*S!y`2q;p5bc2L| zAl)s3bR#XHG$@@jB1lQMGy>8fEjdbeH^M004MR-)Z}cgj_xHT(dDr^BwbTXU%suDq zeVuE^IcIOj@e75;aIN9VAp!UuuBBVLmErl+g(uL5d7Lum9rZ9;sV|$sSmb!RWw%*9 zX9%;u5r4#WW_7snCB9~Nq8DnMbuZ(68Bm_$#!1t&YD7kk!>UQgSS)- zgtCa)D0b)-Pg%r?p2fl67VwrU+sleo4`aH*lSW7%ZDeb2G_4JGJLJJ*!Gr0vv-_0X zv-52^d-q-(ZWuJzpR#xM5!kki+bIF}*W%>Tpi;iZaszsI zLth%=A;Q10RwPH5`MPkYTEQL`-SEIJXs^Ft+mOyixQJC^q5&=~mV@N5QyuV(0==2T z)yaeKO?^O5VuHMyf?R@6~!6;O$Up~J*1bl{2Qle zUC=2mSmr+JhPUKq;##Sw%bPI=%itRM1QLA-3thxZ347XP8XaZbsV2aU^zngBV<>D* zU?ijC4Vtn-^O^5bSm}R9@5XpvURdL!lde`15}wyKrPZVy*UwWJ(H=jP3UWXg6q;qW zY*2|s2<6tzbRMnP9BvqE0>vFK43vPT(~q$qp1=86TY5{(zqHwp7tflTOHF6-Wjv_a z@=sSVN5E+8rz<#9oqKpLTOO<2I^q6sHGcX9={O^oG`5A)^^fA2k?nM}TDsU;)4^3M z6hF&6s&(#VVWk_uFh-KQ%L~IYN(BnQk3Ee7LAl#Pp!)PLW12o2RVwrxdn!*y(no0;?~L-UoZ7aLzq_PvuW;P^b9{jdDV>=_hig($AsHwj*6AdrwE5GwO- zG;tsFDwu_7btI0q)~?mCySSGSvVzfpi5PDynrWwqmR_5KMVoAXN6r~($^j{PBG-mI z?-M=<3+I82$z|yeEs2TEnpX=%+SNta4rbAss8YdtRb)!D3191M$(S|i8w`yf_O5@Q!wF%7L@v~06)r2)>()YcYmC+d9d`1< zN5`Ss>dLB&>C`rm^<{*MdK1v~I&pPa+oE>2{cFLkZx$Ovc}^l-cdyB0LuI71@3uv! zXGac}c|JOZt@-Nvo`q9~Uc^l8(k5_BUzrr2+{F*z@RllsE?GTTI>Yv)8E?$4&(^^> zkD*3@-Cugimq|c|Tcezu(5F*OTA%FRcw#W(20Z{WkxZua%1cL^LSFClT#gQvYQ=K0 z9n1)OkJ=^gVMdQC+~OP?(LRfq%SsMJim;+o7G7IE7@IesM9=CQrT%yjP~9utm;Z_G zIK*r4;o#^YPXPnU2yg${v+7xpe47BRy&v4uWoNty+PlsqR{iOI-t80v90*`K{rh~r zbJbt#-P{dMH?&b9Agy*uJ_yld?xzaJ z?3e`V`j^^=k}9yXf~9LCSViCG=r9yo2xibVZmUK#hLN48lHY+adhgRFup2`;MOf_4 z6oajE%9T|IzqVGo%~Buj%{q)`b!S$}vsg<%&Db)VDd@s1m2oQ1QOQ%@k_15pR(JN( zwAWe1?2Wq?VbRpHHfPTH?dv8!%e66}N$0e+`$kQRM1XR?c%ZUc0(rmYJ#L7O>un=+4NNIBFTp^1z4}^x6k-quOxjp0-G6#KBN`>>2jrAShR`LN+~5?E6)= zPN5_nrZe?87~h?K9@Y^J?~~H`*{F({6zK5ObAk8pz@u`&0*GBD#FDS#P&GKiE}S$N zpv}sc+8P>}tQoO?BAzf7&6550;@*K_EOX&t&dFAw6OmquNjvGF+7rEF`#viN9`smw zi(C)lYRjD+ugS?S%T!Id&Uq`u!ahr(&1UWPUHEe0TTAnHgx5Cb?uNIUA5vgV2;}aW zy~NE;5&6_ei6FIIz3}#4aPgO!lN3ZMn`TH^s6P2*RMRq*5*atOTa|y2g3t!_O!f0S z9nYx=--oB~GDl|9KpFsU*u{_)9#f(E(lh?}1~On#_aZ z4Ma#*7c@)2lMbcq{_=v@RcTQ!_XJipENafx@nGJi53dzrWjllnw2Ij40kXBZBn8H%39(0PhB%( zrR3s(jk=EP5qF(m>R8F!HfDZL1==PR-4-cRwBS#ZCBHXEf$JLVhe3Un9dFIt?Z$h8 z^-w>Jlv1{lO}sc`(|Dl6T%KE`;D91E?8_qk$Yvzn^M2N(k(vsAx))e1Ej=}xU21KJ zx=#Up%5o3D)^zlNoU(iQ+Bdj@mv4}#N30jf1ss^XLqbTjkXVJTwdvWoD~#pII@GX4 zS%4?CH%}0+4(Mx7o&nVrYwi8CIt7Yw9V#~Gx%7?l;F&(dO+g7&9fzT(IoOk-eq#sR z)%({fk_YpH{d&9*JAH^1dT##oww_ZaY_j@Kj@xq|uPaUMahPr1J@#}^zx61bt-Ym@ zLEcf@7L$WM`>j^17LGFsGtUKeOC;PX4Z>4B}| zyh%#rRE2%MEI%r2Xh!Jx!$^x(l$G^RXCNq+mQr39!|ztlUs<}FFpH?)^59iM~; z?~gw*REV(joCzzJcNBj6X~FV=*;6G zp@2Ig8cyDv1W|!m&0YlC?^tOhJN2}Z*uwIQth_c zFF-JPb3X7c` z^l>&BNf+O{K6q>U7zb`2S){$Q>xwX&=O|aPa&S+_av)E}^HS zLg@rU-GL?nG;`zFAf4&Z`6cn=@lGOWE^O#F^#LyprB%wpS^z=7x$x(<7BV!%q`AgH zUf26<8m$GCk6Lr@;RwD;XtrkDI(Nt2-1G?gZQEHl>=3@ZX!yldOZ++Og5VxflTC+O7ks?WT9HS_+jBH3E!a&xS%s>QDH-#wRkCdcMh+qa*s} zUaO%gP+XCx!E`VABh`LNDH4FOWwhZatgfQ1kWjufHt+0dj0Q6mQGrL0@z8e$@VA+e zpb*T=vUN-i?>f0>?4*&e3+;g#hAvH+HbL`-_A33o^43EcMY$fb@!?)(nVABmtxL2| zn<24sK>Z=feKU3?bF*K8tAHTG$-ARM>$Ku#Ea86?NSF^j_#o)b)#LyH(yn}3ivmGc zD-q||!Z*}I?sBQohq6*ceSM+F*@To{Ct4#FZ~0c>shpa~9(f@-^nzfYXcnJL924Yk z**eL1W8%d4Kyh+yk)Q26T^m1ED7(Ih)js?%Jnuh}b!i2NlV3!4s+LI5h4^*p z8H2AD99te|3dc=*XaH9f3RpIe8+as0SnZSOf_VoRd>60wy%R)_s2O-r4b){b$G1z% zYhmNP+R9`}Yq8`c=!xrB5?hwU@}SnOLlP#~*DdjCDJXx=BJ7QV9`AKvsHORh4RkCc z<{+IVXsg=Nx|rk;fBW#hY?>J@#5F1I6W+W2j&j@cH=;B+vgy>$;HcFGx~@?#E1{sU znMk1Ag-!MaJGBos8C$xn*WA@K?X?)vt|yBW3e2_VdPJBnD4mKj_-B02@j?nSI`*<| zeWL+3UrMt+zNDqY-^N}-8$@@!Y1w$-E+N2r48^~*F7~-AV-COV)QC}e1@eFuIt{$W zB}P9eBNLqma}4S6rg%-?Bu=Y3icdXOqk&8fq*op68lu$JQC?Lckp`lr8Aor;z7TD< za*vvoa7=dxrN~Nod(nub7M+Iy{edQ@N~eODEukwt#dA;Bxvu(8Qh7Y+=)#jmAoSEc z-DK2qI^pnq9}?gP*rC+6j$8j`Px(`_5Wkv}DTk3`A3BT?A=&&D#&*N|?XhqDuFZdn z@WZ^SIkWVzwG0of0Em+G=lsNvgDA1_oci(J`i0?ZY^c}@7Pwp+kEYl1Hli;<-qqL` zplv!C#pN0*!$o7dI3|06|DFKwOISCe4Z)A2jlr2mSge3#HzuPD-HmU<{4$Z(7ZmyF zyIQDL+{8at?Z19g!rW(v_WeQQ?ah<{25lt8ir*Vy0stMOL}+4>OZB%xKS9`=fIJ6U z&n*IZS6nZ7>?)a!0PTg|Ea7>eX{7wF)Vj+(j6#vz;0%5DMj7m!6YOr+yy)Q<()Tb} zIWG~0YR3=!Cig@fjFU>`+m=wXGBMV1@ug>;Cc&*4!2RpGYD}g3so(Ia-=3Xi(1wIb zd+KAiKInt+xe$A~q#9yZk-V#7kz^EaeE{>X?8-mOcQy^u!M3o-X!no&Vb8ZG$x;}z zsv5_v7AeR@R&>O&D-WZ9BCs3x?oN`eiX(!EgSJ|pdP_WUJ9TY*>ack85U49&7~g40 z8fZ$oVj8fLUtt53~K zo*q~;Lxk@Fef8QIN={x`H7vn)ZsG=F)~y7!OFp{cG8WMG@(p2djoW4;N3@nVR(Ew? zlul)t4E7NU_o~16_-Zs7e!RRccH9$pDL}9Cg-w~3r&`fpG!d*ZHiN7S;`n1!GSb!z zk30lJL(DW4Rg z0gZ28i@~+u_A@=n^)!Y7hnu@wf`t_gNoJlDInVxEF zc?aK<2)%o@OBQIF6!dzmqjw5FCFm)qE%t-zR>wwNGikV7mlbb>QR@RGN5Sx;WIUG+ zrx@UjCQZ*L)Ie8fDA`Rir8JhB?vmyIe zsUK82`QI}Y!`(3#Fqd8Cia!M~h21I|7BwhX=~cUDuE?)^4U3KcWcE_yac?ad!Dm3D z^E%8BP5A)ULOQuY_{F4hYMi0!W*Qw8T$ci+B;DGRP=zY zHrP!SNr8Z7$bpvm5#ch0#%g`zF^{AO&>qv%jBuWA*F``nuV8>?x>0#d7=>qAJg8rL z`^gsyWqnVQeh~Gf|8-@6- z6>EFKmy#KPDnr$l^2I`aN9Wld+Oj}BI(JKCk#xuE&AZw(Qa=(w@oKf^#)sj}(}T zHqhHiV`fxGSl2@Rd|-y1Wq50w6!M3(Bxt!hVEVH|tsN-UpWI8(&VkVxAfwMgDxWVt z{j!O{xL^IQ4vn0`cqh=6aYij1<&5vKMCglzfiqBkumVCDHnE(CG$GSJ|4D z)GhzKgJ*zRHQIbX3YPlA$q!A?S;2KgX$FJZ@$Ga)2YPe~v447{7TQRirSO{3zM?6n z1EBi|jt$QNQ1iCQp$eAiQcr-!sH2&ZMit_Exo+Bmvw9=+tsO4M<`Ue;W2=Z9Yk zCi&EK$#M0sFE9hoa)E*SjTORNI_`zAu$FPx4=#L>0~&3uJr9!wWi|nwVqQmg2=Kb~ zmHJdHrnp2P!G_m7L`U7o!*crt$v!-oCql?*Z9Zti!143`54?zMYB zGm!8%7b7Y!|7de5pD)ZRUo2dVkKLx5w)UxXkn654Xp9cyJOk6u1pIywHHYW1+swfl ztAl+<*Yy!&Ni(6YtG@ZEOQ;_=tBJJF_+weaKrVhQjJkB1Ce^K^Yw*8IV)TGC6%fj0VfsUFpJZ-vEvNqdvz~S7-Wze+C z>0e6V0<4#Vrz&h_jGmyD|qgA^$9(swb-UM35ts$?5$mz*Ywlz6ph_wzHC zQfsPk&btMeRLAw74OK6Lw8`rGoe!`NpGrHQlNkm39kBAl1Xto#Y--o~$Diwn{tz1* z37Y3_9mf%N{W->nXji*rw>HgNu5$Ri*q1MRw3G0Uq+K`YfUPjvtc6VJ2l%=#ix-8X z4PCFe+LtlZz)53R#U;ZZVIM*9uAw5`&_=$4cQG4}+9a!kO9TrNjlo$*2R$O%Lt0-( zdrcl1w!i#MbWS}$zTvIR^Rv<5BEhWSi(dM&S+dZ2Ur1(sk{&V)*njr1Et|a-n#X|= zDx-f^^oMacz8k4Jo2Zk(2%LHCF_+I5Etd0HHqDu!1X`-A!xt)I03&h9!IytvL7zgi z!%A4ov*>XRk zCM`#Zi;oNctk-0{Bf*?B(DOl-=fmXL#7J}?Ez(%h=@whMn$X)OxeU>^LT912H+;{A zaq!ib3U#6R4Yj?w1fjX2K@Na|m-xfqrCJKlq;2+rNMqkI0Oc%tS&E*+@MC(QGRoG) z_-Ke)w9lwo)eO|}y@!mz&ENQO`Va6=^i4>~_2qb1a`70_HP5yL8tt}cq~xsK1!6azyY{65f_=X>Qi9_a%e(SWuC7?ByR z`7+sk;a*m-;rA;aP=l;sY_8?bJCL7X_EZCKSpsLg6t^biGIcONzM!{w)(er3gCq?2};!(8Ems(z@3vRIVbZ8XLN zFl|TMFLx;dK2@N-KCMu_G%2lAlgC2AIl#U~+=$wJ)(+Hk-XQAq4&bha`q|)Xx+wvR zS!R8mTpaMcPejdnBUl!=!^*P`@&NgYoqdj%cs4h%_Q8RIq~~oIcO7Yd(~np^MH^F~ z;29^ywirtqN4w;+3!?qSUU8tm)$qgDcS$$y*1)gC6p3q-pdBB~0aN?2kn1Gc5~!q| zrc)HShGZ;@ir6XdiR2JO3I?}Tz~X=#Aj^7o)9M!S+Zuwe032;`wZ&0s!+(BuMG7dv zWBm;9ctA5MmzRU~+2;nIAe{hm$heg(r^M!M%tf5z#!+jh8M)15Gx9Ag#Bzc!6h~4nmvUnli*H3XVP0d;z1gpqTP(}*d?CWd+R=d0W1&TR-XX0JzV`)>pb{VUpT}?8!9o-PQEjj6R5voNK>vEas z;RE*iBYbW}8X{;whflpZq0wVlDDZHQ2-YS5a*mI15FmHe=PjC>h=Ed4~gQrNkEv7{)`W3Jo6ss1$aAohzDNQ zlm_q2yNgjcL5Mc|BzIR4yovYfND!DFyA5pSWFJ8__LtE*BS*W-OG`vSS z14<~zd*QptL0(v?u9wV``tWgX8C|_a27ObrIA;AX>36Y<2|On?sUN%g2!!_Z zs2LMP)+iHsnCY^SZ#b!|lSihG@Mn*^_Tcp-9*kIs)>qz)+_oclgem5%zdlM&BeDn% zI2vINIBY0pCmhbj#Me9zp&ri7pe;A!1db0$AoqeChbHtR4q_h0B`gCM9hRy8V<L zx-Ys|=L&gLvWy(CqO{I8WGsq`!ih%fDcL#_e1iK%(g3BMr|~v>7C=28K_WRlZ=F3m z&SwdcQ`Eej?$M|XS2UM67ctu1_5EoyLDVNVA+9RGo{)ECezDeR!-dVL%Xx42q!RML z=^ex;b{M`4XjuYJ!~31V4C_pJ$GHy|@z*N#WRe_Nk(-Q0wJ9@oGc6UY(}wmu3ySmW zR(SWQt?yWU5wGsNQ&uWS97_WYzq*SeLv1e7prg!pRbO~-3s=~Ns%x{hXS-Pc$mAgY zQUK)j(&@Sja_IUeHv+@WZ_%B@NCEADu?}*;$)NE*rsIn0^@e6SbP{VNSf$Ax9G$=A zS3cvRxDdzAeq`zC{jt|%mB-Dfc9#Uv&v4J-r77jf7V`c;nwTHe?1s9&-4?&B!rL~< zAiVANee;WBH%yJUYzv8&W-{F%Q~9+DKaA%@G(bfvDDphxbkARm<^IH6y*3E^`+@<9 zA9dc7Knbo(eTfqN1odRf)z-KZX_+w{dbq0ct*}J1cc9nZ^qn)Dar4T<&}T4EfZ&$shh91Jvga3M2#=3M_R_J7rf?9 z*Y^@@{lk@zs@)p}qSDEXFKnC=1!kE!nT7OJ_9+|crRuGGlgm#xDI7z2j*z>?zQ{r2 z&%V1RLr5e@6t=_^7cURt;*Z-QY~Ww$uQmR>qP%RzZ9*Mnx&v%q&6Bwun)TGm`gBY`+#XujC76-8)? z^n{)F{ca$Sf>t`NHJ)`GdYrtoXJPX^wL5MOBeGsTqwr2|JsjNTJxF3lUuJwf)mT74 z!kJ274pAgg&2B0g`NnnkhpW=(e`Q)BMPG4$B8Br?cW1bO+@tlasrgs4wPiO`)=$e- z+ap3dYWe}VgE;aLGgF4I2|MwE-L_Zyi>At(nkToJ+GmfgDwBe~ju?+W^OoOF_pCQM zy~@u|vKbQKk_CnA!Bw`V5uzK`Fur)fCci&R)hKHdZ`eBTwJoN+*jS>2@3lhrCm~28 zIo#dW>`4CXQYM48=&n12->Q@+9f504w^mgbJHbda2_(X?5FUHlDY16}1;33H|H{Dn z)KVW|>1-f0@4Kq)gE&q+T{n^VKG9#~aU7pFG-6mB-20&8LrFEZW78QyGyH%DzG;GL znQIW(Kert;Ggl0u-{E|tKgYxsUf}?xVc}#^-I=eWJ6J2?lxUrGY+{e-kLlDlcQ!zu z_lEKL9QNy;eh;GAtA!wNQ1DfWmy74m8sQ^^W>E~N@GLcPHIQ10+?foF7KwyGzAuIN zd-&)Ub;;TLIDsV29#W`e2-ByB98|Mi{WNnGX+ez1#r zUICp+upWnra1ad`ej5y5+%1PK)_!zV#IY%g-$zjul=bhJ6;Hhv*0C!dBz#{0i0FHU z1dgiD9(`}c zoBzBEMul4?mYFo%<-d`6qE8VCzdBnx{(V$+z7+;VJr>2YrjsWe9*fHRX%i=!Z9o0 z$7;2^A`;&mytnbpazT!Fc{f~~Rw46ysSz=6EpZ(B$lt_P%RDS1?-vXB~>|YhwtAwY5enA=)7>INCjwekHxlI2y>OX$v{!jdX;DVsj@0jhb|DdiM*l^-zv2Lk;0gj*F&)#Dmj8PA zBB%c>N!MSH;0|8}Bbgnb&HbGLTz{bD*B`q6Y#seO5yH<7{|g#^B|@sZH_id)B=>(> zS}Xl2|N0AvX3SeSjOyF|?F{C={PR2Wem?omfiH7@a^I!?X9fIf&;J)y5nS^{Az*^88+ZQw)2Rqg&Q>%2 z{}RnY+s%J1noBcVRMx!kZ_Vru?d|`~j$cB#Gm@rL0=bX(?@b{<3X>+_f6bZy1LIu6 zcsN_`-@_OQkn6(#;)z^}NCI`H`oDRbKBo`<6@FUqB;Z4~fPDres=oivzW85Un8{M3 zO(NYgVz5ADoymXl^r)Zz#pDf^ZDI-2 zhanN88OX7H=s&Sh*q>#80y047iQh6%@Y@pby~v1b3Mj;4q=cK{?h-TnbW}C@RDuPo zhWsw|IWXf8AlQ5iTtJsuSFhK9IRZr?(U51Vi_jp`F|`B;L?UXr4CXp}s%ET~zP1ug zXYVH(DKdL!aRNx%&54fm7lC^YdWKF^6Txu*gL%QhzTfCk2^!ULuby8lT$5VdU@ofs z{fuB&k|z_5h|O(td4>qY&zXa8h3zT2-MGg%=ZPc|A*sL~>#Fp;@kThh_pn3HYvJTjrd!qld2L6Ca6<`N%*NMd|ws(83_cU<{}g#BitAw zwUdjBk@IbUo*3qr=jRMO4#QW=pE<|VbF`AC0C};!FBc8_fae-=0Ba|K?VT8Kn%jDZ z|4b=}yK3kG*))UMqWzN;&hKvrcSA+bS5nIMzF&C%*r}6p_1W!oT!2APHQ9%Zt4C?G z0v|15j==MC7X$*O&m_;J?5UR{Yy9&r(nP!|A->O3E>j6MsC3jeZmI-Wf%FQH8p6{C z!S`u^kf+MOiDoYYb!JH^TtyPOYa^wo8NDF%JE+b)fi%#yvZjN(4u+60FyhuJ@`uXtrjI74$Q_Jj|2^ZHkqlh-)29NI?4%aHPoA z#i_k$eYVu?Nqjelver(TX20ELgg(>+6a-%$-2Vy$rl@87?d9~<7q9{4;dJZO)n>$r zt+>!M{ z_>E+*o^A6O0Ej3vpPiVwaH0uRmth4Lcx-g@t0UcZtbAyDGa4Nx-PCJ$F@al9DkuJU z#+|XMcKj@nRS@vOzK+kywEgh&G(pcB_9;C$3VF;d_iYZ{AV>rFaT%D1KD`#|^>`_z zrhhcIkRAuewa#O)UpGqM`O3QM#j$+RY-DPO>V8K<&^{1ZEZ~7paR((}EQPEi;F`uC z$AN&*9Xp|cxjzeg>kc~%;E{tlKho=wm{ypzw z7gjCgUGobmcJSpHuv^j0%LU1z3b6-DjK;cZ@R;}48xa$4Kl!Y)Q_oHZjF@>bw*ptg zb0Sobr{!pZzURR&0^UBp7Jo6+^ya(TyYS=o3^SR1W3dU6U0LZfISqA3A4OuUn@|i@ymP)e*2O&d;!%IqmWA1=3Er?KeG;b!nfR&e)B8 z0P}+ijG+5zO{lL2Z0GF1jcu|n7UC{}q^fk&g3n6jQg;1+@}&Pv*~r2}(Q;dcQbTcn zffG??AS+&z(v(&-K6)oE_kH2r+trNwX?HraKRm4G4EN0SxM%cXlJHNLbwOcX&am6N zi)D^6RZ=5MMaol8ctkN{#4u5w)tkh%WRRiel*%IBLDtopQiHs3pXLGuK0$H|m$Dg< z1VUiVuqNPdstP-6e4&Dx+4HJrTvOcscHIfjhbT30>@(^|M57ms>I@{qIW z+*NCgKfe!Bk<4$Q1IexQ6n9o~v>(J3e{L>iE^Xc#uFk}2vpGGb%=hRkIT%N~dOW{G zrBf12sHg`W)(JgNcImS5a;cfH7PT_$de`J=f~P5Xhxs0s%}M1;N=MXHEZdU@%%1l4 zbTPKLKy55VqPvqmA${kA`!gb#dz)o=TvL<5^&3iKV+pjiEA96((lGI;MLqVueenl= zs0=#VfT+8CTRl-Nn)t4S1uCmAO|zVCkUmo*2@Q8K6&3_kL%E&*QQ?ksK0QAb!8W1nD+du zi7^8#QJ@gGzUaEbxe9!HBr}?dal(WPm(awt@qnFF55YsH(ux|l5s&xuItF?MWyD4H za?g)v%@3kxZp&pA)@XoOm@w~aTR*8;GPQEd<~JmVCa0Vkm}5qlf+L4jj9 zM6~pA&4cN-jmZow`0B94Mk-A@c8{oOOV@+cCOIemRp&Dr|it6;|x8yGbY;F5lE5k5;u0DY(T8mFfxN%)3Kgo zJp8G>A8Isri_&n$*W+`Hn7)@YJDT+dEM|nLl>118 z2h>MFa9t$cveMqL>jTh{_3^oYH?xaItewPpA`+7SJkGZaN%1=vBI~dY3`VntUi$K& zl>Voa_NeSCpzI;rqUvjje zw^laf-K%p3sZ42%LIfLdaowtF`HReC=9#XcNWKC_p|9IOzeInvQKqZ?Achq?mQ98x z0~(Ifr7|=aTQaWoeJb#kwH%Z}MD=Ml2VN_vHnFZ!Ekc-uVaMsaU$=I})Q-}3$Da;2 z;o=73p?MiuvS|8ci%e^Ak-sb-(dU@=s)M4mFCv)8)=X4%Q3`KMJmeKcTBIz-K$(w$ zk?AYW%3&UBK51~@N@JM?$yR&2xKaHfa^|ZC0;?0v#g8yHQp=mW&w4v~a0||RO(xz! zQZE-%`W<#|FB4(B^B% zPLf610U0;%x2$COCo<0oi%h&(_cG*A`OY8@?kHSwYu6yskxg8yC_jzOX6lRdsPR3b z>XQBuK1^5h*~@9kmzaNLyW?FdntvvT2@5=oiP*ZZ4M%CHPGB! z`4W-B-xt8+?Wi6M%_p$0sVU%O+XhK<`$1wr-W%z<++M-hyn^<6pG)>{hNaZ&x}H#u z4L-D>jrf*5Bi^)A#Kox+MdUTlKKGtvDsE34Zvk zjSWniFFK-Oef5z^%AM7DBR1N?6qRH3j%r%9CaqWdx+Gn_qcvdR9ge?qoahRxd|pqH zx9l6>Jk+d^`CQ0hVkAwJ?!>zTm$l!R2tAZxWxOfJ^p?{*zw-}qDMwAkZcqDkazR)d z#zU8qcIC!rI$)*jZO2>3zXSR{@kA2<{g`o_AN*2L z9}<=_`HO(Fk>5zkSem|5Snx*KQObF>-d^ITY1!PR9->84hE03^W&tiagiC~08muA~ z->AWehS*%+h%{r1p$YC`EQ;}S;d^c&M=~Csn50vK2;DsCml8?&-GG>zMXAK=!jb{7 zmD+aH6Os^8NYIvTaDwd#j3O@Iw8F#^Az$w6Y7k{NI8xtyZ<)!Lvxq|L?Ko1_pG|U` zbwgjfOEPY33G9Oo$fRQ*S{3EjMa-1)1TmxOhY~%POFL|h^>f+lwt^?qRqlVp0Oqql zTfR#qZ;7Iob!%m|0>}V)LH7*`(P~EtUlGiY_8TErwK@_YR0BUdF0{ye9hewB@1c}^ zACI~oqhCU=2^ecW5$()^K|%S*&{i_CL`^U%P`4?7HKA&E3cb6KEq z06vx4{Lb&7zLacJ#ZXC!f-D%mT&DWIcOT&V&ac}z_$RS(Z&f03%)c_3ZUx-CwN zf{KdX{GSrT`MA~ixLquuQQJ?x?iB~3B_Sa>-`@#YS^i-$ zm9R$&u*jimjA-NTgHNDKW@d|+OWKHKN2R}rn5CvHDiJGcTTe&g!84`sX3i2C4FI*; z0BR3v-9I}4Ps-xr6Ok~f-g?6`Oco|*gSv~;L6mGBRdb7n+gA0-T= z`=7rH#*Q0@mpr=IS}E)7lAPHAo=AEy8?kWg&f6_H>g)Eq8M-uReZ(rXW$O&QEYU!l zc5>I9%A^Z50Jv2%0+ax9-(Zl*>z?3DM=;_HhBpzSub~vl)6@Lj_Fqslee{BDic&SU0OmiyLuvn)qRFO55++^e< z-D41eu9Or8J@WM^0N`Cr?K7M<&n|5Sv}JzZ?*H%>m>ZTqQ9^C-F)H*`SbfXF2)k^N z*ObR80Vj?I%R9yr>-T7Ss-Si}P^qS)I@Klrv8?In z?C-dB>*WeR5xOr2h&Q!SDC>1`0wIgl1#y`*-G5N5!71)pKMlaB8!1VMtm%_TY*&Nw zrvwyl;gTjcQGPDZE-^cD+Ni@FRlKj4W)$>0%Pu~C<?CL%^1*M#k4F_1EHZiH_#!IGH64p^ zsyf)JNCn1B)DIhe7qZIc&G3`D5@egO%K^|1Z?AyI0NyED`vQq*KwS^?oi-A*M;2QT6!;yt!XDn;BpM0Ib>N zDE0v0+|k>g_g=9c;a=SVoP$6&i&_()*1)3wIh&&v0~#O1Z>y&Acd=UgPxI_ZKBwqY zOy$v1yny0HR@AuOBEL1`(ixUcPLH+JKiB)Ka-E`DMn%Wro3yCrL}zEkD%;ZABaV1u^+Lh!g_5Dk~a99GK){KB%~ zcfK-+u4H8K`V}T_wb4W21|Z`75s~&d87s+Q}uaU%LWMCu5P@?b`lt`>ofHG>QIl^}1zbn$C9aPwQf;eX=-NLNx@ zM#M>&!}^D7KosuJdaLo%^8KZ%j7h6r|De+7(U-TMDSHF%NX1Gb@Qc)??OHF{p1u-} zZeI_Wt`A~>Yw(pk>|E0yq!4l8v_Lv2i0R?-2(5{}GTemn{}k=8fM# zlkM96lzD{EyW^Jr@M6fD{jjhe~q8**Fk8cB0HTTg9ZInP(Q z&s4lJSVSn~sm6z;OhDG{N=7}D9%A%t3T6+BL;O=$oVcSB%KNqj-@Y%6)&Dqs`}1809BA-_&x7-2=y15Y3PA8&MlV^8F`n;RDW2slbqyU5&NQjr|U>-)zfu zbV;6@pfO=;(+OgG6hvP-e>Us8`%(yP-hvoEi;`X1vdx$uz5q@#c)9-qhzUj#bfYF+ z^?r*91i#Eoh*_~K7#NQN&b=-W;N2Dxrd$BOEuJTxLMo0Jc7J4Yvfdx zQHsmAiEQ6VPU<*z0&nU5GC@#cqA!z|-gmyde?6=4!Wx9iA94jp-n93=GG3Rb@hOff zHhQ5NzT+Npi^T?|&^;}0k(nS-Ra#1wTw{g|sG&deYt{i)Igp_MW}R6q3zW?+KI7c- zK89WjfR2<>ez=qAgLnju#>?=HWrx`6O+K3wSUuY@U{Dw zaf-S6%y~Uv%iep^P>3AY%zQ3+Y2SfUa*P)Z%>|TP;Xu!gU!s0|@}|#v05Fx^-SE;| zEbBtq@z}g`mZ+wz7n81bh6^_d%lWXwb3caoMy7rGsmArJd>H=-kL~wGmKIT8u7@D> zCq`bxj13&tRpU4;d$5}^6$a%)?m_`(-lXrQJ|5eVY5CdPTUNmXJ9i0lywjx}X330U zsf>4a>izm!u3oQC{$`D)c|)v*;lnpxi>AtIkf$h_;lry!kuxgvi8r_v0>}muytSI- z>Lj!ea7t9^i>#2(YhK)bV+1v`G$y$w^O`n6!IY)SFS?uv+aB?BhX8Z3%ZGjb1e<@E zTI$gv;L^e8U)*e`!UQx1)-v)n6Adkq2W2_}zt)pi!gp+0Y(3TAq6x-F2dn`=>Z4+T z3p3Nm5bq_WL}4~*FNq;w_m-+-O|+U27QmTwR-(*uFuv)j=pmgIa!36ouBSPIF#51K z;@nrsO7yV%-NlXi!Af6eA8ggi%&4K(pRK3E|*|EGk z?NdDiSCT9txKj5_-dCv#G*3&=cRAj-y?Aw2E|kUuKnB4R5&aT*u#a|7h&;dp;Eb}D zbPApjqG7vel`O@}-@0!PEwAY38P)tG_UZ(SDCN9%0(=9D%4#dKpGFolSximIWa7~? zA|X#9j>!?VVP4_f{_b`gSRs>xm9Gqgj^aUWVY%WP?1!Qli>nLk1o6HuS#eJ_o{YVVDQE;h+GGw}E*Rn7rn^3g&H-_B=}fgN@w z6lm;tw)KrW)JxC`lxJ)Bh|WmvYWOY~0wnf;t83(BYp!>Wz77GwyR@%8iA^!!M3%Hy zObCp1NnZT$(=lSmF8v&dM^Z4paHuGCsfs)X>sE9AaC5LwzYs9?LG3#yLtfN!QiF_L zZfWhdyFR&qUCK{6sYK1Ijd>H^?1j?;1?nX&FdJOK=6Rvxv-{t;+#l4!@)nv!zW>u9YRGQ^6S#-l_4K1sU_*ywySbGG1vak$Gjlby{OU3M< z?b)0oP8|miCKMXm$=hi*slB9)!eL|3ZJBWwF&l|^{|VN;b{(a_veERWZ^4xP=z)Lw zXDs?M3-)`$y_Cxv_MUKYwXuG5r&#BD{Iy!j6Ep7{!4Is>(9p6Ni_`e&d@d-aMx2tL z6=Kd!zxI zcc!t=Q!x(v#}=z4yciO5X1Co3h*+$IpQz|D#Aeh@(IBxfQD1Cw9q)k5hf=yoCi)Y; zQW;qLg^p>9i#>J6pX2lwLgnJHvl)Ss-k!dv8{p33OG>nt%7VMdNC!Z5r&QF>6?Qid2W>u$b`SdnhvNLN8NQ zsC7uSpA-E=$J%4UF5=~wRuiwby&)H z^!NKnKr0jdg<#S-15CqQnwTm5PoX*Z&vpJdzTv#2&C@4gho~@`qOuA>coI-cV$^OC zp)&>|my+b)RmEI)6#@BV6x{MP@EA8eQLy5~%W>5RS@p$C53$zF%mPJ{;)G6dr>PKr z{I*XDG2=D&SEsMiOs9pVH!jR~7xRtZd@)8G4aI-h5#Re;wbjX1_#`KWfAO|P_3p_; zl#+<)_tgb72Mb}}8HLxpNdj<4d#?XKoLmIJ-EUOh#y~|3G?MvM=K-!;kGK6B*?_Pd zRyjbW3O2p2$7J$P?cB3&^}+EZ6EkolqmzfNRNgyawCPP>yh|ZAul5T5kIq z#4+<`DrWLRusAF+kiB2+*TbnsG0@A0a_=Az!Qq;kNblMo>BIs~i0aTte% z47J(RK(TxwLumxjR6yyMncb(&hkhcUZc|vRtSJU)NR#hlIuBj^0?(yp+(LO_%{dE7 z-G|X>drGJUS1a;ewm;oIk=F$2|4|2Y0ldue%I637KX_}&@^PFWE5NAZ{OTA2>zA49 z3M-Eg6*`j=L7@rBkmb=@6dJ&DadN2PA^L1A{zrMO{t75A*5{YabM3#Kj)-IK<{nTm zGueiOECunu@D0wp2`=&kqph&AoO^{a(f{u4SE>2f2%o}Z!9 z_F}vDUU|!&WoDwx8*g&715m$!k6_orb|E4v1IzfkCyCOa*`EGfxB`I5qUZn4^7czy z9tKe8*-~(SuGH^yKt1F<7AW!jx_g$pR-o{FNl8`VmViwC#Wd{u0iN2b4*(^U)MtZ{ zVwoh(?aikl`z1KhVweNBcgc|2LiG69{9rrIOWADSAKC_`u{~$1v@as)^BNe|xecEu zj7Dg%y=NJioSEI%rbh3*zI$U|xx5e~IOuFMND|>Au$ceuKwx6LENM~gQ6=RRcl+yP zm64DZI%LA2XAppD{nlyPCzSueufDMTP&;DkF;r~=gBHfZ20y2BU}yqH0s&-0Y$hen zc%%%^CKMGw6Wv71O3`9kW|htg4`fjgv{`9%7ug zvAEkQgySl8`1g6F=V2api)ke|6i0Kj4#H5wEy*9-i6ftfnI=AsPSHoCl`tVoC$Zz| zOHQ+5IgIy-@6=q%>Z&b!5w1u>UAd!%Ls(CjJpFiZg~QMB9i;q>t`D_DwztAs(*7ub z2>9Mg4QO`>cHL)YW{vnFZuL~b$A^~p{KBrcA|>sly-E4#_-n4WTDud6J+W0y`Ie&` zKbS`a61p!|zUl(65`>L_a7w}5bht=9s?&+)EjqPMNG&XHhQ^tNUE|cAluE}jU0;~} z8n|&pATHE%Odu5r{bPOtF={w-rdfI9_iPFMS&l8*+gtbw-whPno>gHueZ_Wb=7W%2 zTSwAtQZj%6ITv{OgP#EOuO0j(D{z-3Dt^`(HE=kM-UTL8C0hHTb_`i%!?U}3F=1bm zaKaa-r+>?^uSu(v=q&*Kb8MJnUuaZsg$A4PW#+hlt7cRsHET(I{YofP3P#dGq0~P0 z$VxJ8L-85=CmL{tkKsnX`-9fq&A;-nedkXOyRkvv#(|;&k@%QJo};W6$zT=el-j>K z^Ay!UW&Du&jnbn(8=4!8MRGhEvah2rn4vwd0{eCBu=JA25Rhem8$qNNEF^9)T9I!( z*0GC(bM%Q90?4fOYHw)v>HqpnWMfi;I~*Q|34)<xge}ep7U+t@vfjocB>rJTJbDj%@0WXk5^^u1#iqa52 z*4TU}KlTr>J^XxO)&%ePg%k%RSVm6`{oRWGho;{Syj%d(iewuFpwl7UZOHhVGfb3v z77s!CMddV>-!pr}*tWcp6jwg_YqbhhH2pxoj#TKXY%wyEc1;#(602OfhU~8 z3ckVI!Dz_gg3%4#5Uk(+flm0IHuXh5Xf`-e)(78Pq$ zx7hkft00I<`O#U}W0Nq}(QGFxVc|O(zPxm^3u$tEYG@Z2lt>AlL-rfiLHDmP(;tO7 zq}O#Bw@eCORQ`Jk?{O23fP^SGI5dldd8*g`2an43=cmsC%pqSg=#J*omj~7GMeF64aLftGKxtFw zLE^EH|4V|MjzR^VhVrAuK5hM)K%)`bO|-`y^KbDtA2t175W2dm$G=7o zeVH(f-Ql0!n*Ssq-Wf@mjCmguq@6WJ$68Zh!nsUNdA&Fnwp@EL$Hxt6LJ}nfnC-yC zDq=g9aDEI{r#Ok2!|?*FsD_cSD zGED{Mg_gh)gx4egMd%Y>`(|#=OLl?O zHej>iY=yEOaA$Mc1EK#Y?x?T~84KcmbmSr{{dg4{j0H2;M%FWJ7Zw&@_EXu_{UHkC zg#Rq>No=>7P-Pqu1G@=j^Fh49%nx-^ljxYJZe;fB9{v2&dS_JX)Slc5c5+3cQ->S- zW5^II{A^0?Rr$-cAru`GSxgx#duGb!F0#iG*AUNUSdHq^W66hm#c1QI`4I}o@tWx5 zG_;8Hj|m(MjEBX>0yMb$;NekL0&||E-WkGY6$E%?(S$K^lBZ99W5~Cp zN9d}p3J-q}oaO*J+G|L6b>A!2#?yEKwVz92mIY3z`tI+k!f_eD{w~4{5bW;cK z_%G?f&y1AS&l3H_dx8zE_r!6}bq=ln>iH!qqak3np1S}K{r`E7b+)OY_@UB0c^BfA z>YMMaj0BbU`u}C@0!a@Qfpo8+a*7gCqQEvUo0X5-!uOOvNe>YFziPf6Xm!Xv zIIWEs7)nEpq<~@QOJEJf%e6WMO#}6q^4)`|l&21T#tvu6z|pA>Jj%gxegysOTvT(8 zRb%=c^EQFurz`q8?xpzZ@y5z8n?tU;lF}FI*;c10(sA}(lqy-M)-+97!?nedblLD(0S^Pm&&=2+@KkW-Gy~E}E!v9gW7{v8K=bNeINnz66=Errh zlPj7tCDzbyT4M_7C;p$z>~d|iIN$~L>z$*9Isx}mtTh&&FCYNnlWB25LxjM3Dn3_{iSw}S zo>W#5yquh;q2vlvD8^uJK=_g>AQ48Nc~$UtoNY1fkp|E?PF8!QaH*6>jP* zEqfvO+A$dQ8}FjV)*|Obe-d{&=}MffV?)OyLIV|XQ;i^giSByu!ysAlG+7IY<_Aq(ufAN3iyHlcDdxF;ZXnXVI0x<;w zw^+bmw|-P-OMf1ViP6H&Mv9vMhWN2+M6LT2Vl3*SDKP1QFL6sXvNwmGZYgDLA?odwze|)*De!6!S?1X^&7k<^_zVMqx8;s-c=n`YgVm}m< zIcn%TLw4W0=!x1VH}zM*$-rb-jN6Sca*I9ZMBiIM(54>-ja0w-HVXPJ(2S)};4p|#x)_u*q zkMK_Ty!;!@g4pjyaXE}a8xoR1f<&FLMVGWzPtxdy?M=dl-rh|mY{6fYq2IY(z z33+r%yZ8qBXUJ}2(R&;y-%>5Abk(!J{FIZDiowYCm9QW|d&H@y=(+aqsQYm~vigz@ z!%lK+T(dY6%7~u!%o68+{Kcyfsy1e!7z%^JFVm}#4nFQA%PS#En>fqConr>uJ?*Fc z!l)R@rL8MWCE0({U?1NI8~B7!K|Da_L7ug|@ZkTA*Dw^#2_frpgT-i|wwMB^i`xI< z`na1d$pei3%rU58aUvP=c8iP-YJtnEldt@QK+C^{;gF51P=IzpXFcV=0%b27tMnayYWotxEHM3i+?6B!B@| zpY#7$C13pPH4c2c7t`OzDn|zT9|2K9;kn+S^|EUK5*BMzUSg%+y4I@jMC9cXW&dow zdzK8-!=hqs(0s6ho?6;i5lF5C`L9dKtMDjtvQo=-__Fd$Huh9`?#o&owDoT-K*B;Q zX2WKh8{ccEpX%uro}THEp;VPYB}ntI!;?tpj97yYmp@mUZKI9}ib$^i&PcE|w<3mq z%odj0dKo-B=JX)srC0-2Kdtdfn%cryHu5qH6E;{*_aKZ&&@Krx$y`kKUOcfO0;P99 z^IW{4g8vR?G38IGx$lX!7V_=hTNmU>qm%#ksSp`0b`0+7pvZgR{(13)UA2?7l}*Od z!0cEa6?ClQc8<)D!htuy0DH(pXgxJ%2xt%AuT%uX5fi;Wr;!YdPRd0$#w>nknjq%2 zqprEQ_Xk#IHLO8MWb@genIB>k{R>aWt zH8}F0P5TLuj{wdKM`{i=&7A*;62Dt5cBor-mD3;cX8`~{)@$Bs`Qq@;zN~Z%mbO{? zLW15H{wlD4RdC?-QO<9sBF6B$L>(N=VBqNH%3OF&4x8QlU zL#D2|1jnlzj31vz&dw8QxSIt&>8Nd7vfZV%ded#>3}T?7hNmHAZ$9K`Pf4V@IhR;< zJiU`jU-d#qAmS0Xu>A??H|%}7opnXv{v0gzG^11>j8!)Hn3Bd49^cG_th`PR4F3kN z<4823+?{)%8{U(|**I+&ON-m#I4YfWxP8Bhwq6!Fv!%nFskQ4a8FftqCB+kN4yB^I zOmKRKf!DDx$5828ql-Kd0LqXkfw*Y@>RsL*JG+6_z)3;sy~u+eQ)7Ni3By${1QGBm4?V=X3>_-qh*YlHQh+}SU)5gSbk%lj zly)2Pe+O0;IMUw^JbPjT7CgwTYy#PmGd3)x3@3R6pilVjU(oqP@nLBf+1ZUkuO?Je zc#%3`p+JO4apaAz88#WpBe0aKRc&+Jwn~C%lx~5{y1feRAEtmhoCmX^&U5p@*H`<> zwai^dYk&z^+^&GiXFI=IsN#F$B_U6}5sV_zjjVxrzA2VBm+AaPGI4ie>Q`?*!}fs) z>rJ=ll~)sL9?T9~0f9N@kQ1@!lm-fQdAPX9tzp&t1XQc6#N}$N=1M4G^OX8Fhwhr) zb|Xbe?bA&lGny7(J6-7t!zi$|u#MR8+0r8!2#I_~96Cf8o^WVSz%hwTR}Jp)LBO@^rbk^{rL?LiT&!@FStxY4E^h`CQhsOSwni z&BLH*ap%3Zl&qA5sqIftXlq~edJme~yWCO%WSl=?P)PNxJq?E%=AxI0I6_%&!PLFh z=28jZag|tIiBzwe06|2r-tg2PWdA}D?L~l_(h8`=T(~VKNkqa`mG=HYKq{j48uy%z zWUQ-iX2Sz{GcN^M3b9hg-~H?Os2J9U_az65ZXD%z7gXg-UgrqX zp9&^Vu%-6DnpQ?|lRw5;R{`DzfDsl=t4b3v0

|M*`rAtr5>|G^a~xfaGXx?m;_g6>k2O+ET7yyWdjU3&J^% zmPdQWymrt@Obg8v30E>c6_0x%@QwMpDgxXLLYZpvkhrBeXF)i?ge?!o+R#uU`qJ56 zAScY>xDogl!}H+;IK6!eh9A-)mroSU$y>s_)DU=DUh_boxz~b|Z-jY4=~^FjnE=~w z!mOy=J7iSdt!pf%dkl@7r=aHLL~_9K2GES;sDo zmCr`*^AoH~NqaQ}YWj)RE_|ahOe^+2@NQo_L3K{K6;GttOKO2GV_RD?PgmNdc&&qT zv&d1DaPT#{mhY|8YMp-}#^zuAh>_ny;vV`axyX@ehc)oO(wa|BA+|rmNs_l%SQUCL zRH$*B1aU!~J;P(vX{CnCk>@wS`@{Nn1!q5RD8X3;Kal>+oLn=rCyxZ#L3a=SZVZ@% z#fPuq z*N^m+Qo>B3Y)6@Sr{|Ur@!6B>B78}Gh65fAB1>Ar5ij!RLLd7;i;SUPd66ObTQCBJ z3NAbL${?5rmejCMj2oi@+o?sQjYnzmftB%rcW)Rb|f~MTtX{uDe41s`6I#SH5TNunl^FPyxXjvFwdcNwh zm<04tL=3m%%fb%t-o(j!<0)+wJuNQ|j0r ztOU{;Yff7114#xiJ8+dw_#s^{0l13n95=FRQ|kMV-|DX~)kQwzE57#d85*8JtBaA% zxwTyCrem zTV9beb^_u1#BpvQUYYZ+8%Nt|x*8O^`L7$N>+7L0qg5fyrKx!UtaT8S8|nP0SXrP= z`}fNKj~b{uCX|KmLj&CwzhJrFg%Qpjj#u9p4705Rffj>!7l<{g>IMNsW+3qi60cd` zQW8fs_?K!!Wb73EM3n1iNZa5K%$q-lZZb@w=-R`~%CI@+QFvja?7QW1H&7+7xaVs! zOr~mkV@9m1V=AlDJe6Jau&;4K&C6eE&lISm?2A0*tvDCNz6+Wz%cwvpy@YhnNv7&L zfOcPFE0)^xt6uu_IOjtELqZ2)y5#IS5Cn3P(hv)irOIj@Xab;yQ^=cD;FZsr5t$w- zxu}U@DT(k2WvR$_U`jpn#GS`K?t7Y8&?L9}`zs-DNT?slI}@)>6_QtEXCQF$c!hv$ zk#y!sWpMrBMnG#G#T~fbY5T zglZ-E$o=n-tP+RSFkNegA&eiZdj$L*uK(&OfFcxXj$~{od?dzjU3>dJO#`I zptIJBTxGM8X~!qKdtL_i>+4Dfw@I;6Vq2a#5W-5DtoOQLvKmCN;w6*xe({wV{^BaE z`E}u94qW!0lT1+&|GWqc(l)&l@tdOTlVp8a4$PBqUZTNH(-3*4YOO&H846=c-f6h@ z!2O_kA-hB;{0Ehv(SMddlOXPayfzR;-PnFL?claPt&S;ps!kJJ z*4IWXSvTBW>VPSKKpST#(wb+sRCaCl5*KipY1ehQt&H9D#3XA^sB{cn`-Mq~nWG{9 zIW<^TpNuQGzWx1?{Rjhk@JiG=wCv4^#VsoV*7m;~#)q-)v zPs^WtOupz%6Rg!wQ>pqot^@@V!nByq`O_FHzO^ADuZyW4HC}&L;4VIS_cj&_>Iy(< zE0BRGgZ3(v2xzp>#r2c0W>S({z(C6 z!-uB?51Zu`ri_c1+~v3Kfea7!vm^B-o_99;x)+iSZ9q=RiTrx9|K`S_k}#uBkuq&Y zR61%5Ty*;fnT`i9hYl!Y9z~0@Fp6S^MWw zd+%?{z5Sj!I#1oy#bf>2*xq^O?%_mqPXRXWdUkz4Lz8>lu(s2?mywkE7K(HIY&(za z7&CY*thPr`SLRn4rlA3w3ZJ+2ps2+-qi75ZJ^U&3k77}e6&q|P*g7f(ktk~z%&-3a zj*~C$vx%Qf9XBoo;?LdU+rK-JemHFq25(3v*|D6H!FWozE&1;A1V_JKOrzsvArn~| z;>XpPk?2I;XW%V&tD7txbOlR)l7U#ytI_XQoNgn&UN};#@>2aeMO?}Zg(dWYHE5jv zND45f1)lHYUS3}TQo~MuvJzQ`py>&r(ZX8`NKdr9>0PgjV+!zE8^Xz zMLHbI09HD#>hSRQz6LBuG)6JpZ)(U=Vo@Czi zC*sVLy~AE)5ED8=gDY_)M`~K@kh+f_@IiI9-MH#cWRxXu)b826DjueWWLh$yF-Ie< z1*{WaK^3=mrgC6$;U$3E-on1W!JY{EfaTLkQIWW&LjoOdZ#1T5aEvfcFJBIv&SbyE($nRhVbd*rwFfUX$mUGft6k z=U@j$atz<1AeQ@pPEEj1_w+6E64)5VTN{$WEjzV@bRqd^M4QF3#E;B6u)rLY#iTS* zRhlX~VEw3;nKf>Wz~K=Iu7@Ka3{06$n(JA~U+4Kkrm5Jc788WcAWCwz6S|AC69wL{O; zgOH>4<(Ja6a2$tse+*dXQ(elk%gD2efKt`wg_|p!_25#9{wx$3lgDv9i8+B*jPm!2 z2L4Q{X95G($jjkJwQgOzOWj3^6%p}HjEvj)2UPyQa^_2a@;MG4%aC=uHf21A{Pfn~ z2G&2}P`aa028YN6{eE+OG;BOVUW|dW_5&Yx z@3f{w1fp_LZ^{^~j)k=@_~Y(l&-GePy!P)lCGR(1=qrS6?rgAebOtvj(=ccl=!p!y zFq=Sto%bd5%twVa^_(M7Z>S$2n!hE-c zg!tHVnH~Xxn5Xf+?D9S%Wju>s`Kjp<3u}@+5u<03sxqtzuY`>PSKt0S};U~=~&f@;;XNloqm5~Z(-Jy7pI z1oKUN|JO_V>7g?7R2BAr^w~$}t`LI}xKw}+9wvzu^gMXA05+chEQ}apn?M`Dr5pLM zRCPZVb=wMgNOC zla!u`lyd$I$-vL=dTR_#g(ZlOZz!>~45<z!oyZ=EQgJY#?Q-ehk{&)wi!tPOJP1nK1 z^}J17T~ogpz6{bp2u5R*OPE;c`HhrI(GVN$VkFwSKql@+=9E4E&dT_>E~cP>kV`ZC z*?>H8_7EB(?~Ef;Le=y`S|o~GW#1@dBH>a2?y;W@1`!7tEiLX0XQsr~ z5b5v)hIpa!MqFZJK8T@aSL6WFo^je2udFnMyDxIVfp>6qMdMwC8DOMHs%%C|MrPmjIU>G1gKY^~zMJ|0Ky(n~ty`$SH>-bR8U^yBa2kMTAk7YQ;Vi6T#u%NZT@JWb!kx{)=40Fv^A`pgp1Gjo=EG?)Zu~MeyOnBc%Jw)N{ z5nE0u+Q}j8swI^FPPvCX@$CaQRaVNcKF!w#wZ;#PWQD_vP2D5cYq zZ+m8B=GXb4ct4*sm#5 zI<3dF!MiZtdT>V|iG}ZfHq|=t|!;KAZbmRshw9_-^?ANL01ghU}iQ!fvFkLEUIgZ$%9rXtF3FRmJ)y{J< ztug=QO*fFBdgY(<87FL?o@7^ljC@=*e5$@Q z-xCsAk%--t_5?a1PqL0u*6s*(>J&oCO0b_D&*U-B$PxZ~%nt90V0W4D;u) z0;VC(QPN@#DxIe%xjk|8s=+HN#VRaiKMZfkSss4v33YarkFMasXJBJ%Z(QCu{Isj( zh?>K}L`-QyA$#SQf7|}Bq9YpCUljPY8$=EJ7;H~PkM-!wlFnPI8hy;j-~zfK&D5Vj zZjkIJl?^C&DKwIT$=qUR^jvuO5d90>puapGCR36JO+(Zr`>eJ4o1h+E8;SE`5<4l8 zE}5LrEIR0wcy^7=mTYM!y0Z2pBrT0h7S^44p!`trCM4?lHfVZG#KokJ9Ex9}`-~f$ z?o|{TuD5#vUXj7gzgqXzxcxHX3U?5Q2TDtYmTan^7HRNwk-Cv;-E9#Uubo?N<;Gf6 zMY`sKA&PaRMkChb(PhMwxjnEGrqYvrG&5XdtK59?u>uc#fF%6U0@5{g0?77d53`h$nxVx9*nII}cr< zGM6yXFT<1^w@hGf%@EDEAoxf1<|ECanG;n>PN7=5D8qqh@TM0TPC^V9(aa|%`(`qD zE;eq!6j80yak+*{j*dV#S6S92G@rPK{&XqanUjozDg#?2&;Hg#I zI~-v-PieBry6^-~Rg*#Aj}7je5t}*9fZ;a1AZS(n1EtPhJF+VL&lnKp=`<6Vu#ke# zhi-Ho`rD;-IIY9!+EI>Wj*g@qygvkKQ$(InG(1*YPSeI0?nebH@jD{VW+7;V!7^Q> zR%Vcv(|LL#RKhXMIJ~t6NL6wPz7xq-BNzlYK_oZPji&^Io+qX$D_)A)OH^I9fXLDk z9MPV5;(E0`oOs?fU!V!!y>w@(_^S-|KwJ&aa+vgT`fNzYmk0 zWnghXb`!G%l;p?M?iT{!N2+x5ZW~Z%dGsK}XNXAvCOq6>k#p?kzEC9cjp;0Z(?dsx zq{0Gv9H`AR&hV{DR*$5A=9{Jf_X%uM6uAX&lYsV!G|7P^Zwhf=?(Mm}5xM>w>6l&6 zH&7^t_H%pq>}GsSa%&{%+{hVu)0i|R$?j~*M<4agV8l^*`IKY%Lb`^Kesk^tBvubU znzhFZs=nuX_t#XpUa=#>atPDobC~pM}q5e@1fKTNP^3XV@suv0RXQoZ91qHe_2{9iTMm z8uP}Y>q;6L@#>yiyvxuHqAb)GrYkj?=1AcVN<=C8EXX9Vt;m;Z;im=S#*(HR?G#bF ziH*%A{0Zm23!Ck>!-(gYin}BE$jSM>%Bd_aq79CW!vSYm;>r)K&;aa-0{WKzZr^wK zDbb{22pS^aI*VgCEdYnx8${Ua?Pas#uFD}Cpa-g**w5Cp=H!_ zbjJZ)e;USF%O8qO%M;&FdWhFrpo^Ri~pSFM=yK-YGg1;iPl=`ZLd^ zuiUOf^j6sHvSLCa%+GB9$c(a$p6>`^rkS(x5X5I7=+yvi7^yB&7NC)_p`F0x6?a8e z&hP7k>4iOKf;D$Oks&AI*i4#_(zAcugGgQ?1`Q7HNYZD;BchfXEDm7hNQMImo@sg5 zABht!cHlh_X>V2^mI5Q;tDf>7Iod@J#Kc)lduF zG3=_5s4ZBEOrAAR2U`7;k)GsC(6-ez+A62hC zsM>in={c==0Jm2kFh4`TTRrh#ljJTZ{Vfr55-`l;SY+&FD)Fh46y|>newqGLnu|(s z%!B0mS7cO7w1mD5Uqyu!J*B~DIf(@393Jg1qA`{mS^o>VPqrT4kL=mIFR9^@aTVw8 zpc!|0hDsm2jPvQik;OcxTnTxu;_5mFBNybirIdmu99WBL>k|adCEC)4ITUveH8+}? zFS9#POw+XIs3x5cGc;B>ml7#mborQigE0$g4-w#5B#YG{{oj5__-o9P+YiYg#YcH@ zU{@q4Qy+@WUp8CMPm=OJ44i98hBUcu20UCIBeLLx)01hn^w^si6?2QnYo0Zmub28= zRl5CIOzKf=DXnwRckamE3PQtK#bSJwrGcZ>yvhW|f-YOHy?HA;ThnYKpFqE~oA0Mk zbO5fIRv$3wFCUo;pMv<)i@Qlm1@@Q1>_S4{)Ib2R`$XNX$XNwS@Ev!%)f9;JE&I|gg z3#seC5ks8qkvGdT=f~CP8YyNq4mmYjXi^jVVE!loOmVWjweN5{Ky{noOh7)7mqF(5 zn_v)(CsLw6;OdGJMwXsRnYa`a$jH$oV55WzOhg%%%pP#EYQw*{o~B^27Cb90&!#KI zx^qTXU)7clwqTk_gB}b#j|Y?a7Bbk_Ag6Y;rKlUhxybUuFAJ7Bcv#Xjkw00(K~W(n zK^df^7=2KLbE#w2T&~pNKp$=0O?pHp>VS`;^|aW6kiq25PH`-2OkI{FZ?fEH=mG<< ztyU$4nq$V-(G(9x5=50y3u)eb=xb$dGcdY}>?h&RbK&V?xyRZXn!k+K8vj7cKdD!HWA1M-Js^kb| zHIibc^wu7dR>p-I$f$T~rKf^Ax^* z@DQfo!)XKaiy?LrKN!4o#iGN49Kw)DMkI;0@FxH6qqzoir7$Z2OV}Cpx_Z>WM;-#B zy7I!E@i*x=_Z%U2qeZFd^fGVCrmDMGplDf*xsTfXoq$F0AUQ`%L{Wh9ufQ1Qb^Q-t z4r6EBI3;jx$otd61UqS=-r~}Syc(K)Z1(6SwD)Nio7(z zS_&RIC-dDFMDs{IHLWhP&?!`Og&vBcj!!C`_;_dJ5+CMlm~+iPio$(kteYf<=RuXB zWvh(H`9^2s$Ywi~+07~iHz(9$QK5U%ARQ*@kQuTDrP3Z`uozC3&geINwkeUiAFhtsfHP_ije!4b$n)*?w!WJ$DNjs};$i!5S|Ex-+`EC~dh_m`LZ&NWy1Xw#7_wTGfD%l378vfk*~cbA4hOAbe7k|Nn8zec18N7LJo7!^-Y z4SM##d>67qwTj09=6rn_6d1@f&5b2CV57Um&9T|_R2^qckh+Gi-Slxl=6u@DlSB1U zTYvbdQ>O=gz5HDEcrj6H%1KB20^RzJUDO~7U6cqHe~}vMEo@za5em$IFoBy%Wuc8D zbe5N`gp*=Cun&=k>uL2;T?9f$I>qJQF#XSk_{m~wW_+iRO`y=>mc<9O zT9}1d)0D2_Y^@mBQMPhqxw}eR??x!^GC8N~*LmVt9cxN`^8fET^GkhtXrU zN(f4j`!8dd5F|yh8z<^e*qMu(R2ldDq3qKP8vW%9a{OC21&2?R*jeaIvEhU~|Dy9Y z)6i!=k4K7o3EN*#i9y2@Z8(b6W`mB>j1YR!dHyZSammn4Xww$58VI?gQ%yG!%LCf^ z49U&WIQ}8ARl>PhXy&pb)4=`vi?};7^_dr)p!)-@R1B@~q|!$?yA7tpLhM^dT}!uN zbL{s|5rqt3fOcEn7mLXpb9x%ZhYZ;?TAsq{<>ZJ@OBFe|55bqW#V}vDP!qn#x|FCD zNdjBy_+pkbh|E}vHN;J>^#0nE-m6?)$OcI$|4Ro#nbFg(biyzJx#rIr7Nbfjfj-CnksPG_Kj~)1hvYiL5faR_#YI@n+)~)gX19|VkVX9O)EjX))+6f0 z@ZNJ>CxJ*mHjyZbnTLmwFmio}rQfiiaqSY14#cQMM_C$}hZn`iV#I+chFu^XF@gRt z3gx(j)*5>WrC(X$bdW%TNwl3N;3qut zo=4{X6Ux3$;=tBw?up=Vruvo9UlFFjAK$y;<9^-BP1c5#K%Crcy23jQst90g$r4D= zuPyL`?Wt!NUO6Yq$)7lOJgS8DB(=PHBVu4M*eOWFh!a!pra`+qpE;}wdqt_l zD7VO1cr27(0=1CUXeTdI)?}vdS(?7WSEMXSuV|gDlo#p#!oLfQCRu9}ceF(b9l}f6 zIhP}RrL*tq%S}g?YO3_RGIU(chE)8-q$jbT7Man6R7tG9*?=bUV8V|( zBX?yHb6l`XzsyfKD@WnFyezWVril=kv=IOmo!B#f>jYar&f5OXE}Yo*CoQ(;HNWv^ zCVPj7RKBEI9m%zZ#q1iUsNU~fKcv{2AfT?m9h6Zz&~Y-`?Zg0e*r@SW3&}jH@w)^2 z+79}E8I7TSh9~y6kl*B`FF6x6sQc!QrkV{*5^6%rUqlRaLcm>G|GG^YYve`1FNj$r z8{wkqqW8I#%4uuj&8*srToAw0S$z99V+2h5o~n%1plj>LIi@dl=N}|P;Ux-)YZ!$q z><7_-Syg0^P*i%uN>ix$9~Se`=SWAbh(<@@Bk&5gN5wdSb1C9A!BcN-w$~9Ud?u(A zhug9!87m=Xow*GKyOCqIO78GjaI7IiVLbWXi^y8JYsXw(3hU-$otP5K(VSKvZB9YY zQA(p1)#0yDI89N#Cod{X(MeUH>`VITV3d|p#X_2qEr)YCS&@Hg%aiqUMtU_CqP$`_@0H^ z^JyKAq46g4rIm`%<7>btgmGT6>V>Qwfs1kW+pv+bEInNIn;u|*)y%N-v_kqNNJt8w{1KKurWr?#nd$Df$zixZ z-}}0L|K0b${&Ag8y zhL;3BpGCX}CgQ;553c?nu#R7nt^7GipWwCP?8$_dn9}#}t1lqP&?8sk`V^7gfhFo4 z0^)jt<)^1e&%cV?5uN_5YNbnM&YUTIR{Nh&#RABfE(JOmoA*<>og5O!uX!rv4lDHV zRc2=s(q4=!u0YGLfAFSxpGv&uBf^lbV<;|wnCakL$dwrR_xHJDh`u$kuujO?|L7o&~qJ2dEw(X{AAymOeljGor~ z6g8g)&!%USe!0Kqe(LRuK`$Zpj(kMcQTcrd)24Id{9#Y=p7_BJn$W&?eZw}>jz%(5ukBEx z@jLNgGq+3%gF{ruh;&$OOHA8x0SL^^9wF;EPegjW`0kSus(o*Udhh7)w|g2bLdO!B zJqtewpIgE9hc$K7-5Rs)p}MVaqjTB9GzZq8T)<0 z>v8j9*&9(nMEYFL<4@!q8@i6wr3qH=loNSQF&`088zc{}-%=S9lqzGu>%JUb(ebWg zp~l6Oj()80xl?Ft>N7^cNuy}zFw1QkoYOTIY@|w5*fs8)$PJvSA>t{&pF_2Cb1EV- zD5QEfP|nva9A+I&Jx0@Kd6WsFb{>W(l*3)^Q;l0`f1#IcVvVZ6sBJ0i*Jd;!U#Gp1v0 z63T!_blM%?w9gj%wQfRvh)Am#j}5{O5lyBk`W(u4PBgOr6Elfu25#s;g{K;f|HBuk#gH-jy6Zx}dV!z{Kb!5>3tjhMcM zG=A@92yzlEnh|DrlMa5!)8-{x>^FKzuXF0jsRj)!!&LP?QuH{t)!VbG1}=G1iv#`- zfBy5(>_Cy~urIwIYl^qseq$rDnRn&*gc?F_Hl!H8?v|A`NCjXqr+d|Bitt9Ujr&)= z6L4nK+1jja?~2$8;8h)LfBj`kY=aBsM+kMd;{yIWFDVdXO>Z$+&O8 ziz?FaELw24wIM4GEWk9I?(`5-@r0Tpg!hk{tlAG|b2`6v#6LFgWJzk(e-^veq*>{GNAG0XyM%%a)ZesNE%>fA6(mD|z*;>ex6dZRzE z%$Qeg^e|@cponnt*gq_cuJ<2|{`T~-k$ms^WE`9)6J7645ZRWFhfemd3gBbtArWs$V~n#Mr0)`uoF4W?Uh29KF}IE{>iXkN5wB3e+OgQ+T_xK zEHx|{%^H2m`0IqF;*tmU_-o2xC1=Bkn(D2|_rawg9xKGl&zPdwaNOWY4eVs|k4j-} zjj!f#kD*De{e6PuhvF>IJ$34|`wc2ScC&9&&CNzlm6S4g#TINH;mHGZI+^d$x!!); z_=_nwWMJj+porQlRYD69ss}d?PPK~M^EJ_cZh?3hARv=ONgH9)LK|@+)8BHMX4SS4pPctYoBtqi z_CLT>v)-3k`wq)z$b{6w=ZVhosMmLUbGyi~W-Q>x`hXDOK2ND-$M#wOxB2+wt2kL7 zhPLemvU#dExGpV4WlOuf0Mph6O9TV%xc;c{l9vp@M!v#espHKu_K0eY zQy+ukP)qA9zHdp7S$>ooocUE>1Sf>48gcqVM3?0pe{={heElh0^JYsl@MiF-?bEr({ zhpAkGC2`eJ_+&`l2BlE{vXI~h^@j!J=)Id zm6?ZfUP!Ph$ApTU^;euOYUTX8p^ff3m-rx9@j$6a3 zmZb`>lz|DhiDfyZ>K!2E_dD7~=0Z>J*Lz4N44-F&e#a%8z-yGaRq1~I(bZ4knN{sa3d~n!&-L{Gp_^bG>=euCIt?hb1mr%Xgjbuz3p#*rP zza`!549-R8w)8&4)d zeb7q9ZOM#nY=YozB+L!l3Q|^mc}$d8e^E!7TyJ%O+gO$N?Lb`+8S{29r8L{k)949f z*`{2dWFw^|^~bDyTlpP-)Ul4#`x{nhwjYW3=T{9`-$tk!L_CXg0^~N^2+XX{yNM5N zhV+70gXv5%uw*pNeBsWtXeNo2W9DhY0T9W;MayFnfs@ht-x1qYnc8CiLTJg@5B^e8 ze@Q)=Lw<7_sqgV`F!i|zYxBoN8-CkVOYvgS+gdSf>Nda>2u=_;qG{LDzVD=n>odw8+{HTA1*`3z^mqP}MhoZGi& zV=cFXmh}YOLsfmnsQm&G_k+{;M&)I>+k|)`{$QGY=>{X68Ztslu}Ow9(0($3%>l+@ z4~}QSF)l4DnX9TdW9Kfk#hD{Ak%w81ch?g-^5pI3r$K)vES!!#y$a})nj|0z-X65} zH$lv_NT9$*&jVi!-Tp6oV>8tQsJ0OrV!WXvr`ce+F-lY0hPHtC2$_bk%45ak&@oLJI+LYqOd5LikBhRS_n|8CZPKwOQ#+v-&T0>c*36q8*#j zG=^j^bgQcU3&B8P2|glGy|Ygbp-@87y=bL%bL=ljf)dXqndOIk?dCg0z0*5FJ$yRF zqsZ|iJ!&>l*=t%OPQo=Jwh9km<#&8fDB#%@{Z0GFWO>u~P-bk6TWtA>ds@M-!8X+6 zWB8(0bW%y?r>7Nl)e)sN&z^w=lm{bUf6;D;((wIPu|K_Z2@W8$LHh*#`}Q9gmFT4T#%U8 z8#?ss3fB_g<=LcN!**e|uNPL?2PgWLRlSQ<{KCHjw| zZyAl@IVe8#yoI`fbC*a}{Wbi|V?wiH`B9T>_7aBTKqt|?QI267mOj&!6} z`H1Cs?QySv`hY_oU&rYDlsF$Xc^%Sb{ZxpM0}?y}EQH$XN!c}$TE8){AM7>Fer0RUutEVEVP;&%JJ{t!)F{VVZp+;W#W57>fp=F#Kp;_iG9ld++V#bC=X za{J3NeR6LsY48{^wIhK|hz!r(n2t8!o21pdFjsIc!NW7fW82%$iJNQWuPEDHxM!@I z`}-KT)F@^&-|B9RIhF%h3!vW-mq{2{a#S&C{vDy$-FBnjYH(52 zH6G{dh75=!8tFUi`PJ+?%42@gFT$SrW%InY2~59ptfWPM>wgc5 zL0bzy9r}v`e4nF_u8iiv=${#+w!U-w zwPLO<=z+iZvyq{YQ`O1@`^`*hivN!wS_rcjr0tKONDuoqAtE zO5h3;^(G$e>D?hqzlwRJyz$?EAiulhsi2dHaA^>?a2+NNP#;@bqkLqu8Oau&9+n=1 zY~b==i-fauhsDuV-31ZcT<~E0G8%XO@Yr2cd(BHSQxi#-Yeb6KZ#{pG@?WRVk%T8i zD#su0!y~{07m8Tk8zanR5b8tM&!^chm+lCsuj%_2$T=(+N}E+7c{3Ki{k|E;HXRdR zEA}2~wCOx{6DtWzh`r|KzYpuX293=6IM7z3UaeGWxLhT>HlpJW! zcwy}O)p9#vF4%l=E1{VDdt!$O^7`4&;l#^z^JuM+Ul`hQubf;TW|K4|-O!?=#VvW5 z%h^jC@?yYpTg0$Co=C%g=E*;XN&bsPLk3^zlD7^ z%x?7{)PJ?THST~K4q_Iq#3Q;nV?T2gUvoy~vZjmF${HQGr%7HD>V5vXLa6=ukp=Z? zRUwT~^b#p&iLZP8p0g$ZNrj;-n+qOjeB0J5rWV<(Zm89}vI#F^`ITq}*YP0Ep;cyC zQyJEp!LfcvFH*ZFXEp1hck-BSQg`=28qa8Mxn0!EYb_Q|d%@{rp?~9!ng)pg1;OGq z0qhw!qHP-T+)DMO%SXhw{d~rn6DJ#ip!l&AA7CjMK1#yu9yP~Jsm3{ zU3tEk>+ojGNF6VxrXuVKf})yZGkW^swqiL3g0hh{?V%$-%(w42#(N)S9e|B}2L)=k zfIRQ&RF;B+w8S_6bF4!OLJes@Y&*$&y%Lj^14oq~;j7L5Bljv6PR`Q@`PjkYM0+NA zu_g5yZNHfM&^{1*(bh4t2fX;%c3DDAlzUIEB9OCjt<}8Cxsh3eNEfvwX$YH`(M;mk zreMgG`B+mi9t0hG0Z663aKahWvA7fz=nsoQKWTR+Dz)L}gR( zB2Wi5%W`jVSf`%f5p!I*a5?-!0=e`0l!0l0{i1;AJ$JCokYCbE5W(xWt5TD5BG{UB zH@sHHu0j^R;|^)ZciA;ADWsd-cTe^E>ljf80%@%)p|`Dde)z`yQZMczegf9w_Bi|( zs?c=$KGGPMwOnuG6|CgsEwQ^jiL@P?dgo+WUBavdeG8W3wga(*?=l1V)T{?0h_Uv5 zRK+7zwt7CS*#`TC_KSL6me)PBR)^z8X_w5n&2NC(ZlQOg!Es-~hAoufzrfX~i?%y`sLUGYslHK@_v7NGh6~iR9Pob$#=N)bA-u4tVnpEgMf02^k zaQK8|rn=`P0v}yNb*;UTh(X5eeM}=uqw@~k8Q3xHCS9c?Th{(7Rqt><@Hf)omb2w+ zZS5a;=5ybulLdFb)TC!y( zpP(Crm-d-dRPI;SPhD?R@WL05-sHC|-_1NkFKBzIP(+xI9e?XgBpuvmZD}SHbu+IkoZQS+iEyH7Y zJ#~CCEqJVdz4^hGENLRzcI%dkdiHH(DM0lYAV`W7w$%MeJFUiV@<7T0N|rQug`+E` zr!cSjk~&Exdv1EMA!X|=Y#?u@j$iHAs=oT_RP4`3rxXMYRDclSWFt)Sk-N~e_T%i) z9TvMc?a)HA!&tdYdsktyc#h#ITqV)NxX-_W(ii##(Y(%LaLdKPOe^H=s%G~rt7b3- zx3Pg`f=L*m@x~i<#6YH zFzAbxfetRtjT-CE3LiWjwogU3f>Ex8_1om=^Ck^}ebesD2GblnyPhcNt;~+7oD>=K z){*;^oxzPd~KC_C9)OWyzwi{)(s zc*`_qeB@Rny@NBOw{I})rU`>M-KhBuF}j{;IIVQykv{rDBCSYD^(2WwVc6W*NPOk$ z<~8fFc<3+03b3dkee={siJoc}(P*RZmT@R;HTzyJb7i^~R0E)cAZjI`Y#IAbDjD(e zV$fypgh?TY-17A=$p5z=(7IW3(Zhdz4+6{@$wyCrmQFDE%+iiG^=*+I1d`+B+(Qmc zp)KKtQx3IVtWNI$&00QmnPylD?q5=~Xm1E&MXq@Hy*KcP!@Y%A+k_yfI8LL!zU^r!U3O zjUHj6lTD~_PoPED27a?%ohBx*R-gi^d*I9gq(0&5P86S0&(WgkX<+S1eAR7(#!22= z;1-wB5N6x&L?8P3D!|lKZSAY=%%J4WEi43jT5L2D z<9*gBbYD%ZijhozLMa-SJKR<;P`H-HT!QCOUwdMw7hjP$x{1(m-?gLYogYrmtr8~U zQLyQKkrr2B=U06Fm;V531ztgz&Q$Oul7eq?5|p;X#E->e7zusa;BwuZ=f~V==jsR> z@CY;6fMbsRl`NLUgq0`vd%R}}j>WumFA{wb05QJNuP=2cS?n8o2g)$Kk8}Jcf-fbl zzldc-N}5oP?#)Yv)#GwQXA}3J+uxukdCS56!CY1kB}<7@Q)ovVTi-OHr3MKCWKTJoqhb)zNVk zvE#m2xD4t4^4mFUU{Jgffw0nlWMTgb7`%;5X4f1TAklQZ*Q`Amh+y_)^sNH-RWA~g z%?#O|ks|bdwjJZG;p_FYWO^;rYbpvXKllB1=JWNI0O2VGfuY3QIEX~VK)g5u^UxV* z5`qk>BRdtD<<6(I$LnFFWw2-#NND@@sy{I&(KBUn8Y|-ueX*n7C?L#9jy$zGT)-eA z{V|(1qE_8Eg<0yibYKJ3eWsLE7>ISbby4O6tE;J@(n=K{nXRQ_1hw(ZCvskrT8o9S zLy#AWK1vkFG77UU>NhG+S099lw;Doz(#uo)p1b10^-rJ5{F#^h5v!q;g`=6ZX>4eQ z_&EFtpzTT}n^2)5X4A6Dj+MiYq8k5kj#_B*0!%iwnhCG{aru%uIF(f12rj9W@bY>` z!EHENydT5ax9=E@C+~$h6OEseTB}kjSrO;mLU75M`xJr{^wsGD0{WcYS-K)Rw7C-e7K9ZDEHTm-Qzpe*`hK1;&l$b*{Q=e> z^UFZP^eI}b^xU_O@D#i0S(bU-Ypv{n!zU4A49AI8WuU5^pY@l{PI15m zJu4R~s-_0ZjuBVLO8@*Dj`P6$fG4oq*?DcTksY_!_s*VDm5^s>5;RRfuV^JpQj*p- zGJrO-?xkdZt1OuS6sC`yB(taCxk&7^JU8=4!;)j2pNOa+Ij@&i!q;#&O==5|J zts-06Ly84wd6z4YWokb6D$O00KB74%(aZN%BI!LSavK*zP3aV?#6Li=EH?T<;Rh+3gh!wJc&SukFSv{X~AmbtW zVpvz@)+Q-Na(F>&vf>agffvdG^5qPCV{ie*#C+vL(osy)V zus=s1mae3>9=v}>W?FJDO@+;Iz3qnlKJc#sn4%8$%!pFH*|`9ka;>SCJEADeUpOz( z8=bhJf?p`F{3v)~fhhqUVUNAEJ=sDKq>~AM0QjTObGV8mwl%*IhM_h5K5+Wz=Bctc z-@1)m@5L+Y*?<9q9ECO^`cEqPgjzrE(a>@_Eeqa`jB6v9FWB7Fwt2^2u62}9XZJ5L zFvPD);Uu}tU30a#t`I~M-;i`ze6A-}j(MTya0jZ2nt+LxQVNw_JaAf`q{irzOBMK? z5~5JGQJKYTc5u;?o{f&T+LG~S7P$W?fM}pXLLy7XS>|!;`ih;FvY_Q)a^Hw=Nhv4J zJNO*=ua#D%WN`i}hOUe{&+O(DUrCqP1uWGL64Ltv%#ofwXqY&%i1`DLQCkb901atu}lVDUS)puT2^NgK5R#Pw6{(%j$LvPqq11B;B= z0Ul{f4X*9B7pq`n@OS1?T}oYGD!|WU|DHNIugp{#GNg~KBo%!tv2c=&cT%qNL{Uxy zS6D=@#Zpz+_Gj{!Y~Z6@UDI&($1@_@kW>e&ZS9Mr^cjK5T`hx)9f`9vHr&qQ+36-l#gc+bMMj6 zc)^FKbeIA5q}TVClxwnMWzusT7$6{^mXr?S!PxnAmm)Q~S2ssY{UjR0Ha|NCyoLGW zAH&w^FCp&>N!Q}mL%7^m?cbv~&sz$AtBXr$gle&oDj){kTC4p)r4t^;DSCIsV{gAC z>o4I2i>p#*%}HzJX6_1o>9D3e`gpG>}n^nCxcFfGTY(*gt4Y$7xc4xOb4 zj6BoZ8&Mfdst3BtKdPc6cTW{yz=|J{o0Q0XXLjkCMPk1~rYDieYmP@D0thFW znh<%01O~(3M}8JN?Y{+{t4u>Q&1blC^Zh**5ZHN_>{HuIJ=8Fom#8v<0N0`J$DHcv z=PdN30vQtJ4PGuxjM~^^APZ8-m?a6R&w<*u|Jf_6&#_51gnJiUUK|CqSg-GN7rswi zx~Y=IY(x~gqX*1(gfH1O*(K`Kz9Z zFKH6&O1%O%NW7&g_PUd>S$>w3^An5EUn(P74qE2?q>H!t%3_|URTmk%_IXpVMA&@N z@6|)1+J7aPQ3ri&#=CcuK3*&>&3W)jq1VVzoYxcm6Ko->m#!oTrFqdbuzkY(k<|wq zD1+#)%pRXjEY_b@@Z_WJ_@=Q}?-z`ErR_2P{3EHZX_suV1B}yHxVVRFhZ;{!*yIpi zXF8UfoBSHwK4FL+OrcMYF&7%Rt>3R&?LU(@kHTChvf;qX5xQ7rk$yCQ{&eYU#gY$x z^rxA=g-}LIrA&#axbG&`<@p>6kDzQ5L%5?llNpB*XCOpWqN=oSMDYaXvZ;b(Jl8wejBFfgm9;+s={jjrP!{QAmyEW5J01Fd`#Gmu?Nd&7 z>*-MO`uSium%& z%<6HfhY}FoPDEU+NPVMgG1&cBE0`fIIiz;Kk(OaGvA|6(FT(qDRa@8(@5YzQXm5QO zV=R0piJKE{dQVuP^MmlTG0=8*hD>Dcm9*v_c#+&cc}ix!XZtu03SnhX$ZN~##Tjf# zb~8I8(eY2?r+$p9%J|a2&V){4aA2xYJ<=JWl5V#-5=K7L5lW=V{V`(Top^>pfRrnt zB!KSh3^O-stCwGA;3O^W>XrS=)Df4{zHGGD9g<2|rj4!fK^V(!LTpw*~2NtkSNEG|veiB_k?q zfL|l$9ZRI=hA(J6`|$%9m=l1`*z@IoXPJAD_MJUHQ(ck&^`JV}z(t<`J5z1%x4q#_ z0R&U#hwrk%)4{Y3IHPUgUpr}?Wq1%q(<`XD`JnO9QX_fG$&32)gs8Vihpc*Y)79gt zFp&FUS5rGZUH}M6q&=fY*jo31tzNzM6@KuFzR-3khOv~gjQ-F+8 zh2XxHW*Oe%#^pHV4mM(2c241__e$u-Z0xtsv}bU_U-?)|83<%}2f)8+j>xAu(o8eZ z?qN=OJBNnvOP5&hKnw3gbOt@46T@BwfD59u5aa?!UG+PPO}dnUFns;u_oXTDR2sSr zz7nK0ap@~#Eri)@R|ma>D4h$d{6BQS#5Q3UUsiqJK>1Q=V(S|$(Xg0DW8A5n-~i&p z`wu-}+dV02ETW)@>)1dmbiz*=MH&GN&$S?0LEx1Vz`zUx6L^vf{Sr8)cb~X#X3RB_UYvZ{%&rel=b! z$`RCGsP?v(1sA-6E9)=q2I_eQCDbh1{w|gto$?-d7Cclc@{1XQ%#l8x=v!q>$`uhH zlVhyf5+~CxBaH77(IT=xJWzy!Zcb%03_z)Q=Rz^!tBNIuA`leM{SlK*ftCnQFpm7I zT$?Se?ufA>*S`xMHqZu5_KbA1pZE~_J;iCy&hVrVY;3m6CMdXY9xF*o&RD&vDH99f zGsvUEK3bxN*&Y@1&<*X0Y2^H2bPFrGAQ;Ii6nGkWVl}GbLvbXO?APaTPJF{g=f#&K zRSmmhH2*Zl)c-(n=9u?Q)*l06PYal#M~9+SOmq7ZwoK?sxI1`Mq9cNN8p+>jWrLYKLVLu(z`HHItjB^8 ztjb3OS^Px`S@h9v+vS{*4+k>ooUK*?99e^;Vuc(gphv$^%Vs!n%%5QN(S9{caQA|5 zQlZ*apM`xd13ZxX_bxOYm%k{rM~|Iz9S9PzWrCt@F*~n*04cun&_eHjyUBr~@HhTA;IlnCd_S@H~1+ywI8r>NAi?Ppq_C#ihRm3j3_NX=nsz^wjH z-M|bYdX{IGb#;Pf8#tJJZLH+~uw_8hJ9!AMwZ$(OtYkA}`lej&dV=qwQ2T~sSS}tO zp}wW+8U_xjsfl=Ee|(8Wvwex8T>vFz@7Rl{sy3_8&)eV_lo?gq#4_IKgI8iamhg$x zrvolvo5%aU?v}ri4KHlaOSDf#!;CB7HJ7DQc4P-0u}P3O$?H9Kp}Od(__UoWj3wc& zI%P}%zdV&kNc6>%T<(ah2ycv*aeXZv^RdEUfQih)vWH5h{@sgj((3(CXS`KS`}5(4k4obekx_z* zm?@k|3@xu5Cd$U#bAk1zGjW`fX67KH*j)iIq-Q4{OdjPLuPJ1SQ!u)c*&@_wW4S!s zr}}T2?HxM0N*o#tG@Z(%v=GX`cZk= zLtEy2jCn7thPg#N*gF>NvCW3|{jyik0AO+{6xis_9(Wd#BH5PmlZAY%P*e>QmTe4R z*zH(g-?iG-py*{CQ6;yVsMEG+od|VCLYp5!#-C45X|C*z&ATCxJ#O~f@6h5kcQpT?RXGH*>=(=Emo|ckLP1&-vvf^cyG0M!w z=`TAS@H1b-TMmf~F#X_AkMQVtVVcZs*kmtNq~7x-=Dc*QpWg2gjG#^gYcMw)`+U>4 z)Jpr0ufT~R%atxgEf$LFATmD{1g_-FpI6~rdJrayPewbbpQ7B}?)-gz7M0gSA$tw0 z{*;!9rF(G+Oo={bpj7=ytNYekZ#bs4G4U>aB+fu1x7~vyPlSbB2Y0&av5yqWUUG0d z!Mk=(KuSj6ZG+0sEF?cl%FG8bKdv*TAhH4${GS>7a{c!5w94jD18uI;cc3L?G}Vds zd(UI1ESx6T-Q|KyQoU#gGIBeO+8fAKg_{}Q8xx!Uo9$h3kz|rv5*7-p)a09)k#vwH z8jd?eu{%V4DP;p!fD5-_?^y!k&g_Dwr-Eq6{ZF9XRy?6gIY1tA#GF5*118+p|KT5% zMm^X`TK4DiAHL?F18?D&?eS+*Dp|y$Ay?EBekEF0X5X&$VAY%W=;B8GsA<|a+{Q|K zg^l(H_b=&yMFRIn;=I7PZ*0?xd`T`M7Z?USuL=3BsVOP_`#2t$ORhC!G>VUS;EMI^ zrenz(n3YWa7InsZWg{IRq_@u(;t2nh|lT?mbN{>!NT^Zpo)gFHCo3Q zqQ+*;&Ots?{fgark0R!kd9SPq2c^nOi}k$3ftCDp=Z}>rX}Ra3@g+@YJDE&UcpFlXQUzQb50QB-6^)#~d{;{}!FVh_SF<^LvjPU`G zz@-)bTHWM)eM{ZPP{&Kb8{qC9O^$!IYw$Yuo=d73Qt?o<>K&A_HW}BpX~~m*^V70K zxGT1G`_NM57xtKhh3qd`H5bb@vv`-9Ou#<4Xbs=8fi75sESgUwPUe4YV@vGjog)eH z){yWwu0-v#Fkk7|1Zgp?Q_E_zJ^oq{*SPi+VuR?l_d6@viU%&aJJoG5B4!~a8G(Ii zswiq@tXtx?Z_C(jWt`u=fpuF#gvsj9BG6@jemd6Yh{+nFL-y_MUb4VHHqh}^Rs;*^ zzCrRJhdvKn680ZL-~(9Bwy!FxA_0g}-3YbHw`%yD_)@}3pol$PHDXy#Tba!eHr7Wv zg_0_E`_kAa21squgqByWh9bS*zXB35+@Cn#A0|x(5LKqFqn4raZeXs1=D5CVM_7 zFwGI@NkssePRT_p!6d`V4oUR|)r3bBLKcf5jSM9i)hnl(74vuej#pauRvMA}fNijp zZiIS$Z=>JQgF43BLcKp*d9>xtjWT9QGEz4C6-xYWBy9{ z-uW&im!{_EFWR3^8J2VyE(3+Xjc#9hTsbi~qh4J9XF`4P7Z)Y5ai1F%Ba$KRoK@`{EN8PZEC5n&*IRV>aJ0$;Sl4V9!so@t-vro`5p3<=*$0- zc8D(6)EocsaV}kFLKZEvO zrmQMpgp~h#mDx~v`>w=qvf^g&meYBK;*@rvJOFqb+q0R10&yc(Z<{5j>LL-0bfo$}r&F2iD zWtm@Vf&)Bwky(w8(XWjl_4*jA1s)Ud>7zj8e?oP?M=7n@J3)imR|I&lJ0<4=MHW1h z%8skQg7k@CfirOi5F27#M(jrQ-EUSt!JRcj6ss%ZYv*Y<0R*Y^>u$nQHGus{{n%(w z?pmZ)kJ3NCKj$?>jkynn>78l`Up7v2pf=Bi141R2aiWJ18)$~a9;?AbVt6Hl_?E)L za>a+$faanIYVaD5hD%~@-zpU&40X_5>76@_l1}`j`9QA9NHKCb%>0`vFlD4vD<9~^ z2Kdm;zGeJ!uW#`P?@9Y7_9n~j3l*Nw%do#On#p=MY*QxRq20R_>U!2YY=0i>Ye7#e{r)fYUYiO= zCOVbRS&%NXjAGl?s7U%E?^VsEhlbO9d&-yb_|(GvgHsr?$;^2fuKfoisjS*`nLuXC zj9|8U1R|h&&Mz znPGZHbBW^dx3ZWXVXXE%F=?u`6(JnsdRdu-q*FucuPKTB)IV*^pA$w{0XMjR==r))Ql|61{tbEtIS6qF?dc4eK>PS!_sZTIupLH5IfDk zkz#dONvr@eAo+bF8kAzn1FDjiJY4#!joL8`#l*UAnZp~aRcvCpL#*YVc@sFkah`l) z`;A5{Hg{OQp8vC*wcI@Wo_XO<0?zOE6y&cykE$oEMuO%1w-u4zztmCuzeVFT$wkP;n3ZMI-Ft0wnIP$`iP#->< z+$e!#&{8{ACe&yY#=LE(Riw>#=Mxgl%tu5UNB^E)`|5D z{%*A{OJoj9S6?i=s3yciWH~!7JVwkAwVp`-O)yOE34oY;bj5Yy1VF5315<3EQzk*O zfkP%pvkpg2JlOoApTP?6M|vzLLSf1eSsM~7`1OdZl_K@y? zmog^I*ElI1&tK6e12gyF^10CKENQcmy2m;-%qsfGy0=iUdUe5RQka*nNJtqIC_Sw= zE*W88ePmg{eb^kL(%>f8C1Kfqa6Hcd@(M)!{YBG#fZQ5gu#;xoTrm?YDrCfe(+X-$!SV}sfc0jFgQ3?{Bb zHB@q4u=u~hBCGjmc15`@H@d)%4w%;9UY(P9t_l zvbC{IStB;x=v7@1pBWx2Qui0wpLkqbIBS8~O8w(Y-HmQuNJYS3p%{|{XxjmgQA-ha zUV=;#)d%-?+2eLN!sYSR*^Kc!`>70|ADt3%zN24_8d%9N6?6>PogMyQ7_i5Mf+89v z(HluDXX@kZ2g{xV4}1jwF@)0;I32}^_psrGo+`$LcFtH&lk?;ZCh-B+&s1MG0mI%{!bl%4_f?XI8e~Ev|ch$Vs7CjLo&qM5A9^! zZEn*k1k2dkvUErU8hf*Vl)*5MZiPP;8DQG9+=H8+$V4O1JAQ6>64R}yjrVez9cJ;F z#j|>$5j+^))NG(-=Q=3Ex5i_ui;AwG5gH7n;;cp;{lKzbbu#_<`jLC&whroIhc@tB zKv8+(I<)jHM$4yE{j>sEOuoVtw}!0zTfqaeb3G8S@jHFl@^&<=IO&Q8&qz@K)^NkH zsA|7D-^N9oAhZK0e-KO^AQKbULf6y83ay|)iC{(eCT|BD(I86tW=<=Zj{K;AOrrQ? zYljko0%Wb*Z#7##+%-p=v%99w*eN0+`m#FVij>Gc z;@11XEqo`An%ljCMjreydg!#Bq+E#5Sd`Uv|=Zwh{W*C{AESV-(L)@jOn`&B`))B4UKQ2_|)Jg zfY~?lm2b|UD{=tVoM!UWA4_v5lTYA0VyLW}A$ay$?DeW9_)*6O`XT>Ia(V~xb1N2X zu8y^)byeR=UNKfT5*$?uH|9R)V&b%5{Yp7XOJLAbTUYq%9+5848_YY#9VAr)>z&JAth9RM^|QF+a0W z_z1E($9P_;RIa)&^SLhSV7w=st3?i{NOhXan7y+XnKPCIeL>srAxhv4MUQ2xK7PqwwK# zegrbm5`1O{<@{tIf^bZja=W_zXW?zdIokTJO~@VR`jZX6hn0@@@*~0ukfO44_hHFE zt}rzBfthUpOcw$xPZ)R=VnfB**<%i#cruXOu-khT`DFtL2}ppecMNRZ!eDCS&I}%M z9JKU^@U#{_@#9GAysCWu?5_xc4730UWT1J-Kqrx%vV|Gqx`sh;S~-EDMH=EuO$HK| zRnC%uoFqsU$Uy3$LI$D$(AftDfF6nrWa0%g8-G^Lwawi3A=V^! z==(_#U&|fW^3Y_U9e;yt=tCq%#leW2eO+^CnmED8Ap|yV8(>TDj)9|hESx;n!o-Yo zte2V$WMJyO@MpI1fw5}>{Bn-5-M_N#kFjZepggG<0iUFIAGpi&i6jGck1LRYVz6<+ zWS~w3GEh3=`~Ki$ps8OV>3Ib)VXL7Sz@>}@4vGRby z>YXU*yS=o540@*Bmyui2iy;4aID1FS?MpXiUuuzk$%2I8@Y}ir8C};{0(knjLC>v^*PoY4v&irl(8jsX5BX1umgGBcW220kD@1+oD9UFj#Yd(jg0Nh@b`^_ zvsVq8)#YtuVKAhdxzZa*`rzcs95zo7j&PmI`E`YW{pygxpuGu7{OAk>xyK z?pA@6(I=P;M3pOsKST1StuQgR!EzmQXc|*OlpBL~GzPt*xiRQT)nMKKP-GxR7@Inb zxX3LqGqAvNvm|UNKZ2TFRC(b%5~H`k%G3%P2J5k^^e{wIzs3<^1a-cmx;aDmeUgF5 zpKBi7k5Kmj7@E*gVN}_Rl3_ogRL#fnacK*@>@4NFQid%}w^^mVy2eQa`uRh{fY#)21a5l1#L{aYzNSA0mNsZtVi5wYd|ATE)s){)9 z1tPqsV;HO8S4i=gpHVw_4C!$j@RY6t^o;#r<6V!GuEVO&R`&(+tDnWj_2FDqFpITt zkDf$!^B3&d^4m^9G)C1F`PZhSIS*rNY$7a~&0q&jt9Yy_{t7kI6aEK*Z=DAnp198&`T`nI2VLa)wnVjXgs-bV^l1Z! zo-<5neH*&>D-on01O8}apvilXcJ73SxdZe}Jz(VBi2|B`Q}1p94fu?!p66)?4)j%8Rt zQ9 zKcQvce<3F^8dg@6JG6k^miJIG%vBqvei#`jeuV-V=vj#8Ap_M;oWZ7G7w)Jqb5FQt z9iL+jm)k%M9WyqNp1B{aSB-LUE|vE3dzv8w$qCU8n-P$&V5_U(2Q=*d5yfR~@V4h{ zAfrGD5*o<{`Wv%>f_(j4FR=~u zn?N7~2_JPG5({`a`phS)+U>X2h9}!}=k0Cl^cz&e!W8?PMk=1+)4bukp zlTHR=4k&?kp8p-q&-^Evhpr;8@fAd-iVzaG6)xTzU}5799XjTL&hzwY!uAnvr;SQ~ zj~N0PXaNw&Ko2GZjhw>vv~BF%@`tJnM5F}@46^@ne_LuYkfc%qd!{etm4UhhGLTv@ zGSHNq3}g*W>jG?UI1T9o0zmVUfofzhkslMLXPJqJv_VxJKNm8`u&MllJW=^L$Uw1N z(!q2EoP(tgu_pCl&|&`4pQY!q}RwZSgN!8K{3Zyey~5K`iafgXrF18N+k;y+#z{-6S+#;(HUGW+Unh$Xpx(4ykRa8lb;rOc_6DGgM-9MnLdKdvt z*3i&#h5fodWDfloT1P%a%*Gh#>zF~yaXTU^zg0eZfCbZ@k&}UJp=}z4ps04lW{;@O z5p^J4k$)H>s>rZZww>I*teZY&bKifWY49>~8(u=>wn~HqY=MjSdRSPyL5G4%hAuGn zZ%6Lvk2B-6cJMg1ZCV43X)@4nQ5k68RiwtHz{ZT$Z3|dMK7;(;@43MEqw2}WKn9v# zhw3QsxceVy+4D~{^q)nRY%jvMlpv5uS)Qw5X64F~%bJGXuw6fi;)$PU$K!8G2KpI| z`#wdu$10c@(NR{82rN3yCE-;1l7q(Bz&`j)lO~NK5YeJQ9GLhHv4J$7IWka41rZX| z^us{D;x$xMNnkGb6)iowkpF<=B_AG3NCu)YICTr#Hb=qCoOHS?%pEttacwcy7d(TU zhA&Y_V{qUSM8mfsYkLvC4uOB5W5NhOpnmW7D9EXVi!D`#aff^2t0)@#FSPW(gcSjF zM3_0W++;{@zN8w&e{4|aE2^6_gx@I{h#hTp3u#dsU}jE7g#{v{;tV8{e??tO7aXlT zp>JXjgTN`2s!iA|fDE)0_T)@l=Se36$>Tb#!o%Ot^vr*vb>udRTK6M1yB1*~DRB4Q z2x~h}=rVE(S6y{m&n?xyel0Q(vlV`XY~mhdRMBmMAHP!uB~a^{_h3a(Ahc*hh;=$P zlpmFkLn1;Q#pZSKuwY3YJ9ut8fZ_)w<)?l@)7YC>;T{SjV+ZKyt-wn0X;e@Ag2p{3 z5#{R(BNKTgSUPHOy4ry{=-J7AZWr#@O8MW(2MMm?+mVDHAnE)7tGtb&W#R%oV@FmE z&&W0mD~hh8c8_{}K#8#GwgcD@WW#JwZM!%GzdO|$FScbG? zTl;;~%@y~cxC$dH)3~QL4|qQI&i{^@@^V-*;($H0U7CWV`v&t%!7nA1?;x-J_AI&V zNhbrT#4`Q^O@kMa+weRhQ!3a-dKa&CGqDUULsysu_8@nZSAq67Ap^-l_fMN2_WcK1 zp7~Ek{%LvUpJ?9mSJdzL8I5~RA#T+Mn3_32N7ofW#U~*epC!JM-pNf|2U82u-S*I4 zxf7N1l7Vg^c|BF*t%-TS-nZ%wY@no1QMN=j&~E~P3?zIMK77uPKn7ZZ z&+MR_pA1w#bqD>s4nZc6ffVo)DvPsVM2m}={HU&S`?gO0Ha3@z3mSae0r4$cLT zfkwVVZh9KDb?JQEFt|j&4DoaSH=6fwXE#3(tPSmaP`Kt3O$Pb_Qs#gXrpqC74s|`9 z)4t~)Xx{e^H1>avEh*)&w&4OH`W8NL-gabe=k_r|AOkG`0vYH5WFSg1UPM7@4Xkx2 zapwz%&>rOO{D+zVo#M6cLppK`;$cb@-+)AJ0IQOLDi=)#B3;qaJ_a|}0FE4BTZFBx zm)Ook)xJ3Vamhe`Lrdo#1o(w=!Exci>DB_;#O=CbQ|{KO@O_xvYX_Wl*M9s3a(lMZtWTjrnXIm9BQ=qxt|So`0a z4D>l7yjH`6f{ywjSXp%%;s+*(>A0u&v6?EhvfyqMB9+KM-yRNUn*gp{3x~4D)YC~Gr_YjrK;OguGO&upVCcKQo zLGmT|*Lhq}=PRn4GlbtO8Hn`FzoUkZ3bT*RiqI(;^j?$+0u{m!S;`81_R{bSN`SVv16sPEX8LPuqK~vunnl|a!P;nH} z88Xl@#IN51O9MI{%@_VzAE0cUUoBh_Kcn&4uMz1zz2>aK>gqFcGSH`3=dlhJHUTiR z3!ER8KC@&X6YrV7n>((=YMGkjDaTGBYf~ha(eX=WZZL5PhpnwWbc~3&R*tPhJX=qW zx*b2Gw*M<^jokokJsW77u7lU+w;+DzZ*#mKZr#nxlfuQ!npI^q3vEDB>rZHX?krNa z7Qw;H3nu*MQ3G=;KN!j@9O{{P!c;CB6AL%^<{U!B=+Y#@Dhje-OlvmHYmXijJ!ni! z-9vT%eyp_RVi<;Ys}NLn1I>GGqoQ>RKBk@vdN!&GY+~*SJyU0{8n1;P%<0)IJ>ame z87V#gp+p9%+x-({4LcF+8^DNQ2DWtU3XKChXlMq*Gx-%1Pg0dL{{8-l`jKnbVBx$17FI;e4S-X`1ac>Se(2-0eoHbC#Rba?(qU*s zM}@h=KJiP)CNCo`Dg@d_&M>qJ$BG&{B23M19{~?f2BJj&acqrEgQYH21+{_8=I4-4 z$KER58 raT}=#n_yu{BoG_uSZ+at>>O%#EV0i2weg1|195A2Y2G&I8&clJ8%{}I zqjusvwnYR&%g`A{w$WH6yQyYvSpXSm_z1QoCd%V#>*1ZaA2o}NOENssWS}{G>X}%^ zf3Pa#4ZR;Cero|Nttgh^0DbFFcx9g8Vi_~n{SC=L55mrWqONTa-Y)(yFtUM$X(^Jr zE1-^aex^w1I9SFJA^XbIt}b4XHz7%TH*1_AFA8is4}9 zy&yY5#RjtGY#>p~XzJOcG{NhKK80-R zXO*_KyZ#-bifTA2kb#ohxy6OwmJFG?-RD%uK;kjl0ncxHAGLk2!{31``L6HUfbD}H zV5OrS^i17hhh&P91pM|FkXo=C z@eD7ad`#VYxzFA$oyDCQ>kA#)@j@lG{fAJ@91k_^A0wnhbm5VK8V25lbm|wCw!bZi z=r5L{u1}aiUb^R)`iaR!$^3XeiGPR0(uFAk(*vo<#zScThS)J2&`kIn?d`59B1TS0V$I)49&18;OuI z^a=tUykJO%yPyzjp1)UWC{yn54myoq(tr8iilI>LB_d!nBT~stb%Oe9_ zLlct|5fjcKJGTpG(2!k;*!60Y^|Lg47sjMI^2bf}t5D7o1?d z@eNe++wd`V8cFLzp-Y4ob8lFOzs9O1vhSvTL{0w*#43@2K1UM;wfX%hf&v*zWS}rc z1}bKbx0?2k07`I+MfGrX^_U?8ZD~3OsoM8Il?5*F$UyI*YJ`G2^8R+-LvpwmjLc}) zt^?fj&Y@PVElRitNozkGX2?Lvi0dIcNU;qCFfO5lM+Ta;Clx?JqlvRhWFWhcCZvxL z36+0sN<^PSW@b5Tj7_0o8ikO|mynVtfs?5TmKkkESnj8&-bYTg;_rS&zVq)Dr5qsd!{!STtTs@1rBnX+`_8}2|aVoUv5pTYIqi5K4vpypy;|=GwTk`yD~*G zkY_2PB}Cwz>;C@EY|fQ|t6*SAiHns8E&3c46y%*!rzbz0M^ElmAOnRXyiuK=TrBRQ zdiWHQw@)EHZx0g7_oHf3oiD`AFAAc?t#;v#U^7_-*Ql3JIX$jur{KV+N@O6fO)sK! zpZd98xBH(ct!RR`0vTv??FBTdl`B~v z5b<%_U~OWK<;Gcvsr`0|-}8J>_7&S&#caJ$ZAYLrN0kgjez9uXhX?Z=A3z48>yUn} zsU1RqkJk(tC{lbJQkL-G_et~ostOsXXn;$oD7XC)4Wl1orMEwfXqT;RInw$*#(Ezg z7@9gj$GR5T<9BE0vl8w>$|D1%Gtz+k`=zocr$NU>P{mFKGLU=ZvnY8Cb`ZnPzoD5& z2HIHvEvgltk%5vUSCLuzJQDJEVRPevshHG`jFOH2fwT4sQ!mwUoFd?yO`{`;Krtljlbh-E_v^l^n2=|uZ%L{uGz zjLluDgmDU6Hg1I_^AAnou=!==_w#KyMf{AWsRLN)u>wX!ve4Urb+R)MlMlA@64LT^ zA|`Vb(OL6h4AJXSX2?JWHruf=dyIV-lRb%T{j^Cz@w&HBB9*{i%Nn{St}t>*#5(au zNZ;fST|GPKSop#!VjskllzidGMW{meMdTGrVQWD4j2|3U_9A=ouV_&FoIjy)=LsYy zZil%pt?AaV4%DZ6!b6(<~jE`l)iIYCG0;@)krb{czH=V->HA?Kfg$^FPtB z^De{#hmk4*S?~q?WZI)YlCYqT|Kc}%Lq?BwGL}* zzJX|mI$bJ%#$EqFX-+aMtf_i}J+#Q@7`zGT95Puven9QSDP*Q*z(~^)IwsCAbk9I^ z*%`F#{fUt>Rev)F7$M>WQj$|)p=k+iV>ei@ZN`?WDTIf*LtEDp%Wb0&lJf~9RGCI) zoX-uCN42<#M+Q2}h=lyMOOJbycE1E4TYHv!F!X4`_WpwiwX=aPC5Jr+QM_;IW78CO ztgPRM6*FX@OspwCgO=TMJoApfqha73gs+W&fj$u}Lg1giA0?BDu=+KA@-8Hu&%wu*j$QJEiBBuG_w7fhtrc`kd|>W7g3@RCaRdTw zoWF=d9vP@$+P-0NwR6ZRsD_O^u4WTbhm^j*&FPmv5+px^>2Y$-8XMOk?FqoyqdgnmNqmJqg6!N6n{g145y%7}K= zxWgvp5E@uAg8#dFkd0qKS#>?E^=zPPk}J!+7M;;tN1I@P$LZfHe>WFytB`^Cb9-KS(b!Te z6YjgYx6l1`|Z_v1Hb$Ut-5 zL*3*JwPYZrXP^8DHC>+~X0;PE^z5K*xfbq;`yt)U8Jfy{-bd~3pHNjhiC|YdJhfaG z8qQhRAil&B4bmNdg|u-OLIZuFsc#J}vpB3Q_za?5zpxzyv;ENc`vrA7P9iQc4i=_% z(ABemdH62m_7EY3n?JRCZX+W;4we=U&^54xNoX&!d+xC0l*+iH_dIe6QEPl*L^{*R z9uCQ0pn6PMc7p`!oVgBeud{*z?}R3 zdr10^A~7nI=|)}iFgQiO3eldQc+bi`Z|%<8C>9UE*Tw=``qnV=&O@a11{!wGJ--(Y zlGQ!rq-Md+${xCgR#@g-imi{@2t6@fGXp^(KatF!=Z*l=0n&UgK zAZ2|p44Eu#VYz-Uisi{i)$5baqUr$`QBeOW%El?EGW(rm_yRH$)MD~N#YN}>I zCGb`Ik?pvR@}_4HY;OT=13MVHBw%&bRaSPKyRZ8wuib%QXKQHaSwr730jmnXfpiyL zQ{B6yN;`K@(QyccUFT3eIs1Mna8^`U0XIi0Ry{<+hbjzQM7@gM{~mRFE+Q!+3YMmH zX1oO)6F)@RFv){2TO!!~07-o_T&+D|U`$Tp25hLmiW>fVpnsQ_$;Tr(ha#k@8;>lh zeHhUJj@;1>7QS#vH~0+A?N5nCPFH&5o`;MozVZ8FcFIkZ zSC%oGk}8Q9StVgpH?6gbZSO%gdI81R>Fj)QJ!VgKA$Q-jJ*nJh-S}zb?-YvJTrU}+1rw#!Nz1ToY(F`5fLuvdrB5l)u!kuHrKUvSHdgy1(ZyG&jXPJ zO3EAI?Q8*UJu7H{*5-E5)-{7$@_Q&Anx12n@VJTexM-ML zI5Pjybj2hJhVC-mrmjEPuJ2Kqoe2YTS58NoRUvikwnCdbkjUy`M(YWc%{NU$Y~OuG zlvD1D;KL(u}ZYrBxvO-I=% zw`J>gQ5D>0*6ms(HC~XfNi+s8Atg2eR&jN2~ zPupSxytcjrv07V`bpI{n7d61c%94@tO#*8W-*{8)`;zVYPgGSk!pqGb8ah7k+>iqo zk2NsVHG$sBL1cE^WQi~3-;o}g^#N7%1e3-5{v#&~{%`rdqL{rW{(aB%HF|y~{=)p5 z_BQ4htG0c3knh;g|F&8(P(%Nl>^c%q*Rts~>{Egi_d=dqUsP8YY|GMfE7Z|7qM| zcaf8k4+k4NXzN?T$gc(4+ODIPt=s(msQq72T_T2!9c^B; zgO*_q;yW*(dg?x_wp+LB7t~V)$$9ZN)Qc-;d>ljnF_cn;K=vJN0-U}ky^njS>3k0> zJm~l;N9ddT!D;g+knX#O%7zJeSvx@6z!rwqVF)R>2tFeMZ+@)bG708L;EO% zuGnLPXVEM)TgYNKVLtrHMVXbPQn;)Q*(4JM4YU-t&)W82=dIq3dB{XbT-Z zFSw_^i_%dw_9K(U4`|qN9b3XfVPfV8Jp)S^MZSSbI{K5}_K`sHI*qJyG3+$V7!gR% zH4VXKr_s9iCrF5tsq&k-KqRd@%qDlYH)5M&_3RR{s`v~VpSgmptSZ=)%HsZA`09`_GW+R7(2tzbPa;Hl)%!! zfmLF$Px}UqOF0+hknQ{{D(d!Ng_jxgnY3-A5SaZTBvVA5`fvAO)WSsJ{P#0*e(T|Ofg>m<(#o<%*HeAyA)M&@e~9AC(e&eXFEfK%r2Io4P5$4DdVJwc#ADi;{!`LO1=H)U{B8oydI9 zhtY`-0>y_0KUL!k7NW6d-@JO0zHd|{1F>A~`zpCwYR4x75s!iR42t{);ym!lK(et5 z@@G`N{(~TV9{$7U{0L;ACHP3jZp(2)g$zW75%bEsj7 z0)F3jAlvaBHis&Zfwq1G(eN!bZRP#fHN%1!r|euuV&YSBaSg(oeui{hwLM)!I(iwU zqB;dKP&>9$U}L%s$$#Jk(o&LO%oxwkFmp~rc=btCjNV5bIf9IB$|+yj#0^w5y^4sn zNpK3tKxEA^ln>pXz5c|HsPCp@!q&5UBu2j9wl`6_^WV9nfoQ=Pzk?c;dtD?bs8ZJk zTN^L7`&~yr4pHrNOc*T$oP#;jcKi&I*N1RUHti4%XuP6X1~FQ>iW!9J?n6k4 z+YB8;2k4r4!))zy5baPHtK+v(Gjb7S4euhe_G9FBoMnbA_x&A6CT>DB_$_iI?;^AQ z3l#QWVxYoq z$vL!a9AJ~CpX}Zye?fJ9JKP*Rp=aa(Z5>Z| zZhaoPT~`?)Xy)AF9X~@dd;)pfQdsbtNRj$P21@!2;t^%p6{ISVf&7M0G<_{T#&1E? z`!3eFD3E~)&Z2<#nyEP4AIeGLTsqoT6Sv6%h#deX%)p6GftaWipWD8cTGkTzfH? z{3~jjN8sn-2VFy|q~rnDm`P-|U*>kT$#oYwJCc!8C@iamleIVWjhvuwBqv|c?wgr; z{sSt-b#SnDVgw3ZJ%9LT9Yk*b4SB+gt3Jj0B%RpzIWjWq;A9_y;M_Nn+k2PIYkKyY z!E?xnS0Dqe+KX~>9{JBeYtrORlr`={u$=-KsPZxzrsXv810=m)AYoH13=OQHrRNMw zpJGJIzC+bG?F63Q4MuAcId)f2R`(oM2gks5RS7oLoy%4zZ$wa=1AtBo!OZ zvjDa7eNocv2DU{O1Qzfn0 zb(kgEChwrC{UG8u<-^e}5^KsoMe(3~9VUmi{|h9C4EohAc?!7=$0 zmONFxwgSe#NA18@NJ>ax<)L--eBr!)1bI9E$({X9?__$NdPGnaf~radGLT&v5*of! zUMFX87qvUjB4M>68R#G?25-sd9L=-y*c?JtLr5l$(DtiCX8UWn^<*bAMQczl`dga-V|ijlmO0O^JiP zjtz7Tonh*jf(X$Gh)3n?wS4~8PTfLP?G#oA1!I{W%_|}U6(U+XyYA4uE3NE;pDR@c zbbyXd2wcN^k=k*djc=}E1!vz&$8MlhG6-L{Am|&|L0j7ij$2+r!N6T+kIm}Ijf>Nh zhI8r3MPmFm#C@M2CddPt1`g0Q_ksD^Q55yuf^??(6A{Td{aMxgHlnC%DUp3F*TH@3 zyO0oJiTa{3*!vMSE0KYoM=@1i;9r-?^=~LCYk;>iktm4_REaq08Ad?lw`KEn*JW(m zv}J}2D-F$!)S>i_D)Co2lcO$pwj(iMJfAaCz^*NFvBd}cG0lFq0 zF!OCh_6V)5{O5QS+(X^2?~s;O3|l&~mXU!vk=pfF)bZvvTek}R zZy{}SJghBfe%NBUOD2gqtZ0WsfhoVD%>SU^ zW=-dZh+W5#fpp9_z-!BXh}7Der2B6muaHLus=q#KyDFe&;zu<0zm8R*>!C}DGy_|9 zyq~s?GrZFdplJ9fUi^Uk$6F9}e2$#j_mJ6e6qTblmHbEP)MXSle9ZjEtoC!t`cqMk zO={$*;wy@keMKeQfo$X~j||khVBc;&FyFE7Ej0n)0yOl!!LCDb9kFy^mJDQ_jg3`D zQA4o|l|E^nUqVXsCXT2^WT43BP|$l5GBtJ%t$PjaBXG3wW<)hEu0zBiU#zUY2JxJF z@)l~E`xMAPRY;@c#LV~8_GF&2Jz264dooJdo}~UL_&)F@lH%j#c90*O*N;9HJBWcu zTCEfC_6%b4TT5p>R*Q~9Oh>m&$90%IDY1AJnQ<{NVHLKi@|_6DvQwy*ueCEi%#NQ> zGk6-?lhe5h11ttKj#B=$$8MmitP}?1vl_d>*d`B2qkosjuei-{kN%uAhwh@Hb|*sJ z9HF6S3r(AFcx`oj!gW<#oHUB5(~1^oZo|3itTukv~YuA9A}2c?rZip+JYecnhWNfa zsGsDGD~g5mA3|b6Jd6yjp{?f(bN5_C*PcN&Z9tN*YgGMFGIRu6v6{tZHXM<6L_FQE?ZgWtFfor{g+Zpl`SW!FivcY>bGP z@*^TxCABNaujqo8wH2#iu5Ym#9(gCxH1&|rMmAvI=SYl;m6L%s!Z-OK#42mg9kzM2 zx^@Rby#26T&lXxHp72O`4uxc&(?+q*SFk253AUv&W3vCm3IEMFO zd;dRJ9822sA!0U$!_deYT3UW^j2J`q$UW38P=Au2D5`3N+cX)-t^}Lu2t~zM_hZFP z)JPj(YCvJjS8CP6*O?779r;}3r4 z({Zn*K@DvTX?ma8mD^iRp;{>}PPXM)9+B~<+fZDx z?`v$0kCVqGU0~r+fJn(HRE?>LOHN!tdEE71I6Y4r+0@Eok)R6tk~)~vF;9#Pv=wnZ_tnHAX zHVElPZu=ETzy}U1 zy0Cp}+IA#@_ri4FLxU0RW@8f4iR&n7J%XgnVmMm4Lf?Q$fPS!vdl_Z4F;ShoCw@eg zq!G^cly`7~j;=3!()OchSng}e*JIjHS0bB4fP)>gg=mwPecIQ|ewpcu*+3_ep=bkn zWGdJ|myxqQ9cI%ukYgz}_uZbu20DVoEebY}TP~txXL&ZzkB|(L4V0l^1Bnq=zgQb+ zY}N*{D@8)XW3quJuA{PUG@=x?_n9X`#-K%)Lm z7@h>ozfp$CaUkIRSu#*l=OC)bepKo2Q9%81hR^H-ukHVcCDr(C$-t~A6~V#L9Xf{A z9CNbeEUJn8qS7}RxTg>o8px_#P{6}>%ZI2Mx~ZnE9OIv2tvwkyRN{}|3Io?pmH-u`x6l`Z`p0+yVM_@d(P?hjh_? zWY>O)y!ua&U3~ys^STfm;14}b6Lt==&Dt(x_WTSP1*fRLv3sZ;JB)<5Z7?Ucvw=A* z0?H6x{5~?J2aqB?gyP<-oKjS|-$j6oxPjLY?CcE#V<+euS;8@9FSb@6M5g2*lIy-j z^~khL??Bc#0VjJq#>mw(vWB^TBVx;6MW*CknF_|#dv4$?~W;PFp zci~~qS(9H^GI|+hvJnKjdceqb1svA5AhYidq$G#2yQmrZ2HUo8hn1NJj9emEP^NsC z3?w3Goknh9Ej%3EU|_cn-dlDcZ}2YnTi)|kkNtp#9))u#$>8pL2i0TWtF)b!$vxB# ze~$R5c-Y!{vSjQQn$?m*gq01=_9 zU}5VE6R%8cYCNML!?{32KV>oylf%sS*HG286Coa6Ft!bW%}Oa!DPhB&`3`D^ze8qG z5v*10@h_dkMx?4k=^tOa_ZkfM&&MS2v2~isR=Z7onYWvjE&+GkdWcX`nnC- zu8YXZD2Jno4J&V|VHN?`@L|MP9zZ&g%Or0it@;_nL?$w$Rzu&OG39mjEnv2O2exWHb`w=S2eI0P90Et^Td#m`-kI5XK6V$?J%_P1GL%)q z&@o#9$H0`Z=kTM4?Z>)tZI^>S1vZzTxGjMRIa@!+<|P*S)^~yfGw3Q zHMWC>mK$8w)FZy&1!UA5;ABN9Jt!lhg~~H<$)eJRy$G_mV5ON2oa3>o{1WHzOy3t(`e}X%tNneUr)vdG zT}K$Yq+wAr44GVxcGR8+ym)|pjD(A77EVNem)TTXk&eON+8N8K{I$Ll%$<_4A^Rnyi{EGCB)$3|5_4PN@9hI!Z7Y^Ars`bA zK{6yXTvsEbM*}6*vbS!4xupX;)=D#M;sLh6O_qqOs%n9=qZhP|i0EqpGv9i|u(3(& zHI2;!NGch?N-tk5r#Wh54_%jJgwwoZvXEo!4n#di5gkfL_&GwyDhv)uAG5^0>a~@B zC!@2yb1S*>(6nRUeG9^J-$#b*AX25DA-D4q3!1C^_8~#`V`W7xT%6sx-Hdjnh_63| zYPIiy$mC}bzj`IxF-t*V>&*vHK5%np&QXv*FE1k-E`1rvH3yMd z^D>glUq;?nVt|A0g_-UHrXV5N>e{bre zZPA;o>||J21T2mD>tGMAgin+zCwQvSt{W4b`uhTU|N4TSlkrL~R}Fp+lMyImS7aCH zAGvl6t;0S90BUv95|E}2X`wwadUvX`O7&$vqmAbJHH+k#!Sd^O-svSP+yx8aDEE}U z*U1YBSrtn6MwU#rA|)j0b@=Gii7FwU3wb8NtrI&KR*-LZlTMW_CNnkqUPq++^3M_5 zXQWlSmz2yuD11>L4u2bP@kHe-A)B3ag=L0vp01M@=Nr({(rBv~Eq_#SbMrhc$&Ue? zYTZ#w+|t*5zu?ZxQh&oxL-jKGbL>rIqc0h?Aa$Tvi5!_ORb(i}^rSFXP@?E+E7Fl0lWZEXmyf?hD=>+~{Mxkk*cJXoNnd9Wi`9_3g_ay_ z+PzJ><4{;~=UPJ-!{Ou=zA?5fd-UF!fn}5Fxz=M)6~$ zPG650p*`fAl-eYWg{)D3Q~6@8z;}dvwtLd2xC99qo0IB6y;0q0X}oV2K-x}clK>N> z35Pe+K1p*=Nq zc~XHq8~f-|fJWn8|9AlWL}h>psb0kGg2HvJ7vk)CbB!ni+`eL>K(upLPdi0dHG`Eo z=&lPH4^a>rFt+`|7YF~8c-QI7L<#~bAe3JO=9tKphy^I~UQ@)ow7VLGywzUh4YN6SK`OQP$h~vguuMXrnK( zj)Fo_lh4VNklh*Nk{|m$BbgO-Mb51RMzuRjdN)8WJ27dax#5wNuCU!)qet(`@O-vk zyviKc876LjZDqfh1?^Jg%j_@s^JntAfKlsJyoWJN6|N*o7NPisAM7b~t4R%gqLsHL zb^#XpIueUkA`2S%5`Q@8l;_t@{27?jUe?$+F3H401fSynimQ6yJ2dVJjzDPqV?P(U z#;)KjUjhd@zfm&vqz|GsOa_>e)xz!-i-F|f;3B$4#3~(p@a+Rcba=t6hBBG0)}|n0 zkpr2L${-5j)B|k%w4ua&x^J4R?&u%Y9*Vz=d_CS~#r&w*U`uK`ypYntb2^T-aru|D z>8TjR<%+=wSnTWR!4L%R{wz0LuBMTa&v_9ANlz-svOq`a@0>jGq{Pe9Ts36N;8Xu4 zYe0L_%KCJThdN8fiJru`8I6A`hImQCeX`g+Yzc^h3dlu5ucyTVV~aN%Kajy}L#b~X z?r`PD!1ks0&pI?!FvEMr2^IK--0BK~+|GV{;Ky@MDz0;gLhyB_uyIv@(7w+0ED&)O zAPO>?dl@0BG9T6tG^n3ZKIXgY%~WA zz2P#f^8?2r_Ej*FwiPzg_!H7zLq?ka@Ogpf?nj)R9TO*07)R>$oJtv!A^Q;r!EoVs zAuZnxDAi+6c%UP5XWJEa`~E=ZwQm53lp(Pvd)y&<;v8=PcCv(bVC^v7#zJ_>l)%Fr z83h)9T^dUGK!{1S64~nt0&x9s!@AY4+d&Xjhc?EZN}`a&__(hX)g_-7^<_H@xUO)= zu5V0jcMuF$26hah8qO6;wVIuh9Yq7tRNwPSNBahl6B&&~XMi}EU2&SgT?*GT52X&1 zW0CM8fgPlV-$!WIvE?Sp;j*wnOYj(!Uk_etJ$~NtR-)cUe|#;Si*1n#N<^0YbE%;ZGmi5bh^37J06tCbM04(G9G@Jb@H* z$>$v*m5-AgyiHMTDQb-Yn`zIaX@=4}2P0oPLU*q@A6Iq7WuE3eSQp?sx@SCo_)PoMktlnRE*vXIX26pCjcQ{3f* z&_FgZEpA_U{5xDEsHy@@v^>?~qx^NyL}F}!Z&5sdA5@u-Hs5J&)Hw=i{5#EXVE3}2 zxrL%4rj!SQ8X5jQICF@GQFn*%V`{z$bIn}p{=j@%%*X6I*b@6ccID5LBj^GQt&=pc z*w`lfevtF(V%OXeMCjc9>=cP+;ex~QB27Fr+)JN=Lu_;F_g`}FH)_;u8rXo5F}|Bs zoUn$X5*&;~r^=45A0RrMoX^GC@RpAO8{zKPwA+%fCKhAxwvNpmN#(sKb}Py|z@|^( z;!(sKaMn7Xf?@>qXf;+Kiv>IS@=TUBMYB`VMl~+ZQowXaO*6vD5=n6*NlMF3enTQ_IFj;uC_Y$Z)N|6Nc(6%~ug+34HlR zwW#S~i&m>iK}oNtc>F7^pz$6AJzuMt8K7xgVwy7#>8}8jA}I4#aDRrYCZcK2&bumr zfUdne%AyUEs~ka@`jY@g&_Yv4RS3R2i*^sq4PgE=MDZ{-a}yJd(WXu4cp6Hs&>%DlHSwb4KYsJ(bH`R8Ya5$=n@^q&IOF9J4|rP2KOU zXb}tz0}jK{%>7>QtyEMX3^Y#giW9$df>jjM5vR(cRDrU{1Ae8yD#N?7P@Cfnrr{#D zs*A5fc2dRNa%2#Sq~Zd+a8nDR!@umnHxGwpokh;OQ1daWr^HR6Qip$VC!LKoxe0`* zL@Ulcf{ddlY!)KlXc&Q{K}9#3caw#Vgb4+q>enCOMi0r4pD~~+!8F%$+4@hWCDM!R z?77ym{i9MT^V_L4ivn!kYqBmb5W;q}=Cw)eN4w#wQmD^gSZMOfp20Z)hFL_be`zFl z^2$=LX$Gi|o)e6;)Gj2kF^@IgFi~K-iDEu%0LBZ5gi?Jc5(D}kF-wn@k`hjeO~oZh zxVUVmvRZq3fJpnE4Kxc_pY2H`hMj0m4MR=rB+N%itxkT{bncK?FO9wI?EYiYs!j|Z z0YQL2e~r!UEs(qxYIPY%X^g>piGE|(CN0QE z1H5=z_kF`dx$r({osgQ^eXe8r`J;b>AvB2O;*tYaI>hS9x$c?p8TMH{BeS^so*^^y zy!~?|(ld?gNPth6L@GXx)IX~-#pN#WWu@_7Iuce4Woms41y4^!bXxV4$Wl9z6&;JW zb^)+5P~6nT-U_(nGB`p;@Z)u2QJ|S#6`!hz3&~N=qWJg>d7D*NjX@)-(-T+nf)bZdn$!EH>u__ru zwRAj3k1)&>nM=Bo>cdFm4e_-TON93@4KJ(dnufnQ(T0wX6|mB9CojBCg-R+>&E8bg z0gfW-HcOBF^K3{fC&pIofj?Jgt9FmK9dB(V^=rb`Jq6@6#yO8nu=Mf>`gfPq1nWvNw^0_)+W2j&BCt)>DM7~q?M zk9tjA9PZJ<9({hfEYUpjGF*=Ap~R?Ft^M?+7-er_l0z46`%g#SV|hq7Uf~GXqR?mm zL_OB#NGjiHa|h%2J|F4~q!Bj!sGF>yK1@2!gKlcmgmZ7?>f)*gLnfkduIoW&>d=nO zroS%f^N*K}p=kDCTSdeZQy@!!(`09ita$WPK|OO7v3K+nNg|Xeyg&`ecK=4^&brty zv`zPgY2_c)%!Uiy4*?>f4#ga4{HXq8|EA%vs(!CEFyxYGcW1R=i$f*jQuCgdd? z?Il`78!LGE>B|?dp2hSkIhb}69U+ANff3RmT`Nq(c9$ji%8-NzR`&+mgI}x_D@BeH z_?x1jWvn35@^|5i(|G zIET1cor(OUdVhQ2PMopU8VsY6G~{SZwVSP@h63952 zA&QMGCQOjQr6=^gXv;X+BzqT@e*|u%ZUXJRa>2(O#JS=HPZ}P6~d0FD`)x-fK9< zmDdxCd+lNDdf~c$-6jyDVx2{l{sAuX3eL!2TdPwpOa*dbBY#O;1bT*I&>`MVe%Tx8 zdB7iZaJan4>Q|DDQid{6G+uKCL;Q!I*ZH7)n9iZDpB}mzbBJQjn9@r#HKSngAQb_e z_KTBet_(zjg=JDe#IE*=c-KqJBrofcSn7zAqt?ajo+`yt$?O5dQ~Ki!t*W@K77=4; ztu`NlAYJx5S$&5m_r1kT*r^<*r5~4UxJ#D^V{_P;y_b2!8pdp1Q<%dD#u(V;ok}Y~ z`Ob(iEGPNAC8T3sQAd9kTMJKyvmc`HIa8;7hyA^C8{fR``GlXNV@ZbmU42j-;hzlAi=$-WboA)C_Q|0Z}_9toZ5Jy<#FEPba ztv-(9U*F@5e=phn`{0@aFtt2_b{cil@mCFnWbuZJy>uob4i2OR6}D*TdhLfuGEx3u zO|pngoScsL=#GV%7d9NepS`^qOc`T~m`Z#^ZQ!0>G4hN=a3YVXue8Dsz(!^72 zg}i8#BamHQ!P+bpq3ex7h4`eI(rGmb3)7rU*J>N1qJWG+)gXwmm)9SAknXO=WWWeO zq!UHO@E|mQx;vV4T3id23r@qD-6+xG89D6}_acEa9)_rHNhXH?fSG`!QDJP%k6xDq zXuQXPA?Y~2VM z(causY~{z5PWpquKs%me>5I+b3ml-I#?(a7l?s%J&_rFfM^d3B8lW^fzoANh+!wp4 z$zh5Ad!n$=HvA{+n&}BXy!Ag79imK-dX6kaYli18G(j7Yi*cbCnBOIvEH>u_m(qnW zQ@}(tl*qy16=f((of}}5ne-&?Ez>JXi7Sx>y_WT$Xe?e-QCUFd69U6?hWLLI(# zNueUVo`iR>QMGh|fOy5y@)Bf4VFhgF_P&t-riT&O3Jq1&X{HfKpc(eI-RB}bs>I6{ zRS|zO-JG{pB}}jx?@h4b@z!8MH+*n;-4{7e>iM+P6^rhu-YIUJg$$t157X$Mqfql{ zsKBqN6|0$&SPE$E;=~NLe{~x<>q=NT^2>R|Zfq2}yuCF217=Jt-N8cl`!?zF!|S== z;hYR(_42FAtlM3PYaG1!BNMpw$ibW!8_DLT)}!wyDYDlo5ZpP8y;t2oIEY?Z=!_j) zV<)KmOi=PRoSa7#CuXnOh|g@J|0`Vko-+cvkTZ`lLGH&5Lhv0RB_^!Si$l9T3fV`edKGDj1=1(FC_WkV0=IVP%MsaW%+n}% zbUnmm4zpf*IP=Uh+>S8x{JKC!rgRg_oEw2PsHdQ+*b5oguN4esKWR=DdPWSW*laf$ zlN73=%4&BU3}*e|PZuz$8=r74o-KCPBu$O*a~;nLMlbF?;E5aA)t1;B)7@HGHiF%#y^3#^_uKCJzDcWdeEsWM?o}0FrclVS!FoP%YX8@uW zw|TsauCA|dbCgB~9&X9u#`&7!24~Rsi5NS&qVD_3Fq$-!6>&+~r|lf^9e!}Ovarl? zVvjh$+{8}7RZcHybl)qgRUc(REAQZ`v03HsJ%*&PB0pec2^~tuVbH7%VXIM+)*`Tg zUGD2)LJJZ2x!)9G16!w%oQhu3?@)ksm~-h^;A-J19E%PPKePpuK3*-0N37Q7mRkW- zJN}SZ<%n)F&DI+93pk)G956`$R_Z3Pdeb#bpXh1BHVP?u$%BeIDhqktC*Ghk4%pEt z9dqnRy)a2&RQ%>pDWhU9E#w)a+y%arLX>#q;tjwdh>;n{jT+!#QBXLy=CAfYWNNDZ zlzvV8KJmL5)`*L0=yoW)VVZHB9cF5;`r((fkXz*s^fKwD-e{M8pF_~*IgzeL0{%1h zZm0OkM`Ch&KZ^!??O5@0_H8!(Ej|3#oLP`vP5pZnIbfqCz>o~Ai1bT>k zGrP(7sI%H4Ih{-Lv|{EJ0?XOL&-g7qc^DJ&`~K&#GJ zA?kb;+xNr>yw{^i*D6atMZ;GC(TlpmV{Ex#jEbwwisGq8 zkodAmwPNH;O9Toi-ZGY&!5SXG!4XnOoGmG7B&|tkuKGpm?I9~Ep@=T`ds?EW3h`&O zR$egEI?9EFBw{O+9=)X3sHL4YcX-L}RCnoLk&bM__)3=ANVnu}Gz{vl-{w!$D#+6+8Pdz%yx#HDI zm8gyP6LIuDdUwF&%pNfpRtimp?F6&9Fodk`w*nK?4tx*fot~O>Wm4Du^TbxP@To6l zzoiw{ow2q=%59BcMH;5YB`$otTyH>&EXw^LU>=PzF~B#`zBKC*Io;#Jr2iRjY3WH~ zI@m3ir*U%m=}u6Ks>U-(M&+JRR!-R*ro@uunL|$`@)QUuX$|gk4>DVEoV`LZTiENF z>?IuHNxEZQ>X@JL4DGrjS$g?N!}8|##cH5GnCb6fX0PBoY&60b%dcNsLd?fbmL-Sl z0r0VlY6By@8tz>Y9Da}$mIB9P#Okw8)(2GZ975gVeF(kOb`3%H#G28qq4Ob8>qVGbv*;gk76$3n>bl3klUm)qqes6eQ?b$$jWK@ncb|f0&Nh(ZnB1%1VF)(@q@bfNA%@sur*0g)q#$W^^>+Mq$3t8hGwR@(x z5t*AhB8u494!HxIQ)9lRbc;|ZYehmOyXfwgm>Q{jG7CF**7p5Nk@+P5$fXvu7N79m zfz-sv0-P>_nVH(~t{93+5u8n8-r%#&Zzg$Ri0*ebZ@|#AjA0~67xVYQs_J_%Run^-9AHo{Y!VtdmP!y=lR z%BGH&K>k*Q(cvHPd_YDl({p6hW^eC^CaSJyE@y`^P6)1Mo-xhr#)t6+)&kRQ%$&yu zoo%DNY~dNs@VXH5ku@V&eLrU+5?%4M1!-1pmgec6@>JZ?AX>l3eLk7GN(Suy-&6VU zU-RoGK)I)I19rnqbe7$NKd^#`h_GB&k@h*Aiy9uX`*Mq~6k&Sbbh_Edrktoac)SkW zX%fW9?1{1w`xmSMuT6?n)h{0;6{Bh8+naMYGw$+I^9IFxAA3<5st?196*v?XWeK!q zs5_n8$g6CPYr=WzX(AJ8xVZ_X#}jG{#>q3fcx8{?VAKM?aYMQ$bFY+Aatb=vkjC8) zqib9pe*Sf*DL>g^k=0(4Je$&IC|wRqKbhS;B>jMkit-`FaCj1cf|4!h6rA>1V8K}| z{@9ewXY5y;dsEY~pDxik)AAhe&OvgT5$Z(`Eq@!Fmg9hLQ&0y~kAbdLZJfNIr9 z{A%WNGSl+c>n8;Ob>wt$M+<%kc}h)f(HO?u%_50jq$r15Ob8)aTt+*hDs$)M!B6yp zgN%ge<2Xf3HDhRBqiV2#VNB?Df6m)^T^qQ{Ut|U`y&ynR2Qk}V!( z6qG@`IFp|tG&)ZBV=Bc(S2vhgNw)ky- z6Wdo3;smeFB5|pO6xPEueis^L@|y3c@T=*=3m?iK z-AEZ@R!v(ckj8oLOo5>X3bwCDcXy!K%zJ>hU<&-FB#YRk zN}2CsK}k?g)CrARh5`S&9Z)Bs*Ks03XjZB&l!sD!E^fG*4z&802R=Y=Iwd#S(Q4}y z>;hyzKE>UMR8zB<&^45mS62c2)i)%5e;CB{ZcBe1UI@6nxeZBVb3}6qlR7oRo8`nP z*%&&f5o3nW_yKevgL>|Uj$`LassE)Q!F#~E(B^#EwITWN;&^enV9wjMT?3`-xAcDO z7_%rK;C-Lm7KfCb_!zPhmsTBU-i!Pi*DO-@ZWrJRJyN$~T!JEjghciKwx=lP4K7@| zWRhf?wf=xnT)2>SI?!xo9dQ2KggnR|xoL7fkYHG>Qi*G5VnGQ#?^_UDL42eP#P+8j zVQ)|zXvD%coQCp+TV4F>Boagf91-s*KvtfOBti_h6<{+$O?+L9U#pU2Ar6m&vYxZ`B;)o)3|Bdjfcm0fRv|svyLUpACe-|8*2NZcINvU* zeIW7cE47L30hB?8XfqQq*H-C>NH#~Q(>mPa=*VkUfl+K1ZsskV6Y`hu7sW=g(=G=W z#2y<92i+cEOoR<%=leF(6r3M}SxL&*M_2k05Un$iS~K??h|Mte^<@n>b9E@$HLKH& z;JQI5s|Sb_v&V5=An?6M)?sfA*NZh*Y&JjPAEf+D&FBe_#$Z%xi3!_I3T`$ESFHU^ zkk$gfJmT?X?B3+hip`QxM9<6&O!f*bx0?;etT>WGKpl459q(@AhS{jVdy5g)(0OrJ z6-5Nsm1i@g)`=z2Vfwo_`~*-_bb z+MGiko}CgVsUaj5w=*~igqCv143s}N;1_sPx>eEe;RN1^uZ>1+#T z!G2GW*sPFK-9ynMhW*gL*s&EDxYl+K$dh`|eX#i&`L3GgcGi=8;9q0axl15W>YyLG zquv~NqU*Mkzy2+e`}UVs?|Rbd2)RXrvCQOU!ZPP`^gRt3GDs7FseMB_lZ-CL_7%rw z4(%u40RG9=H?Stj8#`SH>MIJQ(DL)!{yZ)xjlqGLVrqY)G_M9vBt@L?7GA!8P(wch zE13Pm`&I4>FPv zngkJgX9IBkagDd5#U0kC$ExcT5wsEgq=Wa$?UFx#-+#NVHNXtIK9LAKu#`PQqUL>Q z=NLIyn8DYG(KC?>8Qjok3-KY*y+P?wCYGQHxY z%}ye6#&QhQe1l=uX#Oo4GYV_#Y~&W!v7_eibIkF|&*(yhQam}+iD)cnM#x5JO9lm1 z7I&g}$-DXpYiV#Rh9=Ejn4vg3M8jXq#~_vI%qu9kwB=)qZ`J;Gsx(R~7i_FTWZ#0D zkQQx6Vjo*#aA7ucd))K+12%pRJhMIFA+ZjmmKfcQ<-t+LZU=t8&K&7!$GT8uD*d+s z)ZG@|90WHpx%Zqn7`kSH_e4J(*b$lin|nn0Yp)O`HpUd*xJzD)Wimy%rFd<;OH1X5 zw9SgbSXVKQlZ46gx91-!UL1!vVR&AHN zt$&S|;JkvQrfmD;@euw3Q!DzH4o|xGO1VE|js@0yU(wxBi@MTe*YSY%u**&|8siZH zjsUtDzWddb?pl+fN(-n-5iNM%+M^wVk>UTEiOd zpOa9+ew_}?K|H3_(2tgX!fJd`5=p>MC(AwatT3GD2BhDuE?Ht@2Uv!RB+|D?&9{ub zZIrupuvy-h9t!`1MOZxd2~d9SD)jpXB`P*O1fysN-&(5=W352Von03jCjJOgfnn#! znp{{8cxb2wMUS5VonPIgGw@gbVJcIGrSrYkT|Q!o>3?f@*cV8u$1dI2}Tsc zg{4&aBM~M#@`CB^p%e>!Ljv@Zi&%EyMRZlF-&rB8tH=ti0!YwfZAs7pOr)wmV*yV* zZ9N2_|0&*};w>YEv4?E6a{)}BjJakb?DQjEh@X&ZE1;hVQo!hMNIorjM9(+mv?4GL zLy%eIIJvk~xVdb#)8tnxurMLWNL;y63C@&G`uLK?kj-``%&rOuRznxu$=uImtsrvV z{s!hF5^&h@GnMXIM9Tv|*Nt9{83^dFwb$ciDG-{vA*sA*izl}(Zg$Q_yL;hdGjPxq zeA;wVTJ)0FwSEZQ%j$@G958#V9*x*IHh8Qm?=LSZE>%lm6oA_(t~wxvjDM=}b)&Jr zXcoi_thm`~{(*nu_E2oHNzMH9xmO|GnzSo2w(2t$MHWUi?{}o}KzKT6LdfK<1UPYy z0WXP$pNH%nU>S`YfPqaM7IwQ{UD$9=2X>N;7+$dp{~8nwX-g#0oqeph&=QWS`_L%j z4B}Bl{`}`2%{?K8d1TbFH1kL|>p# zSAOC9e~0GcLN~QvZZu~va79$JGUeTVI>~DX)UYv8RizEp;y!K%BD^$TMABj+I*nJ9SwIMZ#&h54=_D=WZvif`Q zSpIa4fM8tv1G;8(so-S1{Q1iX)*UHG+6uoz%c|en5ns!?Gb)zTV<%+j2p@0UNPHmD z-8;gvSFDOK=1xQKA9XBXLkcp*Hc;`~`i4sMK-YpNSA4JlI!J5vyOZ{Me3|bK9r*%* zt4tdY{0_&gmM=?Zth^gj{dqW zia+uENAS`(o6Ur^PeFu@=Gv;z{q9=EJ*0L1c6xZ3FI2+z6m{8bAjufOdNx?7Jx;1$ zr=1g3)HLfdel=ubzn_1Rew@5@(7VTi!)zQ%|YOC?ttLnn%{Q?(>%q{lk!3 zU5!>^h+K?yhjYl1)I-%wX%Ch>(w7}GJZ%Xe*JZWVu@r5ikmuURm&K|}9I_?oNHqe< zOEN0i^>L z;%dq-?Fo=n9LHhnt(`!^E*L(RVLE^mcxT|EJXVblg__<}LB~_4OwLBaCshFmrsgb} z4nVVKs9bx1V^&WYRN`5wCVqxCs|H5WOY2{W+7FDq77Y|q)}B9a34R?T)5)0`$XQI2 zbrhan76_B6;Y}QpX7gXfqoTY4-Q9rm{h7;_#-rM6rJ$kJ052NLu>+8X^Y-KyrwyfN86XY+i^w*s2fH9T4Q6f?~)19j(-u6`sqP+ z)s0~FNO$`&zci2s|>3wce2(LTODf2G2R%A3;c~V=U9{zyB3ACUP zAANLew5K2aSz<*91Y7pxEMzRHIux5z{Fx*M8naUh{e&;Huiqlej%K$hKDTI%6I*09oP{BIQA)d2%>!2vC?f7IC^2uVK$`J%=> z{JUt}nc*70dWUhUW34yd2{84%z{3xP*lxiQh~Q-al6zB#R)qY(&hBE%jG%$4=tiNWBxBrB zi4nLP6#t(cXlzfmV(|{#lLa?(c9U^_wVY7|HNJb<=5_JO-DcU(&Qh{AB*HY($t$;H z1M!E$&I<24mKN53s4SG|JE+@JR{ZjRfo?l?<78jfK?6fLkH02{#u^AS*@~CY({`2` zNZmHueb5B_1*Q&|nL@V3|BF4fs<(gZOir*Xvv}75gxeYAN+;p%I9i7V;WO)##e32exZ9QJamRc zB@bw`A3%5tMG5Z7Qc32?x7!LY?DcXNgJ7Y0)>CvWXURD`?5(2v`sAnZYT^EVVGzm1-@0<=f9N;f6s zP<0G(a5#6mtMpZ^Lg}G5EF)rDieb<632xtlJT#VJanCM1{ew*EjBcR@&ok@_CMTR1VKpA7KnKy#iXSGeb4P zqGx4W1I1jB&tLm!tp=J!mDgricz%68ExX&44!YSd)Q#I<& z5^Y6Yp0Z~s^TP(Dt!}X!ck8ZMf5$Urgc?QO@GKf?ZNRRIs?@h|b+P%nB@+*33R{4Y zi}uL!b#^QYrH3-o8yzg0++XD13kmSbB<_VcNH&zjY z;s)vtuoNr-AHQ-13J{NtPpR-5}3pB*JkWB~jj^u8)t^&NALx5_srpbldIaqK!op z4Az1!$WQ6TIuue}IC~;-*jVj=m;`H;-_KO3ZeDE4pLd1hpvlh+6uE!$@LpKU4||Uj z3_6^o+m@<`juj|r9jLj4j5&X5jQu6Fb;Z@uXYD!#aExh#RS`<9_ zD`!-kZKs8-t8SqrznApU#ad8)`|s^yf2bp~r!0!R{{Lg?@_6fF{MY;#J;V9(W*Lph zUTAs?SI)4!7L=P3QvNJd=<$)G#7-P#MhL;X+Y2@Tdc-hFn-6~|k91;&&XG`)Z$#vj zE&&g>x9xibGG|X`rYsVuYp+Q?CYI~O9VW-#k^-K0hXn1DWU+j zfge*GPM7>gtqdpFrv9JFc9n)LZ^uMWi8X=)Cczu+K?U-?!F zB7XXOo{h&46s(xl>jF?)F{J*79@Sp?Dx|*ly1^hqQJWOBnIE@;nS9!EVf)Hu^*)bD zYM=G7?Al9HDnvB(g7)p6d=*)8L+PIuHqU$XQjsEpFRmttPh*AYEpdo|PUVQC0W}1Q z+^L*lQ%0_>S>~q>+l}pWuAWL%kk4lD=i_^H)34-e)~$1rHnJdM8aqL$=;`0(l>>8I z`lx>J)mvbBDty$Z=DX{FZg`)W*Qs}Y)^u_6$KX=A7(P7|eRn2+O&vC{p!`iqd^&bE z2zvlM_|`;b{4(K3;}c>0PyY5puYj8t;Dv=o^+kvA4?Zym1mz_ldhDnobGyA__*v&B z7I1=Y45~gM9Li#?NXfnZ>+!7NHh04492|ntq#^#6m*kK*Q}t=~`+VX@Sv#2ZnNl83g`X>3y!dILi8S%Zq+)c(ss) zEl*qC1O-ZRXsS`;tlbjj1ya;M)Kb>`&5Nze$v>G#q`Piutk1Y*Ps1kbyrkP+816Zo z){7%)4th`HeuaoiUVTVW+`BOQqHn(}5-+A>#FVL9ojy#zjOWmq>E@uk!?cP&A|hrD zm|*oI*AZ%aZ(vIsPaYbcSEY4SiaIKuxG2$1-uSEuZ%asbr*v}v@l=KcaB377teC)- z3!fq`=ZVsml0YNnr>sk;7UVV_l|wOCXY%&jTP<^VEq8Ch9q5&r=Xw-!g2H#<1R4tj zUEHnYPKyy{UT}<&EedikEWNq54>PFnufCJe;yTkztPtys-~Q?N7GVOvuu!%>JmH3;3Sc+Ei7^dn&%Cp5=mB5>0z>DFID1%`Fg3qQSc?rXD? zebOfMddAWG!i`Z!H~G3r%N*{Yiwu9VwCP2}N@$?T)5EDq%Dbtrs?zGw%hv6^p)CH; zbi6;~=E%QE;mr}02c4j?lNf(=%{DJz8EX|v+mN?zTNe4|w=`R-3Zdr5P(EoVu!9LVBHt`@@dZIdkfH@`F?)C|`!Y4W^V@cHw=HM7W<)OHSluZgw(D*!P+!{&l7A{(KfafYpphQ|yXS zK}EE&W0|rxsKSr*xSOlxms(8$k}*P)7Sr4B1NOToMs%w-thcwB1foxp@2AN9Kv!px zQ&%!mDXv%70#INR1uKh@@`I@4N$zXjr1cF)4&9L$M+P$3~ zg5{@x-0QKtbJyn!n^aDsCh`{%-Ix%43iwC)?Uk=4WQqTo*MZU(SQjE((5dK*ZJu5q zqNPwe0(ucn$F{+JZ7+C-@!P^3WgFn=v(Hmz|+oAY6l$ z8#X3HS))5H_?H zRiU|Lf}8pWxeFNu;G!J48?@Fgg zSS|D{_xY(lPYrO9N*p=Y8#=JGtiCbMGOvS77vZ0(+y_q#!J1<5Lgz(oZg`@eqBDFA zQB74M{%%PHB@YkP-0&>$HhH3<;ge4;HP?poExe8K(mGoQ4CgdaDI5PT;CZ@}%=M*t za|uMcX|TT3Qosk`fp;F*H?Kl*nkh_vT}bOJBb57yp-M1OpFF;a3KaZZgr{SLwh`CG z$sE*0b}z<=cqQJbh03NiDk%paO$br!AE*Z!RSF_DnuEWG+FY`r*vV{JZq+mVwX*9Z zIo|@wUS#0H{4v-`3WHWbp%O`i=bhza%pGQ<$d-#16k?k)3oQX}b0;p9#qm?WcjtVf z`T2gcY%~k4P+IEN>(y7gmXB9Ut_})>{kAD>p{JDPMH=FPQ@EAP;IIRgwGLcStvH`4 z%FN};(D7qshZ7w4?F(42=Q|SMW3g!`1Aq%^G-6z05SlgMI){Q z12;b7*uY-Ml^pQ)SgQJlnk4`>zg;t3`$l;8IKz@uP?&>UPN^43rVwHu(H$vg^=I=R z<{v|)N-x#~s@Hd0KKgxuG|rwkYeE|56H&f0q=pD@wDMkTL8^fC5{f8Aid2;j(o3XE4N8-aNCyc$w9tD<@@`Zf zJ?DGQ|1TFm$jqL}p4Dcp*=w)crhXkR%lV?G5w@Sm=?5x1ozjjYchqt;d+;F8yueS$E$aX{!d8oytVJU`r@9`4pf{gT+*0dF|1FisS9Pba(2k$(6 z#`sBR1-^~EcwRp}bTia2Y|f*iS7g(_%9XT44W4m(P6EMjCEaTMny3&|$!JjWafRmj z-85mpw*{z3Jv@6m=1Sfoevg<1=A0CgiL&%_C!#0P*B!-GO!n9ow~zZ6aXq+MgM@Gs zU4}3MuCP&|lg{_8si-dJDAG2S4VLdYEAsHZ`aFCmh&iZ+lx3|nX94jw{1k+bsQl9Y z2(gS^WW97~?I<>=gH1Q1XTI_DW6?eCs(AsSQ1Ftp|LcCX;Str(5|7S}v1q1;#L~yc zhnu0Sp-K;cLuyK@S7}u|&q0@qJ4NY5`ep=xwI)3j$MI}D2XI-tq zJ*Es6@}g5jY|TV25?(=W*e<8|XfGDoi}Jen?CkUe{0Wv$m1OCVX!Ka;?y)T}0p z%|!E=K9M|#_W;r&m-%*O@e&C3Ph<{8HmFNv6)MVOUr$xHboZ9h^TdYJX&-dKiMuJ= zW9!1YLI>#IVInudk4Tf$9-tUQ2ZM`8Y9h5ZPf1@f>R+z7>8SK%t z8H??GR4Ky7RoCQ9^_FMCLbW{?ijF7&ituW#Q}2>IUWNPMSUK7 zuWr5uHH!jho2-Y8;%2_=4w$hpX%>s@$cMW!;^yEhB#a4ILQ$5U^x6pIUi9g{Ig2+? z{1tJF%i=F9NZSTpX^&rx=j5K;n1&)r&e!oR`sjy&zrOLC|DIq}K_WujyNUtDmKlXZrx4)4XFRsCNBRrR#q?vaTxwHXs$*ox+D2wD>O;E z(+e=BymM=+UZ1Zt@#!~=jj)Z9xb%wbytM7N7GOQJU8}3zlM#Gya7T*VMTDz$vi3Ft zp%$|TRkY1^T)h5tXjx#brl4=PviZmfn=mtDX|L^Umsf*? zmQ-nMz6}N4*os5em#j(037eRu*53i{9sVMuek+GVvM>%*)@#AIzW;ffT~|r#n<+Ul z`u&CU74slZWYcXV*Mb%E*&FMVuI8$h(jz=(F`f-Hw?pVAEuIAc%QIw!-2H@To2GI<^md69l!!Fsls^;91-V{rai z0!D?wgXt(^)Hg+!e#VT4m_2 zY}jXb%~p1mespLFADf9t37m%V=(mi7L%t=$@Z`;im*{=ZmU|0LY7eYA=o9N@Qd<}= zDH*qhyYF1^h7J|qxu`F+=%K@>e+SJniirh4wAUtlTeKB0>z9(HTKI|RMWAjX*V+Dp zQ_8O0D`3mk^;a;jW2CE>4{e`V>yrQ`c5s2-sI6VT^onHBIVuc z&O@F8%d+szgl6AaMMv%0mTT>vJ-`F%1=;2|eggL=JWtWD*q5e?Pvdk1%~G=};)vyI zd|{U%lu3#%isb&@{%P43-Q8o}9#uTsclExEn#8oG>xT>2k;(^S8 z8xwScEzcudU_}Y_Rblow?1cc@>y|pfC+fa#Myx_K_RH?!=vC!aEkEsnf(kuN63INX>6MUnkT}>z9jKJhAqLOX2RRi$@F%r zIwgOk5ZX67jU3cDZqi^QXb8wnODRdyzfHyOq-Xee5d4GV+df_+C8qM6(y;W zERFYky?s=nc^b%v($?Xu-PzeFi#r>)d>^RZ&w62XtoS^}IEE@o&AswH@~qo=?LntE zJVfPy|MF1fK@agUa_fA`PjCDD%`Pd~8ziI6X07)5l$o;e910{Ju5HNyhwq|6{aZ=5 zF(Jv(KjB&vwUr+9Fb8=IMEV&RG90ZC;>xuKy%}^HQCg5H@@4VZhVL-h&Vuo@H0$;t z#SXEs;_+dc>Wcg+ro>mVxtK_#Ckg!x!gJg8$BZ~iQ?N)w&EJ`_%MZ(rTb}B1$-Qs5 zeZ7h3j_J@DYAzFQNpkHWi&9qUSzv3Bhm7HpH-ziYH~Udu+<%g3X=nRT6rWXqA^gA& z7KCuz<05jFi6~jktm@7PxOr4VZDaRx?oyE{2?e_ZlhlRJ-~;Cvkd}%i&i*to_pv0v zu4#4NX9o&+qSdePi!(AWOhzZL>U?2_Y_qG?f1qEj_HeV3($88WzY13>Xv3eOS1P*e z>njhsX|HVu1#&>WLl$Lm_AxNn*>ZO2m`{X|556h+eBnyUhCr6l*SSdw8Sm4*lGj4z zy&D0t)1{SQKFZ*;J^r&+08FqTW}sRXger?s>6vdDCw^*b>xNs8t}@c_~u z3`26?=O(iz`>BK*A!e!5tD(oc+wJNrg3NO4UBuL`^>s{n@#cz6#m0Y4>xq$IbAo%B9iLg~{GZKU|o*WC;8w5+<^M8)Gw zq46nzyG9$FOedaA&27C}Ye_Sq0=gt7`Me!*t`s{(#VqM<~&eg8Yn$6FtOT3j;leq=^^ z&;>UD=NMq-^lqhApR;Nvn$}O_aE%!PGV-%rFJWWU8u5JZt^dUyJQ=L&Fm(OvsHBp5q={;kIMv}n->DOF*&4gi?R6{VD1Mg=V9w(xcbj~^eR;A#Tsep zIm>Ib#P#ht@zDUUwPiIS?649GpHHC zs_Z;OvhV9|T`#T{2fWMnFz4DW0>jldLKHq(w>?&u;-P}dn(HsEW)f&st-?O_zP?P( zkXB(!sIs-62so^9pr$!1sO|=LaE5{rF@U=J{WI(S7R}_Ma_K(Q>Z<}hJG5;JmNx{W z{ww(7BvUQ*eU5UV|3dJ_Ckw$(2fTQ$D7{y?84d4FWRL@1=QIEiqOn?D#BLKBt7d+klUC@^J%zZ05m*l`Q^*10$(9&+a*0?*s;}eBd z8rLvOUF7nlF|(C^yRvUPke0J`$qL0wR99-&qloQPArV}ZN$aaG+e%;-)zU_}3Ow=d zOwZZk74!(7LhnJ^6v(|@j`zf0CXw*83e0lWk}{ zWSC=z0&ZU1?I*L+gJtp(TN+QinUk7l54G((5A58=Jt1 zNXPMp-w=HkcxBNJrXkuZBga2CZe|-@YlIdOJdE77_g)5!5<-qKob*?+VX3U;zmHs z3otff+)=qGZs<`1u7N`QT~?byAll*>sFJQCMUR_Ex4h0TeOKofg?smP0K|^^p!D4} z$S5vOz@Xu0d>mv5y?37IRxqgF#~`y%>f`kZse-|(T&pp}332V@v-;{kU;$9foEbwp zS2lem?N$fKb530#jn81cl>CIO$0ELwnuFyP9F+$`ZtRAQGf=WotEHX^s z{`q%r*6x4LO9nx{A9m8BEdCiRbH+?Zx1z4_E{PSp$1lXx<7dqhm6!RC?Wg7ifo!mq zr}AYW(NtU(OTLe-}S`v@8KqDO7`;AtEm1 ze^SpMQ^u)ijBgdNlnw72)a>QsC&KRWG^EtuYr)Gj@e?`oAa-JYdZwC#q84z2c&JzbkZ1*zR6Ugq@Cn0_hX+qtIzclobG?Vx#?9@m&>mKry66-i5G5PrTgm;|F8 zQ{nZxXr3@;&v_cbehLSN$85mZ2or7qTYa3}Y^7H|VAlN)qsu$b5lf2p<_lO}PwUh3 z_01#+160jf;IeQ}Zf63LCC|xlJO2ldv04)k(Jg}u3m@|ntR(0JkhM~ z0p1L@8f7eJd(Tazecq1;)RmsPBhW5IIfLE{&x!J5*kAcaYk_~MR8qoA>C>% zE~0Xf3~+4M-0`ZvrGuH-QsW=WYCvU~bC>{?>dKH<^nBX!=^XRT{=@X!95JH<<-oz0 z4dr+PUh?!29R5Tx@s+3?bIzR3|GQCsnQwGtyT?s5C&8(+j}s0tq)IDVQ=XOJRJm13 z$h;uDE{}k;;9mhC@YKu#uha{w#iF5`nH~8fr-h8M&=y2 z1A~DB>oI@%N)^N?8~1~t_W;ZUFM!kmoQ-MK095RR=m@ z`spB=86Ur8&k}zZu&5ykd%355nVECk!JEmF@`fjJYhtHWTb2Ic9f;SO+0AS$7$|th zPxNgkPVnKCWBo@k9DAfGYvhh*so#!6RPa+~{q_MpF+@&Oex{3;XykyO;jc(URO?~{ z=tjNpLDP|S59wb4DnERY^<~t&D-fbKcX4)kX$n}%g}hJJd^7f(ZmtUsfQ8_Cz_A-h z1YrTyxDxD6Zmp=65ZQZ#S!^I3#^_Q`ivK)Zd#o=*fke>%lRMSsdrBVr;36R4lIq>> zus>w!i&K0q@EWl+&;He3s@81A>*B}ADy*A=sMPEC^Xz1y!l>=YmHB8IiB`npP$xqt zkwXtNCJvh$`mfHQU%TzB{5qBlTy?Na^?PLSkgscPV^n8)l8fA@vMlJ$kQd!S*Q=Pw zA+Ixhw$Q0USE?B#9?R&`W7}2cR6!ZomC}RxPqrjV`iIQN#|B-dxU4Qsrkwz@$qIY2NEf=C7qv;ocsmb3()m?OC0#Qs7`3|a z-?l}e0Zvr%fKs{tt}azJTHUTurn!dE`g>@D2Z912dMuO)Cx+_oHE=`Bbl)NY?? zVw*o<6EEupo+#bWm^M2fW(TL*Ocno@YA>0+6(%{{GkK?H?^#2Kuk>fWU*@HXx}&2^ zLs`2eIw8F;XpnGcL`ic~$4s`%61_~KWgcFcZ2`QebvTIz^qw7}s+c;X@p31l*AwKK?@mppBE{ry+H0V! zGOZlTLk^!MI?x5>E7V^qU$zuf{bmAw?1~5v^F%&5gT#p6yRe&9T*q=bths>wpyAl* zDF9+c<-c$RxL>lN%dBh#9|y1J22Go#jBKw$2&L~cLBMmD{S8`RBLdnsu?)oCv1IbEgOUD;oYuuSk;EFTVu-@}LI>vvz39njMMIP0NxV*ief92IHF;>9e*S=jkR z5S2J~3gG4reYW{vAEezY^V`%sUo7($dz-#ZVG}B)E)<5_LT=2V*3vQxqcE6r#%d>1 zO%4%^XwV^6E#Z$HHa_F3`P$y7w|3M>HA}irQ-{WL%B*ijwft7o(+3_O5H};fsd{e` zYqH_rFJ5ow@6MLUb}Gl#iDud7R%=DG6cKz+aXw&Tu`zM*-5n$=8$e}bb92^_Em(cE zK&z^~7ZDq2_A+S+)^b80{EaWz-?n+5ARAH107q0XO4u2Fi4{U^^o2_Ie!r~fZ7XN3 zoQk)%rYR=))@#B;S-*7-he!Wdv}bc=ZrFe^QP)mV1PNx+q>4l8&Muw>8z|DK5oe|? zy%Dd6H>pSWYnUe7s9NH!nBmzSb?>z=u#L8;b0N7<$Ox5+G&vR6xK;N1nyO)GPffMmey4xc-p%f{Bdf3_dY+0xOP95R(gDRA2jlWXf!QvIl;#JDNcCF*eO%SCQU1(^fskOCGk#QYcD-v_3 zc72?70AlB(Zfq_cgVx@4NA;+;{HSp?H8thQ%I;|ygSm-?%j`s3Nmrym<)%Wiyt?{awkAdM+Cy0V;qYJzdL2bTKv2(aPB33( zLD^gb!<%+?-)+s+Y20YB?$KkkrfgqZN@@+Um@brvaF*38=q8fI-ml%NlrViVAG&Yj zDC2`yyufJHsW!PtCup-jZ>;UbY z%2VXFhb1swjy8tj4jsj-@2;9UY)nI(K}Ju`w~IL7Mm)8#o1GDlQb}0YBxxFzQ(E3# zopHb(4p^3p=Z6mI!;Qk*+BTS2jy`yFPfi9ivB=QyG^RAIueV9M+KJAi5Fwrk2{f$M z#O zLrvTdG|F*taSStGwyV6$aV}o=Q1vymT8>RJ z9LynLK2Ofu86g-F3cG?lS!oQDR&8x>ZUv|V*8}(j6HS6o^x(M#G+Dtw@l% z{(y%?P(X&*U}9Z{O*+%k28c+wis%jL<)dOX^@K2sn7xF#`}^Cu zQ%m>vRu6v+kB^5O!0jNMyuwnOi$$>Tq#Pgj`Bq^qqF7?6jh&sIebH{`EW|^_Gj`2U@zWO!&ZE>*o>0Sf~0CRj7UbPD!B@Ok*O8 zN@z|06oKzrO91K`N7waG>=n+T73YGtJw#&eUOFVNSLI!7tT`>>;iw@3f)$H7*KEhG zJgl@jEyCVK7K+;`FYX>9SaIE31Fy+5c)i6wDTjJZN|}vtx;)J9!t^-W)O`233I(t1 z*1k$RoO{RFs6g=?&uR@e&$F|$E1F@X0Ib0A$2UM@r{k*pOFtF?aOaO0?o`oR@tmZ; z`UoT!APAW)Q>Glh%`YoiAi^J2PW&1va@gKVQb2Bgek29{Tr%F(-hSJt`gxn}m*2ev z=LWz%M@%GEskOqcGvbdAx2g!#Hm@egX8m*~B6bntUJv>Q``o_yn?*G+)Pw+@|5O)2`3m^Vk{2WjOIW$y6Sa|Xe zCxe&bEqp*kVpipjW??uENbL!7bKs1PVn9a5naJSypYu~o+x+)~lFQKV$z F|9=lMs;U40 literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/requirements.txt b/docs/examples/te_gemma/requirements.txt new file mode 100755 index 0000000000..a4eaeea43f --- /dev/null +++ b/docs/examples/te_gemma/requirements.txt @@ -0,0 +1,4 @@ +transformers==4.55.0 +accelerate==1.10.0 +datasets==4.0.0 +sentencepiece==0.2.1 diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py new file mode 100755 index 0000000000..6285fea1a9 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma.py @@ -0,0 +1,703 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional +from functools import partial +from collections import OrderedDict + +import torch +from torch.amp import autocast + +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.pytorch.fp8 import get_default_fp8_recipe +import transformers +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel + +import torch.nn.functional as F + +""" +Top level description of the classes used in the tutorial from this file. +---------------------------------------------------------------------- + +HuggingFace Gemma Model implementation hierarchy: +---------------------------------- +GemmaDecoderLayer: +├── self_attn: +│ ├── norm: (nn.LayerNorm) +│ ├── qkv_proj: (nn.Linear) +│ ├── attention: (SDPA, FlashAttention, etc.) +│ └── o_proj: (nn.Linear) +├── ffn: +│ ├── norm: (nn.LayerNorm) +│ ├── gate_proj: (nn.Linear) +│ ├── up_proj: (nn.Linear) +│ └── down_proj: (nn.Linear) + +GemmaModel: +├── embed_tokens : Token embedding layer +├── layers : GemmaDecoderLayer × N +├── norm : GemmaRMSNorm +└── rotary_emb : GemmaRotaryEmbedding + +GemmaForCausalLM: +├── model : instance of GemmaModel +├── lm_head : (nn.Linear) hidden states to vocabulary logits for generation +└── generate : generate method (input prompt -> GemmaForCausalLM -> next tokens) + +How `generate()` works in HF's GemmaForCausalLM: + 1. prefill (input prompt -> model -> lm_head -> logits -> next token) + 2. loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: Notice how "prefill" and "loop until next tokens" are just part of the `generate()` method. + This is a common pattern in HF models. + + +TransformerEngine's Gemma Model Hierarchy: +---------------------------------------- +HF's `GemmaDecoderLayer` is monkey-patched with `TEGemmaDecoderLayer` before `GemmaForCausalLM` is initialized. This way, +while the model is downloaded from HuggingFace and most of the code runs from HF's `GemmaForCausalLM`, the underlying +blocks of "transformer layer" are actually from TransformerEngine. + +TEGemmaDecoderLayer (inherits from te.TransformerLayer): +├── te.MultiHeadAttention: +│ ├── linear_qkv: (te.LayerNormLinear) +│ ├── attention: (te.DotProductAttention) +│ └── out_proj: (te.LayerNormLinear) +├── te.LayerNormMLP: +│ ├── fc1: (te.LayerNormLinear) +│ ├── fc2: (te.Linear) +│ └── activation: (te.GeGLU) + +To be able to use `model.generate()`, an entry point is needed. `TEGemmaForCausalLM` is the entry point which +subclasses HF's `GemmaForCausalLM` and adds a few attributes and methods. + +TEGemmaForCausalLM (inherits from HF's GemmaForCausalLM) +├─ model : inherited from HF's GemmaForCausalLM but with monkey-patched TEGemmaDecoderLayer × N +├─ lm_head : directly inherited from HF's GemmaForCausalLM +├─ te_rope_emb : RotaryPositionEmbedding (reusing the same for all layers for CUDA graphs compatibility) +├─ hidden_states_buffer : shape [b, max_ctx, h] (static) +├─ generation_buffer : shape [b, 1, h] (view of `hidden_states_buffer`) (static) +├─ inference_params : TransformerEngine KV cache +├─ model_context_phase : GemmaModelWrapper → uses (model, lm_head, inference_params) for full-sequence prefill +├─ model_generation_phase : GemmaGenerationWrapper → uses (model, lm_head, inference_params) for single-token decode +└─ generate : generate method (input prompt -> TEGemmaForCausalLM -> next tokens) + +Notice how "prefill" and "loop until next tokens" are specialized to wrapper subroutines - "model_context_phase" and +"model_generation_phase" respectively which makes it easier to use CUDA Graphs. Just one more abstraction is needed: + +TEGemmaForCausalLMCudaGraphs (inherits from TEGemmaForCausalLM) +├─ model : unchanged (HF's GemmaModel with monkey-patched TEGemmaDecoderLayer × N) +├─ lm_head : unchanged +├─ hidden_states_buffer : unchanged +├─ generation_buffer : unchanged +├─ inference_params : unchanged +├─ record : utility function to record the graphed callable +├─ model_context_phase : GraphedCallable(for Context/prefill) replaced by `record` +├─ model_generation_phase : GraphedCallable(for Generation) replaced by `record` +└─ generate : unchanged + +How `generate()` works in TEGemmaForCausalLM/TEGemmaForCausalLMCudaGraphs: + 1. model_context_phase (input prompt -> model -> lm_head -> logits -> next token) + 2. model_generation_phase: + - loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: In the tutorial, `record` is called when initializing the model. + +Additional notes and clarifications +----------------------------------- +- Wrappers, not submodules: + `model_context_phase` and `model_generation_phase` are convenience wrappers over the same + `model` (GemmaModel) and `lm_head`. They own no parameters; they standardize buffer usage, + masks (context uses "padding_causal", generation uses "padding"), rotary embeddings, and + KV-cache (`InferenceParams`) flow for TE-optimized inference. + +- Buffer relationship: + `hidden_states_buffer` has shape [b, max_ctx, h]. `generation_buffer` is a contiguous view + of size [b, 1, h] carved from its start to avoid non-contiguous indexing. Generation updates + `generation_buffer` in-place with next-token embeddings. + +- Padding policy: + Inputs may arrive left-padded (HF-style). Before TE execution, padding is shifted to the end + to match TE attention mask expectations and to keep shapes contiguous for capture/replay. + +- CUDA Graphs specifics: + `record()` captures two separate callables (context/prefill and generation) with fixed shapes and + stable pointers, then replaces the wrappers with these GraphedCallables. Under graphs, the + functional behavior is identical; only allocation/pointer churn and CPU overhead are removed. +""" + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + + def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs): + + self.gemma_config = config + + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="geglu", + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + kv_channels=self.gemma_config.head_dim, + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, + ) + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + + # filter out HF specific args + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + + rope_emb = kwargs.pop("rope_emb", None) + + # Return tuple to be compatible with HF. + return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),) + + +class GemmaModelWrapper(torch.nn.Module): + """ + Encapsulates the HuggingFace GemmaModel class as a wrapper whose + forward pass is compatible with CUDA Graphs. + """ + + def __init__( + self, + model: GemmaModel, + dtype: torch.dtype, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + + for i, decoder_layer in enumerate(self.model.layers): + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type, + inference_params=self.inference_params, + rope_emb=rope_emb, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + + # This is not needed for generation but is needed for training + # or finetuning. + if self.training: + logits = logits.float() + + return logits + + +class GemmaGenerationWrapper(torch.nn.Module): + """ + Gets token embeddings for a batch of single tokens, runs forward pass, and + returns the batch ofnext tokens. Also compatible with CUDA graphs. Not a + subclass of `GemmaModel` since the model layers are simply reused here. + """ + + def __init__( + self, + model: GemmaModel, + lm_head: torch.nn.Module, + dtype: torch.dtype, + ): + super().__init__() + self.model = model + self.gemma_layers = GemmaModelWrapper(model, dtype, lm_head) + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + logits = self.gemma_layers( + hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb + ) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + + # Fetch the logits for the last token + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + return next_tokens + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Monkey-patches `GemmaDecoderLayer` with the custom `TEGemmaDecoderLayer` + class. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: Gemma model config that HF uses to initialize the model. + """ + + def __init__(self, config: GemmaConfig): + + dtype = torch.bfloat16 + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + super().__init__(config) + + self.config = config + self.to(dtype).cuda() + self.hidden_size = config.hidden_size + + self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head) + + self._model_generation_phase = GemmaGenerationWrapper( + lm_head=self.lm_head, + model=self.model, + dtype=dtype, + ) + + if self.config.fp8: + self.fp8_recipe = get_default_fp8_recipe() + + # Rotary position embedding remains the same for all the layers and so + # created here. This makes it compatible with CUDA Graphs too. + self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)( + max_seq_len=self.config.max_position_embeddings + ).cuda() + + @staticmethod + def _padding_to_end(inputs, lengths, max_seq_len=None): + """ + Gets the tensor with sequence padded from the beginning and + updates it inplace to be padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + + # Trim the inputs to no extra padding i.e. fix the max seq len to + # the longest sequence in the batch + actual_max_seq_len = max_seq_len + inputs.data = new_input_ids[:, :actual_max_seq_len] + + def _create_or_fetch_hidden_states_buffer(self, input_ids: torch.Tensor): + """ + Returns a tensor of shape [b, s, hd] where `b` is the batch size, + `s` is the sequence length, and `hd` is the hidden size. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + tensor = torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + return tensor + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Creates an InferenceParams object. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + infer_params = InferenceParams(*args, **kwargs) + return infer_params + + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + """ + Returns a tensor of shape [b, 1, hd] where `b` is the batch size, + `hd` is the hidden size. + + The buffer for generation is some part (beginning) of hidden states buffer. + This function returns pointer to it and also copies there data if provided. + """ + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that `hidden_states_buffer[:, 0, :].unsqueeze(1)` will return + # uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def setup_and_run_context_phase( + self, input_ids: torch.Tensor, inference_params: InferenceParams + ): + """ + Runs the context or prefill phase of the model. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + hidden_states = self._create_or_fetch_hidden_states_buffer(input_ids) + hidden_states.copy_(self.model.embed_tokens(input_ids)) + + # Update offsets before every forward pass (including context/prefill + # phase) to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))) + + logits = self._model_context_phase( + hidden_states, + attention_mask=None, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + logits = logits[torch.arange(logits.size(0)), lengths - 1, :] + next_tokens = torch.argmax(logits, dim=1) + + # `self.hidden_states` has shape [b, s, hd]. + # Return hidden state for the last token - output has shape [b, 1, hd]. + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs, + ): + """ + Generates next tokens auto-regressively for a batch of input tokens. + """ + self.eval() + + # Both autocasts are needed: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() + # If padding is at the beginning, then shift it to the end + TEGemmaForCausalLM._padding_to_end( + input_ids, + lengths, + max_seq_len=( + self.config.cuda_graphs_static_max_context_len + if self.config.generation_cuda_graphs + else None + ), + ) + + batch_size = input_ids.shape[0] + # For benchmark generation run, this is being set explicitly. + max_input_sequence_len = self.config.max_seq_length + + # InferenceParams is a cache, where keys and values of previous + # tokens are stored. Moreover it stores the current running lengths + # of the sequences in the current batch. + # A helper function is used to create the inference params object + # because this `generate` method is common for TEGemmaForCausalLM + # and TEGemmaForCausalLMCudaGraphs. In case of CudaGraphs, this + # function is overriden to simply return the inference params object + # that is already created in TEGemmaForCausalLMCudaGraphs' + # constructor. + inference_params = self._create_or_fetch_inference_params( + max_batch_size=batch_size, + max_sequence_length=max_input_sequence_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=batch_size * max_input_sequence_len // 16, + ) + + # Set the inference params for both the context/prefill phase and + # generation phase objects. + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + # Context/prefill phase. + hidden_states, next_tokens = self.setup_and_run_context_phase( + input_ids, inference_params + ) + + # Generation phase. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + output_tokens = [next_tokens] + + for _ in range(max_new_tokens): + next_tokens = self._model_generation_phase( + hidden_states, + mask=None, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + # Increase sequence offsets by one because we generated one token + # for every sequence. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + + # `next_tokens` is a static output tensor, so we need to clone + # it because it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + def forward(self, *args, **kwargs): + """ + Forward pass for the model. This is used in calibration step when + forward pass is needed to generate FP8 calibration data. + """ + + self._model_context_phase.set_inference_params(None) + hidden_states = self.model.embed_tokens(kwargs["input_ids"]) + logits = self._model_context_phase( + hidden_states, + attention_mask=( + kwargs["input_ids"] == 0 + ), # Hardcoded, this only applies to bshd/sbhd layouts. + attn_mask_type="padding_causal", + ) + return logits + + +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): + """ + TEGemmaForCausalLMCudaGraphs is a wrapper over the class TEGemmaForCausalLM + and uses CUDA Graphs to speed up the generation process. We need to make one + trade-off - batch_size, max_seq_len and max_context_seq_len need to + be static. It is necessary to run generation without changing the pointer + to the variables that are recorded in the graph. + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + + self.config = config + + # Preparation of the static buffer to hold the hidden states that are + # passed from one layer to the next. + self.hidden_states_buffer = torch.empty( + ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + self.config.hidden_size, + ) + ).cuda() + + # This is in fact part of the buffer for hidden_states. Refer to the + # `_get_generation_buffer` function for more details. + self.generation_buffer = self._get_generation_buffer( + self.hidden_states_buffer, + ) + + # InferenceParams contains the keys and values cache. Refer to the + # original call in TEGemmaForCausalLM's `generate` method for more + # details. + self.inference_params = InferenceParams( + max_batch_size=self.config.cuda_graphs_static_batch_size, + max_sequence_length=self.config.cuda_graphs_static_max_context_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=self.config.cuda_graphs_static_batch_size + * self.config.cuda_graphs_static_max_context_len + // 16, + ) + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + """ + Here "the trick" happens. `_model_context_phase` and + `_model_generation_phase` from TEGemmaForCausalLM are replaced with + their recorded version. Once the graphs are recorded, they can be + replayed with minimal usage of CPU and that leads to speedup. + """ + # Record the model with training=False, because it will be used in + # generation. + self.eval() + + # Setup the recording for context/prefill phase. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + + # Hardcoded value for the context length. + lengths = torch.tensor([9] * self.config.cuda_graphs_static_batch_size).to( + device="cuda", dtype=torch.int32 + ) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for context/prefill phase. + self._model_context_phase = self.record_graph( + self._model_context_phase, + self.hidden_states_buffer, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + # Setup the recording for generation phase. + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for generation phase. + self._model_generation_phase = self.record_graph( + self._model_generation_phase, + self.generation_buffer, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + def _create_or_fetch_hidden_states_buffer(self, *args, **kwargs): + """ + Overriden to make `hidden_states` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `hidden states` which is already created + in the constructor. This is the same buffer as used in the + context/prefill phase. + """ + return self.hidden_states_buffer + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Overriden to make `inference_params` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `inference_params` which is already created + in the constructor. + """ + self.inference_params.reset() + return self.inference_params + + @torch.no_grad() + def record_graph(self, function, input_tensor, **sample_kwargs): + """ + Records the graph for the given function. The function is invoked on + argument (self.hidden_states,) and all kernels are recorded. + It then returns the captured callable, which can be run later while + minimizing CPU usage. + """ + fp8_recipe = get_default_fp8_recipe() + + # We need both autocasts: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=5, + sample_kwargs=sample_kwargs, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py new file mode 100755 index 0000000000..d0df9edc58 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +import torch + +from typing import List + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import load_state_dict +from transformers.utils.hub import get_checkpoint_shard_files + +""" + This file contains logic of mapping the HuggingFace GemmaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" + + +def _load_weights_for_fp8_model(vanilla_model, hyperparams): + """ + Loads weights and FP8 metadata from a calibrated weights file. + + The weights are in BF16 precision, but the state dict also contains + fp8 metadata computed by the calibration procedure. + """ + + fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename) + + # A hack to remove the extra state from the fp8_metadata_sd + # that contains the extra state from the core_attention module. + fp8_metadata_sd = { + k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k + } + vanilla_model.load_state_dict( + fp8_metadata_sd, + strict=False, + # Because some parameters have multiple pointers to the same weight + # vanilla_model._model_context_phase.model and + # vanilla_model._model_generation_phase.model we need to load the + # weights in a non-strict manner. + ) + + +def _load_weights_for_standard_model(vanilla_model, config): + """ + Loads weights from the HuggingFace checkpoint. + """ + + archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict.update(state_dict) + + replace_params( + total_dict, + vanilla_model.state_dict(), + config, + qkv_fused_and_interleaved=config.fuse_qkv_params, + ) + # Copy remaining parameters like embedding. + vanilla_model.load_state_dict(total_dict, strict=False) + + # Force mem release. Taken from huggingface code. + del total_dict + gc.collect() + + +def load_te_model(cls, config): + """ + Loads the TE model with proper weights. + """ + + # Force the dtype to bfloat16 while loading the model. + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + config.use_cache = False # To make TransformerLayer compatible with GemmaModel + + # Loading model with FP8 only weights needs both the following context managers. + # 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights. + # 2. torch.no_grad() during TE modules' initilization so that they respect + # the `fp8_model_init` context manager. + with torch.no_grad(), fp8_model_init(config.fp8_model_init): + # Just create a model with random weights. + vanilla_model = cls(config).cuda() + + # Copy proper weights into the model. If loading weights with FP8 metadata, + # then the source weights are basically the same as the weights in the model. + # If not, then we need to load the weights from the HuggingFace checkpoint + # and do mapping of the weight names from HF to the TE model. + if config.fp8_model_weights_filename is not None: + _load_weights_for_fp8_model(vanilla_model, config) + else: + _load_weights_for_standard_model(vanilla_model, config) + + # Restore the original dtype. + torch.set_default_dtype(old_dtype) + return vanilla_model + + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." + that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + + for layer_prefix in all_layer_prefixes: + + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + te_state_dict[layer_prefix + te_name].data[start:end].copy_( + hf_state_dict[layer_prefix + hf_name] + ) + + copy_from_ht_to_te( + "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight" + ) + copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight") + copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight") + copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight") + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size + ) + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size + ) + + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"] + + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + dst_slice = slice( + dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim + ) + src_slice = slice( + head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim + ) + te_qkv_layer[dst_slice, :] = src[src_slice, :] + + copy_interleave("self_attn.q_proj.weight", 0) + copy_interleave("self_attn.k_proj.weight", 1) + copy_interleave("self_attn.v_proj.weight", 2) + else: + copy_from_ht_to_te( + "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight" + ) + copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight") + copy_from_ht_to_te( + "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight" + ) + + return all_layer_prefixes diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb new file mode 100755 index 0000000000..cc8675cfd8 --- /dev/null +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -0,0 +1,941 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "87e8360b-8d08-44bc-9333-79ba949afe8c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Accelerating Hugging Face Gemma Inference with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "id": "2da33092-eef5-46a4-b222-0188cc6e5079", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Introduction\n", + "\n", + "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", + "\n", + "

\n", + "\"\"\n", + "
\n", + "Animation 1: Hugging Face Gemma model token generation.\n", + "
\n", + "
\n", + "\n", + "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", + "\n", + "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n", + "\n", + "This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n", + "\n", + "### 1. From vanilla KV-caching to Paged Attention for inference in Transformer Engine\n", + "\n", + "The original [Attention mechanism](https://arxiv.org/pdf/1706.03762) ushered in an era of Large Language Models, but the same attention mechanism, if used for deployment in inference scenarios, can be computationally wasteful. It is primarily due to a lot of redundant computation that happens in attention when the Transformer models are used autoregressively to compute the next token. Several tutorials on the internet explain in detail how KV Caching helps to reduce that redundant computation, e.g., [tutorial 1](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms), [tutorial 2](https://medium.com/@joaolages/kv-caching-explained-276520203249), etc.\n", + "\n", + "\n", + "Further, even though the performance benefit of KV Cache is immense, it comes at the cost of increased memory usage, which becomes a problem especially for longer context lengths. The major problems are: \n", + "\n", + "1. Internal fragmentation\n", + "2. External Fragmentation\n", + "\n", + "More information can be found in the [Paged Attention](https://arxiv.org/pdf/2309.06180) paper. The authors solve the above problems by treating the KV cache as a virtual memory with the actual physical blocks being much smaller than the overall cache size. This makes it easier to swap them in and out of GPU HBM as needed - very similar to how Operating Systems implement virtual memory to swap the individual pages in and out of the CPU RAM.\n", + "\n", + "\n", + "Transformer Engine allows users to use both \"Non-paged\" and \"Paged\" forms of KV Caching, and the results in this tutorial are posted for both use cases.\n", + "\n", + "\n", + "### 2. CUDA Graphs API\n", + "\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to finish processing and then launch the kernels, which can lead to significant overhead. CUDA Graphs can address this issue. When such blocks of computation are executed repeatedly, CUDA Graphs allow us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where multiple \"Transformer/Decoder Layers\" are run for every token that needs to be generated.\n", + "\n", + "One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the CUDA graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 1: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n", + "
\n", + "
\n", + "\n", + "### 3. FP8 Scaling Factors Calibration\n", + "\n", + "This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n", + "\n", + "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 2:\n", + "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n", + "
\n", + "
\n", + "\n", + "### 4. FP8 Model Weights\n", + "\n", + "The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n", + "\n", + "The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 3: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using fp8_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", + "
\n", + "
\n", + "\n", + "### Benchmarking\n", + "\n", + "We'll evaluate the generation time across one benchmark: token generation with context/prefill phase max sequence length = 20, batch size = 64, and number of generated tokens = 492 on random texts with random lengths. This is a purely synthetic benchmark.\n", + "\n", + "
\n", + "Note\n", + " \n", + "This tutorial focuses on showcasing the mentioned features of the Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT-LLM](https://docs.nvidia.com/tensorrt-llm/index.html), which is optimized for inference tasks and should be considered for such use cases.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b18f91a9", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "e5201d77", + "metadata": {}, + "source": [ + "The following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint weights in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. Further, it contains necessary abstractions like a subclass of `GemmaForCausalLM` - `TEGemmaForCausalLM` that is used for generation with Transformer Engine's `TransformerLayer`, CUDA Graphs, and FP8 calibration for generation in FP8 precision.\n", + "2. `te_gemma_loading_weights.py`\n", + " - This file contains the logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", + "3. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training, and other miscellaneous tasks like restarting the Jupyter notebook from within the cell. \n", + "4. `requirements.txt`\n", + " - This file contains the necessary Python packages for this tutorial.\n", + "5. `media/`\n", + " - This directory contains the images and other artefacts used in this tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "36767694-a1c5-4a00-a075-7addc55d8307", + "metadata": {}, + "source": [ + "### Setup and checks" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1de3351b-fa21-4b95-bb9e-d01ac8bb7edf", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and run this cell when running the tutorial for the first time\n", + "# %pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c756ebbd-24c9-4a54-a381-e7c02c555206", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "id": "e8dfabbf", + "metadata": {}, + "source": [ + "## [Baseline] Running Hugging Face generation with Gemma model" + ] + }, + { + "cell_type": "markdown", + "id": "59560bff", + "metadata": {}, + "source": [ + "HuggingFace Transformers library offers generation API. \n", + "HuggingFace generation for the Gemma model will be used as a baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2803e0ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds advanced computer graphics and video processing chips for the PC and video game console markets.\n", + "* The company is a leading provider of graphics processing units (GPUs) for the PC and video game\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 46.60 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "\n", + "model = init_baseline_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b3698dc6", + "metadata": {}, + "source": [ + "Let's put this time into the table for later comparison.\n", + "\n", + "| Models | Time | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 46.6 s | - |" + ] + }, + { + "cell_type": "markdown", + "id": "8bb40f45", + "metadata": {}, + "source": [ + "## [Optimization 1] Accelerating generation with Transformer Engine " + ] + }, + { + "cell_type": "markdown", + "id": "263b40f2", + "metadata": {}, + "source": [ + "Similar to the [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) finetuning tutorial, a `GemmaDecoderLayer` is substituted by a tuned `TransformerLayer` from the Transformer Engine library. Let's run it and compare the time with the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9dceef93", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 12.25 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d40836", + "metadata": {}, + "source": [ + "With just using Transformer Engine with default (non-paged) KV cache, a speedup of **3.8x** was obtained. Neat!" + ] + }, + { + "cell_type": "markdown", + "id": "006d18e8", + "metadata": {}, + "source": [ + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of `GemmaDecoderLayer` with `te.TransformerLayer`) | 12.25 s | 3.8x | 12.24 s | 3.8x |" + ] + }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Optimization 2] More acceleration with CUDA Graphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "Transformer Engine includes a function `transformer_engine.pytorch.make_graphed_callables`, which behaves similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from [te_gemma.py](./te_gemma.py) from class `TEGemmaForCausalLMCudaGraphs`:\n", + "```python\n", + " def __init__(self, config : GemmaConfig):\n", + " \"\"\"\n", + " Here \"the trick\" happens. `_model_context_phase` and\n", + " `_model_generation_phase` from TEGemmaForCausalLM are replaced with\n", + " their recorded version. Once the graphs are recorded, they can be\n", + " replayed with minimal usage of CPU and that leads to speedup.\n", + " \"\"\"\n", + " (...)\n", + " # Record the graph for context/prefill phase.\n", + " self._model_context_phase = \n", + " self.record_graph(self._model_context_phase, self.hidden_states_buffer)\n", + "\n", + " (...) \n", + " # Record the graph for generation phase.\n", + " self._model_generation_phase = \n", + " self.record_graph(self._model_generation_phase, self.generation_buffer)\n", + "\n", + " @torch.no_grad()\n", + " def record_graph(self, function, input_tensor):\n", + " \"\"\"\n", + " Records the graph for the given function. The function is invoked on\n", + " argument (self.hidden_states,) and all kernels are recorded.\n", + " It then returns the captured callable, which can be run later while\n", + " minimizing CPU usage.\n", + " \"\"\"\n", + " fp8_recipe = get_default_fp8_recipe()\n", + "\n", + " # We need both autocasts: FP8 for operations that can run in lower\n", + " # precision and BF16 for those that cannot.\n", + " with autocast(\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n", + " graphed_function = te.pytorch.make_graphed_callables(\n", + " function,\n", + " (input_tensor,),\n", + " fp8_enabled=self.config.fp8,\n", + " fp8_recipe=fp8_recipe,\n", + " allow_unused_input=True,\n", + " num_warmup_iters=5,\n", + " sample_kwargs=sample_kwargs,\n", + " )\n", + " return graphed_function\n", + "```\n", + "\n", + "It is strongly recommended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.\n", + "\n", + "*Note the usage of static buffers and corresponding configuration in the following cell, which is necessary for CUDA Graphs to function.*" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 6.39 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.max_seq_length = 512\n", + "run_config.batch_size = 64\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# It is necessary to preallocate a static buffer.\n", + "# CUDA graphs require static input tensors for every kernel.\n", + "# This approach may result in a slight increase in memory consumption;\n", + "# however, the substantial speedup achieved makes it worthwhile.\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "A speed up of **7.2x** was obtained by using CUDA Graphs with TE's `TransformerLayer`.\n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "0a11b75c", + "metadata": {}, + "source": [ + "Let's profile the code from one of the cells above, which runs generation with the Gemma model, and examine the resulting traces in [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to understand the performance characteristics and sources of speedup. A few things to recap:\n", + "\n", + "1. For the TE Gemma model implementation, `model.generate()` internally calls `model_context_phase` and `model_generation_phase`.\n", + "2. They are just wrappers around the Gemma model's layers, and they are graphed separately when CUDA graphs are enabled.\n", + "3. So, for each token generated (after the first token), a single invocation of `model_generation_phase` happens as a complete CUDA graph. \n", + "4. The following illustration zooms in on a single `TransformerLayer` layer forward pass (within the larger `model_generation_phase` graphed callable) for clarity.\n", + "\n", + "(For details, refer to the implementation in [te_gemma.py](./te_gemma.py))\n", + "\n", + "
\n", + "\n", + "
\n", + " \n", + "Figure 4: (Without CUDA graphs) Blue blobs in the top figure are GPU kernels, and whitespace b/w those indicates that GPUs are idle waiting for the CPU to finish processing and then launch kernels. (With CUDA graphs) The whitespace gets virtually eliminated because all the GPU kernels are bundled into a single highly optimized unit of work with no CPU time in between. (Note that for reference, the kernels are mapped across both cases, and the sizes of those kernels only seem different because of the presence of large voids in the former case, but the sizes are actually the same.)\n", + "
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "e6b171a0", + "metadata": {}, + "source": [ + "## [Optimization 3] Even more acceleration with FP8 precision " + ] + }, + { + "cell_type": "markdown", + "id": "1a80288b", + "metadata": {}, + "source": [ + "### Calibrating FP8 scaling factors for correctness\n", + "\n", + "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", + "\n", + "1. Model weight tensors\n", + "2. Input tensors\n", + "\n", + "If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n", + "\n", + "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", + "\n", + "*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n", + " \n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 5: The default FP8 scaling factors are incorrect, and so the BF16 to FP8 conversion, as is, can lead to numerical errors. Calibration allows for collecting statistics/metadata about the input and weight tensors in higher precision during the forward pass.\n", + "
\n", + "
\n", + "\n", + "\n", + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent steps." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aecee0e1", + "metadata": {}, + "outputs": [], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "import transformer_engine.pytorch as te\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "run_config.fuse_qkv_params = True\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "# Calibration\n", + "with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " model.train()\n", + " run_forward_pass(model, run_config, num_iters=64)\n", + "\n", + "# Compute scale_fwd with enabled fp8 autocast\n", + "with te.fp8_autocast(enabled=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " run_forward_pass(model, run_config, 1)\n", + "\n", + "# Some parameters are in pointing to the same tensors, double save is avoided here.\n", + "dict_to_save = {\n", + " k: v\n", + " for k, v in model.state_dict().items()\n", + " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)\n", + "}\n", + "torch.save(\n", + " dict_to_save, \"calibrated_weights.pth\"\n", + ") # <-- Add path to save calibrated weights." + ] + }, + { + "cell_type": "markdown", + "id": "b6dcd135", + "metadata": {}, + "source": [ + "### Generation with better FP8 scaling factors\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 6: After the calibration process, FP8 scaling factors are correct and prevent numerical errors.\n", + "
\n", + "
\n", + "\n", + "Now that the calibration has produced correct scaling factors, FP8 inference is ready to be run." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a913f54d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 8.73 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8\n", + "run_config.fp8 = True\n", + "# Calibrated fp8 weights are loaded directly from the file.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "8cdbb56c", + "metadata": {}, + "source": [ + "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n", + "\n", + "### Use of FP8-only model weights\n", + "\n", + "Running the model in FP8 precision does not imply that the weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors before GEMM operations (matrix multiplications).\n", + "\n", + "This approach is appropriate during training since gradients during the backward pass are produced in higher precision, and therefore, having higher precision copies of model weights helps, as they have enough dynamic range to encompass incoming information from the gradients. During the forward pass, the higher precision model weights and the batch inputs are cast to FP8, and the GEMMs occur in FP8 precision, which helps save training time overall if the time saved from running GEMM in FP8 precision (than in higher precision) is more than the extra time spent during the cast operation.\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 7: Running the model at higher precision involves only one operation - GEMM. However, when the model operates in FP8, it requires casting inputs to the GEMM - namely, model weights and batch inputs from higher precision to FP8, which involves extra kernels in addition to the low-precision GEMM kernel.\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "626aefa1-d5c4-4d8f-88d9-7d7943afde0d", + "metadata": {}, + "source": [ + "However, things change during inference. Since the weights need no update and remain frozen, higher precision copies of weights could be avoided completely. It is possible to cast the higher precision weights only once to FP8 precision while initializing the model with appropriate scaling factors and then use those FP8-only copies of weights during the entirety of token generation. This provides two-fold benefits:\n", + "\n", + "1. Lower memory usage - since the model weights are stored in FP8 precision only (compared to training, where both BF16 and FP8 copies end up being present in the memory during peak usage).\n", + "2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n", + "\n", + "\n", + "Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4562ee82-8c95-4736-8815-cd386078a485", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory required for 16384x16384 linear layer: \n", + "FP32 - 1024.0 MB, \n", + "BF16 - 512.0 MB, \n", + "FP8 - 256.0 MB, \n", + "\n", + "Actual GPU memory usage with a TE FP32 linear layer: 1024.06 MB\n", + "Actual GPU memory usage with a TE BF16 linear layer: 512.03 MB\n", + "Actual GPU memory usage with a TE FP8 linear layer: 256.08 MB\n" + ] + } + ], + "source": [ + "import torch\n", + "import transformer_engine.pytorch as te\n", + "\n", + "H = 2**14\n", + "D = 2**14\n", + "print(f\"Memory required for {H}x{D} linear layer: \\n\"\n", + " f\"FP32 - {H*D*4/1024**2} MB, \\n\"\n", + " f\"BF16 - {H*D*2/1024**2} MB, \\n\"\n", + " f\"FP8 - {H*D*1/1024**2} MB, \\n\")\n", + "\n", + "linear_fp32 = te.Linear(H, D, params_dtype=torch.float32) \n", + "print(f\"Actual GPU memory usage with a TE FP32 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp32\n", + "\n", + "linear_bf16 = te.Linear(H, D, params_dtype=torch.bfloat16)\n", + "print(f\"Actual GPU memory usage with a TE BF16 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_bf16\n", + "\n", + "# Initialize model weights in FP8 precision\n", + "with torch.no_grad(), te.fp8_model_init(enabled=True):\n", + " linear_fp8 = te.Linear(H, D)\n", + "print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp8" + ] + }, + { + "cell_type": "markdown", + "id": "2a26aba9-f3ba-42c4-b4c3-9e845502ae1b", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", + "
\n", + "
\n", + "\n", + "Let's run the code with `fp8_model_init`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "96264b9c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 4.99 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8 math and FP8 model weights\n", + "run_config.fp8 = True\n", + "run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "3e30ca5a", + "metadata": {}, + "source": [ + "The final speedup is **9.3x**. \n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "c6e87275", + "metadata": {}, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "id": "7bb2452d", + "metadata": {}, + "source": [ + "This tutorial focuses primarily on making the token generation faster with an off-the-shelf model downloaded from Hugging Face using the following features of the Transformer Engine:\n", + "\n", + "1. Support for KV Caching (both non-paged and paged),\n", + "2. Integration with CUDA Graphs,\n", + "3. FP8 scaling factors calibration,\n", + "4. Keeping model parameters in FP8 precision.\n", + "\n", + "It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n", + "\n", + "1. Longer context lengths (with paged KV cache) \n", + "2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n", + "\n", + "Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py new file mode 100755 index 0000000000..cc31afc65a --- /dev/null +++ b/docs/examples/te_gemma/utils.py @@ -0,0 +1,370 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import sys +import IPython +import random +import string + +from te_gemma_loading_weights import load_te_model +import torch +from torch.utils.data import DataLoader + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + AutoConfig, +) +from transformers import DataCollatorForLanguageModeling +from datasets import load_dataset + + +from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs + +random.seed(42) +torch.manual_seed(42) + + +class RunConfiguration: + def __init__(self): + self.mixed_precision = "bf16" + self.model_name = None + + # FP8 precision settings + self.fp8 = False + self.fp8_model_weights_filename = None + self.fp8_model_init = False + + # Cuda graphs + self.generation_cuda_graphs = False + self.cuda_graphs_static_batch_size = 64 + self.cuda_graphs_static_max_seq_len = 512 + self.cuda_graphs_static_max_context_len = 512 + + # Finetuning/calibration/generation settings + self.dataset_name = "timdettmers/openassistant-guanaco" + self.dataset_text_field = "text" + self.learning_rate = 1.41e-5 + self.batch_size = 64 + self.max_seq_length = 512 + self.gradient_accumulation_steps = 1 + self.num_warmup_steps = 5 + self.num_training_steps = 10 + + # Coalesced QKV params or not + self.fuse_qkv_params = False + + # Attention + self.is_paged = False + + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + + +# Global variable for the run configuration so that it can be easily accessed +# throughout the jupyter notebook with an `import * from utils` statement +run_config = RunConfiguration() + + +def get_dataloaders(run_config): + """ + Returns a basic dataloader for the dataset which contains tokenized batches + of text. + """ + dataset = load_dataset(run_config.dataset_name, split="train") + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(element): + outputs = tokenizer( + element["text"], + truncation=True, + padding=False, + max_length=run_config.max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + # Tokenize the dataset + dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) + + # Simply pad to the multiple of 16 for both FP8 and BF16 precision + pad_to_multiple_of = 16 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of, + ) + + dataloader_params = { + "batch_size": run_config.batch_size, + "collate_fn": data_collator, + "drop_last": True, + } + train_dataloader = DataLoader(dataset, **dataloader_params) + return train_dataloader + + +def ensure_model_is_downloaded(run_config): + """ + Downloads and caches the model weights if not already downloaded. A valid + Huggingface Access Token is required to download the model weights. + """ + assert run_config.model_name in [ + "google/gemma-7b", + ], "Only Gemma 7B model is supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(run_config.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + run_config.weights_cache_dir if run_config.weights_cache_dir != "" else None + ) + run_config.weights_cache_dir = snapshot_download( + repo_id=run_config.model_name, cache_dir=supplied_cache_dir + ) + + +def init_baseline_model(run_config): + """ + Initializes a baseline HF Gemma model with the model name provided in + the run_config. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + # Init the model + config = AutoConfig.from_pretrained(run_config.model_name) + + # Make sure to use flash_attention to do iso comparison with TEGemmaModel + config._attn_implementation = "flash_attention_2" + model = AutoModelForCausalLM.from_pretrained( + run_config.model_name, + config=config, + torch_dtype=torch.bfloat16, + ).cuda() + + return model + + +def init_te_gemma_model(run_config): + """ + Initializes a Gemma model with `GemmaDecoderLayer`s swapped with + `TransformerLayer`s from TransformerEngine. In case CUDA Graphs are enabled, + the model is initialized from `TEGemmaForCausalLMCudaGraphs` class. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + cls = TEGemmaForCausalLMCudaGraphs if run_config.generation_cuda_graphs else TEGemmaForCausalLM + config = AutoConfig.from_pretrained(run_config.model_name) + + # Inject all fields from the `run_config` to the model `config` to make the + # code simpler. + for key, value in run_config.__dict__.items(): + setattr(config, key, value) + + # Initialize the model and move it to the GPU. + model = load_te_model(cls, config).cuda() + + # Record the model if CUDA Graphs are enabled. + if run_config.generation_cuda_graphs: + model.record() + + return model + + +def restart_jupyter_notebook(): + # Try restarting the Jupyter kernel + IPython.Application.instance().kernel.do_shutdown(True) + + # Check whether the device memory has been flushed + if torch.cuda.memory_allocated() != 0: + import warnings + + warnings.warn("The device memory hasn't been flushed, trying with a second method!") + + # Try restarting the Jupyter kernel another way + # Restart the kernel + from IPython.core.display import HTML + + HTML("") + + if torch.cuda.memory_allocated() != 0: + print( + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!" + ) + + # Suppress the warnings + if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + torch.set_warn_always(False) + + +@torch.no_grad() +def run_forward_pass(model, run_config, num_iters): + """ + Runs the forward pass of the model with sample data. Intended to use for + warmup and/or calibration. + """ + train_dataloader = get_dataloaders(run_config) + + model.train() + train_dataloader = enumerate(train_dataloader) + + for _ in range(num_iters): + _, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + batch["attention_mask"] = batch["attention_mask"].cuda() + model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) + + +############################################################################### +# Benchmarking and example generation functions. +############################################################################### + + +def print_sample_of_generated_texts(model, run_config): + """ + Prints a sample of generated texts from the input model. + """ + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + prompts = [ + "Here are the two facts about GPUs:", + "Some facts about NVIDIA:", + "The fundamental theorem of calculus for the layman:", + "A fact about AI:", + ] + + # Repeat prompts to match batch size + prompts *= run_config.batch_size // len(prompts) + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + + max_length = inputs["input_ids"].size(1) + new_length = ((max_length + 63) // 64) * max_total_tokens + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (new_length - max_length, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + outputs = model.generate(**inputs, max_new_tokens=50) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def print_output(prompts, generated_texts, idx): + print("=" * 30 + f" Generation example {idx+1} " + "=" * 30) + print(f'Prompt: "{generated_texts[idx][: len(prompts[idx])]}"') + print(f'Generated text: "{generated_texts[idx][len(prompts[idx]) :]}"') + + # Print the output from first two prompts + for i in range(2): + print_output(prompts, generated_texts, i) + + +def _generate_random_words(num_words, max_word_length): + """ + Generates random words for the benchmark. + """ + + words = [] + for _ in range(num_words): + word_length = random.randint(1, max_word_length) + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + return words + + +def benchmark_generation(model, run_config, context_length=20): + """ + Benchmarks the generation time for a random input to the model. + """ + + batch_size = run_config.batch_size + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + max_new_tokens = max_total_tokens - context_length + + print("\n" + "=" * 80) + print( + f"Benchmarking for batch_size = {batch_size}, prefill tokens =" + f" {context_length} and max new tokens = {max_new_tokens}" + ) + + input_str = _generate_random_words(batch_size, context_length) + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + inputs = tokenizer(input_str, return_tensors="pt", padding=True) + + max_context_tokens = inputs["input_ids"].size(1) + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (max_total_tokens - max_context_tokens, 0), + value=tokenizer.pad_token_id, + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (max_total_tokens - max_context_tokens, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens) + torch.cuda.synchronize() + end.record() + + print(f"Time: {start.elapsed_time(end)/1000:.2f} s.") diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 7013e85ec6..00499cff5f 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -5,7 +5,7 @@ "id": "6a5b2993", "metadata": {}, "source": [ - "# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n", + "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n", "\n", "
\n", "\n", diff --git a/docs/index.rst b/docs/index.rst index e678b1d467..2c04810f4d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,7 @@ Transformer Engine documentation examples/fp8_primer.ipynb examples/advanced_optimizations.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb + examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/onnx/onnx_export.ipynb .. toctree:: diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 8d5417a45c..f0ef8d0bd5 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -215,6 +215,17 @@ def __init__( device=torch.cuda.current_device(), ) + # This internal buffer holds the running length of each + # unfinished sequence in the batch and is updated in `pre_step()` + # method. One use of this buffer is applying RoPE to q and k tensors + # during inference by slicing ROPE Embeddings according to the + # current sequence length window. + self.pre_step_seqlens = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() @@ -266,6 +277,15 @@ def pre_step( for k, v in self.sequences.items(): self.sequences_pre_step[k] = v - step_dict[k] + pre_step_seqlens_temp = torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) + + # Copy the pre-step seqlens to the device in CUDA Graphs safe manner. + self.pre_step_seqlens[: len(pre_step_seqlens_temp)].copy_( + pre_step_seqlens_temp, non_blocking=False + ) + seqlens_q = list(step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) @@ -280,9 +300,7 @@ def pre_step( def get_seqlens_pre_step(self): """Get cached sequence lengths before the stepping""" - return torch.Tensor(list(self.sequences_pre_step.values())).to( - dtype=torch.int32, device="cpu" - ) + return self.pre_step_seqlens def convert_paged_to_nonpaged(self, layer_number: int): """ @@ -458,14 +476,14 @@ def pre_step( finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_( + self.batch_indices.data[:].copy_( torch.Tensor( ( unfinished_indices + finished_indices + list(range(prev_batch_size, self.max_batch_size)) ) - ).to(dtype=torch.int32, device="cpu") + ) ) # Advance unfinished sequences diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 9c82442af6..5fd16bf1a1 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -889,23 +889,11 @@ def forward( q_pos_emb, k_pos_emb = rotary_pos_emb - # adjust key and value for inference - if inference_params is not None: - if self.qkv_format == "sbhd": - sequence_length = key_layer.size(0) - elif self.qkv_format == "bshd": - sequence_length = key_layer.size(1) - else: - raise ValueError( - f"qkv_format={self.qkv_format} not supported for KV caching and RoPE." - ) - - sequence_start = inference_params.get_seqlens_pre_step() - # sequence_start = inference_params.seqlens[0] - sequence_end = sequence_start + sequence_length - - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + # Applyig RoPE for inference needs start positions of sequences + # for each iteration. + sequence_start_positions = ( + inference_params.get_seqlens_pre_step() if inference_params is not None else None + ) if pad_between_seqs: rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded @@ -922,6 +910,7 @@ def forward( cu_seqlens=rotary_pos_cu_seq_lens_q, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) key_layer = apply_rotary_pos_emb( @@ -932,6 +921,7 @@ def forward( cu_seqlens=rotary_pos_cu_seq_lens_kv, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index d1ba1a351c..064da8a670 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor if (start_positions) { start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); } if (qkv_format == NVTE_QKV_Format::NVTE_THD) { diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9189ccc59..5749d96c9f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -883,7 +883,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] * self.num_gemms weight_quantizers = [ self.quantizers["scaling_fwd"][ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..ee24dc33f0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1767,7 +1767,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a6c55ceb79..9f799c5538 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -445,14 +445,19 @@ def forward( act_out = activation_func(fc1_out, None) act_out = tex.quantize(act_out, fc2_input_quantizer) else: - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8_calibration: + act_out = activation_func(fc1_out, None) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) - if fp8_calibration: - fc2_input_quantizer.calibrate(act_out) - fc2_weight_quantizer.calibrate(fc2_weight) + if not fp8 and fp8_calibration: + if fc2_input_quantizer is not None: + fc2_input_quantizer.calibrate(act_out) + if fc2_weight_quantizer is not None: + fc2_weight_quantizer.calibrate(fc2_weight) # Configure Userbuffers reduce-scatter if needed ub_obj_fc2out = None @@ -1897,7 +1902,7 @@ def _get_quantizers(self, fp8_output): fc2_grad_output_quantizer, ) = [None] * 10 fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() - if self.fp8: + if self.fp8 or self.fp8_calibration: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] @@ -2114,7 +2119,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None, None] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..3bc8074131 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1643,7 +1643,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True From 93a67af81a98f6542ecb2e414360bd0a74ca4367 Mon Sep 17 00:00:00 2001 From: yuzhongw-nvidia Date: Wed, 17 Sep 2025 13:59:52 +0800 Subject: [PATCH 23/78] Fix memory overhead of linear layer when all gather from sequence parallel (#2125) * fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon --------- Signed-off-by: Yuzhong Wang Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon --- .../pytorch/module/layernorm_linear.py | 23 +++++++++++++++---- transformer_engine/pytorch/module/linear.py | 16 ++++++++++++- .../_internal/float8_blockwise_tensor_base.py | 5 ++++ .../tensor/_internal/float8_tensor_base.py | 9 ++++++-- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ee24dc33f0..4d30be414e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -353,8 +353,11 @@ def forward( # Deallocate GEMM input tensor if no longer needed if not weight.requires_grad and not return_layernorm_output: - ln_out = ln_out_total = None clear_tensor_data(ln_out, ln_out_total) + ln_out = ln_out_total = None + elif with_input_all_gather and not return_layernorm_output_gathered: + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -891,9 +894,19 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted - if not ctx.return_layernorm_output: + # Deallocate input tensors if permitted + if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + # Input tensors have not been exposed externally + clear_tensor_data(ln_out) + elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: + # Non-gathered input has not been exposed externally + clear_tensor_data(ln_out) + if ctx.ln_out_needs_gather: + # Gathered input is internal clear_tensor_data(ln_out_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1169,7 +1182,9 @@ def __init__( self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output - self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.return_layernorm_output_gathered = ( + return_layernorm_output_gathered if return_layernorm_output else False + ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3bc8074131..7e526245c1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -317,6 +317,13 @@ def forward( # Finished forward GEMM... # ------------------------------------------------------ + # Deallocate GEMM input tensor if no longer needed + # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + if with_input_all_gather_nccl: + clear_tensor_data(inputmat_total) + inputmat_total = None + # ------------------------------------------------------ # Prepare output tensor # Note: Perform tensor-parallel communication @@ -878,9 +885,16 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted + # Deallocate tensors if permitted if ctx.owns_input: + # Input tensor is internal + clear_tensor_data(inputmat_total) + elif ctx.backward_input_needs_gather: + # Gathered input tensor is internal clear_tensor_data(inputmat_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index adffe7c580..da0220eb7a 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -349,9 +349,14 @@ def _create_columnwise(self): def _transpose_columnwise_data(self): """Plainly transpose the columnwise data and scale inv.""" if self._columnwise_data is not None: + # TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + _old_data = self._columnwise_data self._columnwise_data = tex.fp8_transpose( self._columnwise_data, self._fp8_dtype, out=None ) + _old_data.data = _empty_tensor() + del _old_data def __repr__(self): if self._rowwise_data is not None: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 61edc999ac..6d48223443 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -95,8 +95,13 @@ def __new__( return instance def clear(self): - """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" - for t in (self._data, self._transpose, self._scale_inv): + """Deallocate this tensor's memory. Typically not needed and must be used carefully. + + Scale-inv tensor is not deallocated because it's often shared + between multiple FP8 tensors. + + """ + for t in (self._data, self._transpose): if t is not None: t.data = _empty_tensor() self._transpose_invalid = True From eb69fad7b865f034015ad149b898a031fbd810d3 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:15:20 +1200 Subject: [PATCH 24/78] Fix incorrect TP rank calculation when using data parallel (#2179) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 8 ++++---- .../comm_gemm_overlap/userbuffers/userbuffers.cu | 14 ++++++++------ .../comm_gemm_overlap/userbuffers/userbuffers.h | 4 ++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 0874934958..ec29e6e120 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -607,10 +607,10 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr int comm_bytes_per_rank = comm_bytes / _tp_size; // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush - userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, - send_stream); - userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, - recv_stream); + userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, send_stream); + userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, recv_stream); // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf for (auto stream : {send_stream, recv_stream}) { diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 17f3cf658e..1dcd54d0d7 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2542,25 +2542,27 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream) { + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; for (int j = 1; j < tp_size; j++) { int i = (tp_rank + j) % tp_size; int send_offset = srcoffset + bytes_per_slice * tp_rank; int recv_offset = dstoffset + bytes_per_slice * tp_rank; - userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, - stream); + userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream) { + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; for (int j = tp_size - 1; j > 0; j--) { int i = (tp_rank + j) % tp_size; int send_offset = srcoffset + bytes_per_slice * i; int recv_offset = dstoffset + bytes_per_slice * i; - userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, - stream); + userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 8077f90be8..4d52fbb644 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -306,10 +306,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream); + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream); + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ From 8aee1bb774998556e8fcc1234e7bb137bd5d0c43 Mon Sep 17 00:00:00 2001 From: alan yang <89962857+cassiewilliam@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:23:15 +0800 Subject: [PATCH 25/78] [Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045) * feat: add cutlass group gemm support Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor multi tensor gemm interface Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor nvte_multi_stream_cublas_gemm func and add license info Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add unit test for cutlass group gemm Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add cutlass support type protect Signed-off-by: Min Yang * add tests and fix lint Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: fix unit tests error Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: refactor host workspace malloc Signed-off-by: Min Yang * update cutlass Signed-off-by: Xin Yao * update cutlass Signed-off-by: Xin Yao * further relex threshold and add a env var to warn fall back Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Min Yang Signed-off-by: Xin Yao Signed-off-by: alan yang <89962857+cassiewilliam@users.noreply.github.com> Co-authored-by: Min Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Phuong Nguyen --- .gitmodules | 3 + 3rdparty/cutlass | 1 + tests/pytorch/test_numerics.py | 68 +++- transformer_engine/common/CMakeLists.txt | 22 +- .../common/gemm/cublaslt_gemm.cu | 119 +++++- .../common/gemm/cutlass_grouped_gemm.cu | 77 ++++ .../common/gemm/cutlass_grouped_gemm.cuh | 348 ++++++++++++++++++ .../common/include/transformer_engine/gemm.h | 11 +- .../jax/csrc/extensions/gemm.cpp | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 9 +- 10 files changed, 633 insertions(+), 33 deletions(-) create mode 160000 3rdparty/cutlass create mode 100644 transformer_engine/common/gemm/cutlass_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/cutlass_grouped_gemm.cuh diff --git a/.gitmodules b/.gitmodules index 21492db5ef..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 0000000000..57e3cfb47a --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 57e3cfb47a2d9e0d46eb6335c3dc411498efa198 diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a50b3fbca5..a0e285b913 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -125,6 +125,11 @@ fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) +use_cutlass_grouped_gemm = [False] +# Only enable cutlass grouped gemm on Hopper +if torch.cuda.get_device_capability() == (9, 0): + use_cutlass_grouped_gemm.append(True) + def is_fused_attn_available( config: ModelConfig, @@ -1805,6 +1810,7 @@ def test_grouped_linear_accuracy( bias, delay_wgrad_compute, parallel_mode=None, + use_cutlass=False, ): fp8 = recipe is not None if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: @@ -1876,9 +1882,47 @@ def test_grouped_linear_accuracy( delay_wgrad_compute, ) - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + for o, o_ref in zip(outputs, outputs_ref): + if use_cutlass: + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + else: + # cuBLAS implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_accuracy_cutlass( + dtype, + num_gemms, + bs, + model, + fuse_wgrad_accumulation, + delay_wgrad_compute, +): + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + None, + False, + fuse_wgrad_accumulation, + False, + delay_wgrad_compute, + None, + use_cutlass=True, + ) + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2542,10 +2586,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): (16, 10027, 128, 512), ], ) -@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) -def test_grouped_gemm(shape, dtype, layout, accumulate): +@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): torch.manual_seed(0) z, m, k, n = shape @@ -2580,6 +2625,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): grad = True single_output = False + if use_cutlass: + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + for i in range(z): general_gemm( A[i], @@ -2607,9 +2655,15 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): single_output=single_output, ) - # should be bit-wise match for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + if not use_cutlass: + # cublas implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + if use_cutlass: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) @pytest.mark.parametrize("N", [32]) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cb9f13b899..08e876404c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -45,6 +45,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") endif() include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +set(CUTLASS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include") +set(CUTLASS_TOOLS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") + # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -81,6 +86,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn.cpp fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cutlass_grouped_gemm.cu normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -121,18 +127,30 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") - +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "gemm/cutlass_grouped_gemm.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") +else() + message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") +endif() # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) + target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +target_include_directories(transformer_engine PRIVATE + ${CUTLASS_INCLUDE_DIR} + ${CUTLASS_TOOLS_INCLUDE_DIR}) # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9e6c5417bc..f287072bcb 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -19,6 +19,7 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "cutlass_grouped_gemm.cuh" namespace { @@ -650,9 +651,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor CUBLAS_VERSION); #endif NVTE_CHECK( - cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + transformer_engine::cuda::cudart_version() >= 12020 && + transformer_engine::cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", - cuda::cudart_version()); + transformer_engine::cuda::cudart_version()); NVTE_CHECK( cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", @@ -675,13 +677,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor n_split, gemm_producer, inputCounter, stream); } -void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, - const NVTETensor *bias, NVTETensor *pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor *workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_stream_cublas_gemm); +void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { using namespace transformer_engine; int num_streams = nvte_get_num_compute_streams(); @@ -711,6 +711,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT } } +void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor *workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_stream_cublas_gemm); + using namespace transformer_engine; + + // Deprecation warning + NVTE_WARN( + "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. " + "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when " + "applicable)."); + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, + accumulate, use_split_accumulator, math_sm_count, stream); +} + namespace transformer_engine { using cublasHandleManager = detail::HandleManager; @@ -718,3 +737,85 @@ using cublasHandleManager = detail::HandleManager("NVTE_USE_CUTLASS_GROUPED_GEMM", false); + const bool warn_fallback = + transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); + + auto cublas_path = [&]() { + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + workspace, accumulate, use_split_accumulator, math_sm_count, stream); + }; + + // Currently only support cutlass group gemm on Hopper Arch + if (!(is_hopper && use_cutlass)) { + cublas_path(); + return; + } + + auto is_empty_arr = [&](const NVTETensor *p) -> bool { + if (p == nullptr) return true; + for (int i = 0; i < num_gemms; ++i) { + if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false; + } + return true; + }; + + auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { + int64_t ref_k = -1; + for (size_t i = 0; i < num_gemms; i++) { + const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]); + const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1]; + + if ((k & 127) != 0) return false; + + if (ref_k < 0) + ref_k = k; + else if (k != ref_k) + return false; + } + + return true; + }; + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_type = get_cuda_dtype(inputA->data.dtype); + auto B_type = get_cuda_dtype(inputB->data.dtype); + auto D_type = get_cuda_dtype(OutputD->data.dtype); + + return (A_type == B_type) && (A_type == D_type) && + ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); + }; + + // CUTLASS Grouped GEMM fast path (SM90/TMA) + // Conditions: + // - No fused epilogue: both bias and pre_gelu_out are empty. + // - Supported dtypes only: FP16/BF16 (FP32 accumulate). + // - Uniform K across groups and K % 128 == 0. + // - use_split_accumulator is ignored for FP16/BF16. + // - grad is irrelevant when bias/pre_gelu_out are empty. + // + // Otherwise, fall back to cuBLAS. + if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && + all_groups_uniform_k128(B, transb)) { + cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, + current_device, math_sm_count, stream); + } else { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cu b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu new file mode 100644 index 0000000000..18736c4f54 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass_grouped_gemm.cuh" + +namespace transformer_engine { +namespace grouped_gemm { + +// Explicit template instantiation to match the template declarations in the .cuh +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); + +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream) { + using namespace transformer_engine; + auto* inputA = convertNVTETensorCheck(A[0]); + auto* inputB = convertNVTETensorCheck(B[0]); + + float one = 1.0; + float zero = 0.0; + float alpha = one; + float beta = (accumulate) ? one : zero; + + auto dispatch = [&](auto tag) { + using T = decltype(tag); + if (!transa && !transb) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (!transb && transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (transb && !transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else { + NVTE_ERROR("Layout 'TT' is not supported by cutlass_grouped_gemm."); + } + }; + + if (inputA->data.dtype == DType::kBFloat16) { + dispatch(cutlass::bfloat16_t{}); + } else if (inputA->data.dtype == DType::kFloat16) { + dispatch(cutlass::half_t{}); + } else { + NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported."); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh new file mode 100644 index 0000000000..1add571325 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +// +// Copyright (c) 2025 Shopee Inc. All Rights Reserved. +// + +/** + * @file: cutlass_grouped_gemm.cuh + * @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com + * @date: 2025-08-08 16:20:00 + * @brief: cutlass group gemm kernel. + **/ + +#pragma once + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "common/util/system.h" +#include "cute/tensor.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +template +using GroupedGemmInputALayout = + std::conditional_t; + +template +using GroupedGemmInputBLayout = + std::conditional_t; + +using ProblemShapeType = cute::Shape; +using ProblemShape = cutlass::gemm::GroupProblemShape; // per group +template +struct GemmGivenSchedule { + using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand + using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand + using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands + + // A matrix configuration + using LayoutA = typename ScheduleConfig::LayoutA; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutB = typename ScheduleConfig::LayoutB; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using LayoutC = typename ScheduleConfig::LayoutC; // Layout type for C and D matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = + typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +struct ScheduleConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + // TODO(Alan): Add tuning for different scenarios to select the optimal configuration, + // as the current configuration may not be the best. + + // using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + // using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + // using TileShape = Shape; + // using ClusterShape = Shape; + + using LayoutA = GroupedGemmInputALayout; + using LayoutB = GroupedGemmInputBLayout; + using LayoutC = cutlass::layout::RowMajor; + using DataType = DataType_; +}; + +template +using GemmGrouped = typename GemmGivenSchedule>::Gemm; + +template +typename GemmT::Arguments MakeArguments(int num_experts, void* problem_sizes_host, + void* problem_sizes, const ElementA** ptr_A, + StrideA* stride_A, const ElementB** ptr_B, + StrideB* stride_B, ElementC** ptr_C, StrideC* stride_C, + float alpha, float beta, int device, int math_sm_count) { + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + + cutlass::KernelHardwareInfo kernel_hw_info = + cutlass::KernelHardwareInfo::make_kernel_hardware_info( + device, math_sm_count); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + arguments = + typename GemmT::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, reinterpret_cast(problem_sizes), + reinterpret_cast(problem_sizes_host)}, + {ptr_A, stride_A, ptr_B, stride_B}, + { + fusion_args, + (beta > 0.0) ? (const ElementC**)ptr_C : nullptr, // NOLINT(*) + stride_C, + ptr_C, + stride_C, + }, + kernel_hw_info}; + + return arguments; +} + +template +inline __device__ __host__ T ROUND_UP(T m, T n) { + return (m + n - 1) / n * n; +} + +template +void debug_type() { + std::cout << typeid(T).name() << std::endl; +} + +int64_t inline getGemmCoordSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(ProblemShapeType), 128UL)); +} + +int64_t inline getPtrSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(half*), 128UL)); +} + +int64_t inline getLddSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(int64_t), 128UL)); +} + +// cpu workspace size is 4MB +static constexpr size_t kCPUWorkSpaceSize = 4 * 1024 * 1024; + +static char* getHostWorkspace() { + static std::once_flag flag; + static std::shared_ptr workspace; + + std::call_once(flag, [&]() { + workspace = + std::shared_ptr(reinterpret_cast(std::malloc(kCPUWorkSpaceSize)), [](char* p) { + if (p) std::free(p); + }); + + if (!workspace) { + throw std::bad_alloc(); + } + }); + + return workspace.get(); +} + +template +void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + NVTETensor* workspace, float alpha, float beta, int num_gemms, + cudaStream_t stream, int device, int math_sm_count) { + using Gemm = GemmGrouped; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + + typename Gemm::Arguments arguments; + size_t kernel_workspace_size = Gemm::get_workspace_size(arguments); + auto gemm_coord_size = getGemmCoordSize(num_gemms); + auto ptr_size = getPtrSize(num_gemms); + auto ldd_size = getLddSize(num_gemms); + auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size; + + NVTE_CHECK( + param_workspace_size < kCPUWorkSpaceSize, + "Insufficient kCPUWorkSpaceSize size: required=", static_cast(param_workspace_size), + ", available=", static_cast(kCPUWorkSpaceSize), " for CUTLASS grouped GEMM."); + + auto total_workspace_size = param_workspace_size + kernel_workspace_size; + transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]); + + NVTE_CHECK(total_workspace_size < wspace->numel(), "Insufficient workspace[0] size: required=", + static_cast(total_workspace_size), + ", available=", static_cast(wspace->numel()), " for CUTLASS grouped GEMM."); + + char* workspace_ptr = reinterpret_cast(wspace->data.dptr); + + char* kernel_workspace_ptr = nullptr; + + char* host_workspace = getHostWorkspace(); + + ProblemShapeType* problem_sizes_host = reinterpret_cast(host_workspace); + + ElementA** ptr_A_host = reinterpret_cast(host_workspace + gemm_coord_size); + ElementB** ptr_B_host = reinterpret_cast(host_workspace + gemm_coord_size + ptr_size); + ElementC** ptr_C_host = + reinterpret_cast(host_workspace + gemm_coord_size + 2 * ptr_size); + int64_t* lda_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + int64_t* ldb_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + int64_t* ldc_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + for (size_t i = 0; i < num_gemms; i++) { + const transformer_engine::Tensor* inputA = transformer_engine::convertNVTETensorCheck(A[i]); + const transformer_engine::Tensor* inputB = transformer_engine::convertNVTETensorCheck(B[i]); + transformer_engine::Tensor* outputD = transformer_engine::convertNVTETensor(D[i]); + + const int m = trans_a ? inputA->data.shape[1] : inputA->data.shape[0]; + const int k = trans_a ? inputA->data.shape[0] : inputA->data.shape[1]; + const int n = trans_b ? inputB->data.shape[0] : inputB->data.shape[1]; + + auto problem = ProblemShapeType(m, n, k); + problem_sizes_host[i] = problem; + + ptr_A_host[i] = reinterpret_cast(inputA->data.dptr); + ptr_B_host[i] = reinterpret_cast(inputB->data.dptr); + ptr_C_host[i] = reinterpret_cast(outputD->data.dptr); + + lda_host[i] = LayoutA::packed({m, k}).stride(0); + ldb_host[i] = LayoutB::packed({k, n}).stride(0); + ldc_host[i] = LayoutC::packed({m, n}).stride(0); + } + + cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice, + stream); + + char* param_workspace_ptr = workspace_ptr; + ProblemShapeType* problem_sizes_device = reinterpret_cast(param_workspace_ptr); + const ElementA** ptr_A = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size); + const ElementB** ptr_B = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size + 1 * ptr_size); + ElementC** ptr_C = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 2 * ptr_size); + + StrideA* lda = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + StrideB* ldb = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + StrideC* ldc = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + kernel_workspace_ptr = workspace_ptr + param_workspace_size; + + arguments = MakeArguments( + num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + alpha, beta, device, math_sm_count); + + Gemm gemm; + + // Check can implement the kernel. + if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + } + + // Initialize the kernel. + if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + } + + // Execute the kernel in the current stream. + if (gemm.run(stream) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + } +} + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 50b33909fb..0c358328b6 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -133,12 +133,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor* workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream); +void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 113072131d..06dded1d86 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -526,10 +526,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } - nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, - lhs_is_trans, grad, workspace_list.data(), accumulate, - use_split_accumulator, num_math_sm, stream); + nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + grad, workspace_list.data(), accumulate, use_split_accumulator, + num_math_sm, stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 485d67055e..0d18a5ec5b 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -477,11 +477,10 @@ std::optional> te_general_grouped_gemm( // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), - te_bias_vector.data(), te_pre_gelu_out_vector.data(), - te_A_vector.size(), transa, transb, grad, - te_workspace_vector.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(), + transa, transb, grad, te_workspace_vector.data(), accumulate, + use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); }); return bias; } From c334fc46bb166187b5b1da90b18ee16ddc4b2462 Mon Sep 17 00:00:00 2001 From: zhujian Date: Thu, 18 Sep 2025 11:14:08 +0800 Subject: [PATCH 26/78] [PyTorch] Support FA3 for MLA and with CP (#1907) feature(FA3,MLA,CP): 1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward 2. Update get_attention_backend method because FA3 support MLA now 3. Add CP MLA support for FA3 4. Add unit tests for FA3 MLA CP 5. Update attention doc Signed-off-by: zhujian --- docs/examples/attention/attention.ipynb | 2 +- .../attention/test_attention_with_cp.py | 8 + .../dot_product_attention/context_parallel.py | 249 +++++++++++------- .../attention/dot_product_attention/utils.py | 52 +++- 4 files changed, 205 insertions(+), 106 deletions(-) diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 6cd56d23da..61a6ad949f 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -390,7 +390,7 @@ "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", - "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n", + "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0e8501abf3..7078cb69de 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -36,6 +36,12 @@ 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA + "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA + "cp_3_2": ModelConfig( + 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 + ), # MLA + "cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA } @@ -81,6 +87,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) + if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently only support KV P2P!") subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f00bd573f1..09384217c6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -358,7 +358,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] return [ *[None] @@ -366,7 +366,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] if qkv_format == "thd": return [ @@ -829,6 +829,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -838,19 +851,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=True, **fa_forward_kwargs, @@ -985,6 +989,22 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1001,19 +1021,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1144,6 +1155,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1160,19 +1184,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1269,6 +1284,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1278,19 +1306,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1865,7 +1884,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -1875,16 +1914,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -1895,12 +1926,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = 0 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2016,7 +2046,29 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2026,16 +2078,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv // 2, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2046,12 +2090,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2160,7 +2203,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2170,16 +2233,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q // 2, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2190,12 +2245,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse_, *fa_backward_args_thd, @@ -2267,7 +2321,15 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q) - dkv_ = torch.empty_like(kv) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + dkv_ = torch.empty_like(kv) + dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] + dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2277,8 +2339,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) @@ -2287,12 +2349,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout, q, - kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], - kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + k_part, + v_part, out, softmax_lse, *fa_backward_args_thd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7097f4ba0f..fffda81365 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -126,10 +126,10 @@ class FlashAttentionUtils: # Please follow these instructions to install FA3 v3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git -(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install +(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 -(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py""" +(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False @staticmethod @@ -477,11 +477,10 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( - use_flash_attention_3 and FlashAttentionUtils.v3_is_installed - ): - logger.debug("Disabling FlashAttention as it does not support MLA.") - use_flash_attention = False + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + use_flash_attention_2 = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": logger.debug( @@ -508,10 +507,41 @@ def get_attention_backend( ".".join([str(i) for i in device_compute_capability]), ) use_flash_attention_2 = False - if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): - if FlashAttentionUtils.v3_is_installed: - logger.debug("Disabling FlashAttention 3 for head_dim > 128") - use_flash_attention_3 = False + if use_flash_attention_3: + + def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if head_dim_qk > 256 or num_heads % num_gqa_groups != 0: + return False + if head_dim_qk != head_dim_v: + cond1 = 128 < head_dim_qk <= 192 + cond2 = 96 < head_dim_v <= 128 + cond3 = head_dim_qk <= 64 and head_dim_v <= 512 + if not ((cond1 and cond2) or cond3): + return False + if head_dim_v > 256 and qkv_dtype not in (torch.bfloat16, torch.float16): + return False + return True + + if not _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if FlashAttentionUtils.v3_is_installed: + logger.debug( + "Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, " + "head_dim_qk, head_dim_v or qkv_dtype. " + "Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and " + "if head_dim_qk is different from head_dim_v, then " + "(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or " + "(head_dim_qk <= 64 and head_dim_v <= 512), and " + "if head_dim_qk is different from head_dim_v and head_dim_v > 256, then " + "qkv_dtype requires fp16 and bf16 data type. " + "Found: num_heads = %s, num_gqa_groups = %s, " + "head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s.", + num_heads, + num_gqa_groups, + head_dim_qk, + head_dim_v, + qkv_dtype, + ) + use_flash_attention_3 = False # Filter: QKV layout if qkv_format == "thd": From 7f77127cbe5dfc37d5ce02c2e7ba388cfa2a83d4 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:22:11 -0700 Subject: [PATCH 27/78] Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185) * Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix for cuDNN version condition check Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 10 +++++----- .../pytorch/attention/dot_product_attention/utils.py | 6 ++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 60b10862e6..795697635d 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -251,11 +251,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && - // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 || - cudnn_runtime_version == 91300) && - is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && - !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && + // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed + (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && + head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && + head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index fffda81365..9b2b9a1ac3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -434,8 +434,10 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13") + # Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version + # until the cuDNN bug is resolved + if device_compute_capability == (8, 9): + logger.debug("Disabling FusedAttention for KV caching for sm89") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") From 5b3092a0e40654436bec5ea0a0b0f7ad2887b20d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 19 Sep 2025 10:17:36 -0700 Subject: [PATCH 28/78] Changed VERSION to 2.9.0.dev0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 81006d78c6..8bfb1cae85 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.8.0.dev0 +2.9.0.dev0 From 57b4d7bc0350917cd2122b07f144bfeb4c04eb0b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 22 Sep 2025 12:58:24 -0400 Subject: [PATCH 29/78] [JAX] Remove import jax.extend.ffi (#2193) * remove import jax.extend.ffi Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/activation.py | 7 +------ transformer_engine/jax/cpp_extensions/attention.py | 9 +-------- transformer_engine/jax/cpp_extensions/base.py | 8 +------- transformer_engine/jax/cpp_extensions/normalization.py | 8 +------- transformer_engine/jax/cpp_extensions/quantization.py | 8 +------- transformer_engine/jax/cpp_extensions/softmax.py | 8 +------- 6 files changed, 6 insertions(+), 42 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index cdda201668..d0a4e58fb6 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -5,11 +5,10 @@ from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec @@ -37,10 +36,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index df89174b2c..625f42049f 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -8,11 +8,10 @@ from dataclasses import dataclass, replace from functools import partial, reduce from typing import Optional, Tuple -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes, lax +from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding from jax.experimental.custom_partitioning import SdyShardingRule @@ -49,12 +48,6 @@ ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - - __all__ = [ "FusedAttnHelper", "fused_attn_fwd", diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index c055705665..cc8a07860a 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,22 +7,16 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial -from packaging import version from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch +from jax import ffi -import jax import transformer_engine_jax -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - class BasePrimitive(metaclass=ABCMeta): """ diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 7a978c1b74..351767e367 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -7,11 +7,10 @@ import operator from functools import partial, cache, reduce from typing import Optional, Union -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec @@ -38,11 +37,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "layernorm_fwd", diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 1813734b5e..895913d0ac 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,11 +6,10 @@ from functools import reduce from typing import Tuple, Optional, Union import math -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec @@ -41,11 +40,6 @@ NoScaleTensor, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 43cb11a088..575a2dd3ab 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -6,22 +6,16 @@ from functools import partial, reduce import operator import warnings -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.sharding import PartitionSpec, NamedSharding from .base import BasePrimitive, register_primitive from .misc import get_padded_spec, check_valid_batch_dims from ..softmax import SoftmaxType -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "scaled_softmax_fwd", From 5e4e0b2c378d2b1ec2ee65dfa85124e1dd805389 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 22 Sep 2025 12:53:25 -0700 Subject: [PATCH 30/78] [PyTorch] Add sink attention support from cuDNN (#2148) * first draft; debug plan failure Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug uid error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak params Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add grad in output Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix prints in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * address review comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unfused grad; add softmax_type; add sink to bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix padding mask; add swa tests; remove requires_grad for off-by-one Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix indent Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix non-determinism and shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add GQA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add CP A2A; dq/dk mismatches Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP A2A; need cleaner solution Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP A2A; pending cudnn kernel change Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix world size in unit test; avoid thd format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 context Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak CP logging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allow no_mask/padding for SWA(left,0) Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "allow no_mask/padding for SWA(left,0)" This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add softmax_type to Jax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add cuDNN version control Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prettify tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip 9.13 for MLA, non 192/128 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename compare_with_error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * small cleanups and improvements Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix minor CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force sink/dsink to be float32 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch FE to GH FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return to GH TE main FE commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.14.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up before CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * bump up cudnn version Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add backend selection guard for unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add docstring for softmax type enums in C Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../attention/run_attention_with_cp.py | 398 +++++++++------ tests/pytorch/attention/test_attention.py | 273 ++++++---- .../attention/test_attention_with_cp.py | 46 +- tests/pytorch/attention/test_kv_cache.py | 1 - tests/pytorch/utils.py | 42 +- .../common/fused_attn/fused_attn.cpp | 216 ++++---- .../fused_attn_f16_arbitrary_seqlen.cu | 467 ++++++++++-------- .../fused_attn_f16_arbitrary_seqlen.h | 61 +-- .../common/fused_attn/fused_attn_fp8.cu | 2 + transformer_engine/common/fused_attn/utils.h | 12 +- .../include/transformer_engine/fused_attn.h | 110 +++-- .../common/util/pybind_helper.h | 4 + .../jax/csrc/extensions/attention.cpp | 167 ++++--- .../dot_product_attention/backends.py | 82 +-- .../dot_product_attention/context_parallel.py | 130 ++++- .../dot_product_attention.py | 44 +- .../attention/dot_product_attention/utils.py | 55 +++ .../pytorch/attention/multi_head_attention.py | 14 + .../pytorch/cpp_extensions/fused_attn.py | 23 + transformer_engine/pytorch/csrc/extensions.h | 25 +- .../pytorch/csrc/extensions/attention.cpp | 139 +++--- transformer_engine/pytorch/module/base.py | 15 +- transformer_engine/pytorch/transformer.py | 14 + 24 files changed, 1515 insertions(+), 827 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index deda80e537..1a7b4b78db 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 +Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0ad64204f7..7e47e7df8d 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -17,88 +17,18 @@ from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling +from utils import ModelConfig, compare_and_assert + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} -def run_dpa_with_cp( - dtype="bf16", - model=None, - qkv_format="bshd", - kernel_backend="FlashAttention", - cp_comm_type="p2p", - fp8_mha=False, +def generate_input_shapes( + qkv_format: str, + config: ModelConfig, + world_size: int, + kernel_backend: str, ): - """Test DotProductAttention module with context parallelism""" - - # args are passed as strings - fp8_mha = fp8_mha == "True" - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - if kernel_backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] - if kernel_backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] - - assert config.attn_mask_type in [ - "causal", - "no_mask", - ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if qkv_format == "thd": - if "causal" in config.attn_mask_type: - config.attn_mask_type = "padding_causal" - else: - config.attn_mask_type = "padding" - - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) - - print(f"[INFO] world_size:{world_size}, rank:{rank}") - - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) - - # create flash attn comm group for CP - cp_comm_ranks = range(world_size) - assert rank in cp_comm_ranks - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if cp_comm_type == "a2a+p2p": - assert ( - world_size % 2 == 0 - ), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!" - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - - if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) - - # instantiate core attn module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - ) - core_attn = core_attn.cuda() - - # create flash attn inputs if qkv_format == "bshd": q_input_shape = ( config.batch_size, @@ -191,34 +121,158 @@ def run_dpa_with_cp( cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_q_padded else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + assert False, f"{qkv_format=} is not supported!" + + return ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) + + +def get_tols(config, dtype): + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + atol = 2.5e-2 + rtol = 2.5e-2 + else: + atol = 3.5e-2 + rtol = 3.5e-2 + rmse_tol = 0.01 + elif dtype == "fp16": + atol = 5e-3 + rtol = 5e-3 + rmse_tol = 0.01 + elif dtype == "fp8": + atol = 5e-1 + rtol = 5e-1 + rmse_tol = 0.1 + else: + assert False, f"{dtype=} is not supported!" + + return atol, rtol, rmse_tol + +def run_dpa_with_cp( + dtype="bf16", + model=None, + qkv_format="bshd", + kernel_backend="FlashAttention", + cp_comm_type="p2p", + fp8_mha=False, + log_level=logging.WARNING, +): + """Test DotProductAttention module with context parallelism""" + logging.root.setLevel(log_level) + + # set up environment variables and config + fp8_mha = fp8_mha == "True" + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if kernel_backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + config = model_configs_flash_attn[model] + if kernel_backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + config = model_configs_fused_attn[model] + assert config.attn_mask_type in [ + "causal", + "no_mask", + ], f"{config.attn_mask_type=} is not supported!" + if qkv_format == "thd": + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" + else: + config.attn_mask_type = "padding" + + # set up distributed group + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + device_count = torch.cuda.device_count() + device = rank % device_count + torch.cuda.set_device(device) + logging.info(f"[Rank {rank}] Setup: world_size {world_size}") + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + + # set up communication group for CP + cp_comm_ranks = range(world_size) + assert rank in cp_comm_ranks + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" + " = 2." + ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + cp_comm_sub_groups = [] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) + if dtype == "fp8": + fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) + + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + ).cuda() + if config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() + for x in [q, k, v]: + x.requires_grad = True + dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) + if fp8_mha: + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) - # create flash attention bias if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() else: bias = None - # run core_attn without CP - for x in [q, k, v]: - x.requires_grad = True - + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) else: fp8_context = nullcontext() - with fp8_context: out = core_attn( q, @@ -236,8 +290,30 @@ def run_dpa_with_cp( out.backward(dout_fp8) else: out.backward(dout) + dq, dk, dv = q.grad, k.grad, v.grad + d_softmax_offset = None + if config.softmax_type != "vanilla": + d_softmax_offset = core_attn.softmax_offset.grad - # run core_attn wit CP + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") + + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.reset_fp8_meta_tensors() + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + # set up inputs q_, k_, v_, dout_, *rest = [ x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) ] @@ -267,8 +343,6 @@ def run_dpa_with_cp( ) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( @@ -276,19 +350,8 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - - if dtype == "fp8": - core_attn.reset_fp8_meta_tensors() - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) - else: - fp8_context = nullcontext() + # run attention with fp8_context: out_ = core_attn( q_, @@ -306,18 +369,23 @@ def run_dpa_with_cp( out_.backward(dout_fp8_) else: out_.backward(dout_) - if fp8_mha: assert isinstance(out, Float8Tensor) assert isinstance(out_, Float8Tensor) out = out.dequantize() out_ = out_.dequantize() - for x in [out_, q_.grad, k_.grad, v_.grad]: - assert torch.all(~torch.isnan(x)) - assert torch.all(~torch.isinf(x)) - - # compare results with and without CP + # get outputs + dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + d_softmax_offset_ = None + if config.softmax_type != "vanilla": + d_softmax_offset_ = core_attn.softmax_offset.grad.clone() + for x in [out_, dq_, dk_, dv_, d_softmax_offset_]: + if x is not None: + assert torch.all(~torch.isnan(x)) + assert torch.all(~torch.isinf(x)) + + ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -373,56 +441,70 @@ def run_dpa_with_cp( ).item() == 0 ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - elif dtype == "fp16": - tols = dict(atol=5e-3, rtol=5e-3) - elif dtype == "fp8": - tols = dict(atol=5e-1, rtol=5e-1) - rmse_tol = 0.1 - else: - assert False, f"{dtype} is an unsupported dtype!" - - def _rmse(a, b): - return torch.sqrt((a - b).square().mean()).item() - - def _error(a, b): - if dtype != "fp8": - torch.testing.assert_close(a, b, **tols) - else: - try: - torch.testing.assert_close(a, b, **tols) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert ( - rmse < rmse_tol * rmse_range - ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - if qkv_format == "bshd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[:, 0], b[:, 0]) - _error(a[:, 1], b[:, 1]) - elif qkv_format == "sbhd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[0], b[0]) - _error(a[1], b[1]) - elif qkv_format == "thd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a, b) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] + tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] + names = ["out", "dq", "dk", "dv", "d_softmax_offset"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i]: + if qkv_format == "bshd": + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + else: + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + # destroy distribution group dist.destroy_process_group() diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 56bfa14234..a5c3457791 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. import logging -import math import os import sys import pathlib @@ -50,27 +49,35 @@ sys.path.append(str(_current_file.parent.parent)) from utils import ( reset_rng_states, + compare_and_assert, ModelConfig, dtype_tols, get_available_attention_backends, ) -# Only run FP8 tests on H100 +# Check if hardware supports FP8 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() +# Reset RNG seed and states seed = 1234 -# Reset RNG states reset_rng_states() +# Reset FP8 global state manager @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield fp8.FP8GlobalStateManager.reset() +# Define F16 data types to test +param_types = [torch.float16] +if is_bf16_compatible(): + param_types.append(torch.bfloat16) +param_types_lean = [torch.bfloat16] + model_configs_base = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), "base_2_0": ModelConfig(2, 2048, 24, 128), @@ -86,12 +93,6 @@ def reset_global_fp8_state(): } -param_types = [torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher - param_types.append(torch.bfloat16) -param_types_lean = [torch.bfloat16] - - @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -125,12 +126,12 @@ def test_dot_product_attention( config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + # Get backends is_training = True available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -141,7 +142,6 @@ def test_dot_product_attention( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -227,6 +227,7 @@ def test_dot_product_attention( is_training, ) + # Compare results logging.info(f"[test_dot_product_attention]: is_training = {is_training}") if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") @@ -259,23 +260,102 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_softmax = { + # test: ModelConfig(b, sq, hq, dqk) + "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), + "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"), + "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"), + "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "softmax_2_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), + "softmax_2_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"), + "softmax_3_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one" + ), + "softmax_3_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable" + ), + "softmax_4_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal" + ), + "softmax_4_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="off-by-one", + ), + "softmax_4_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="learnable", + ), + "softmax_5_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal" + ), + "softmax_5_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="off-by-one", + ), + "softmax_5_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="learnable", + ), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention( + dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False + ) + + model_configs_mla = { - # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend - "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 - "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 + # test: ModelConfig(b, sq, hq, dqk) + "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), + "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), + "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), + "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), "mla_2_1": ModelConfig( 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 - ), # cross, 1 + ), "mla_2_2": ModelConfig( 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 - ), # cross, 1 - "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference - "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference + ), + "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), + "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), } @@ -289,7 +369,7 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), @@ -344,18 +424,16 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), - "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped - "bias_1_5": ModelConfig( - 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi" - ), # skipped + "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), + "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"), "bias_2_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_1": ModelConfig( 2, 128, @@ -364,10 +442,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_2_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_3": ModelConfig( 2, 2048, @@ -376,13 +454,11 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped - "bias_2_4": ModelConfig( - 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), + "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"), "bias_2_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), "bias_3_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), @@ -400,14 +476,14 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), "bias_3_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_1": ModelConfig( 2, 128, @@ -416,10 +492,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_3": ModelConfig( 2, 2048, @@ -428,10 +504,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_4": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_5": ModelConfig( 2, 2048, @@ -440,7 +516,7 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="alibi", - ), # skipped + ), } @@ -454,7 +530,7 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { - # test: b, h, hg, d, sq, skv, p, + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), @@ -492,7 +568,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "swa_1_1": ModelConfig(2, 2048, 16, 64), "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), @@ -532,7 +608,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { - # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type + # test: ModelConfig(b, sq, hq, dqk) "alibi_1_0": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" ), @@ -586,7 +662,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): model_configs_layout = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 128, 16, 64), "layout_0_1": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -634,7 +710,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), @@ -726,7 +802,6 @@ def _run_dot_product_attention( is_training: bool, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" - # Set RNG and environment varables reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" @@ -989,9 +1064,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tp_group=None, layer_number=1, attention_type=config.attn_type, + softmax_type=config.softmax_type, ).to(dtype=dtype, device="cuda") if not is_training: block = block.eval() + if is_training and config.softmax_type != "vanilla": + block.softmax_offset.requires_grad = True # Run a forward and backward pass if backend in ["FlashAttention", "UnfusedDotProductAttention"]: @@ -1026,12 +1104,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) if is_training: out.backward(d_out) - + d_softmax_offset = None + if is_training and config.softmax_type != "vanilla": + d_softmax_offset = block.softmax_offset.grad if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, (None, None, None, d_softmax_offset) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1060,18 +1140,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) + return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) else: - return out_orig, (None, None, None) + return out_orig, (None, None, None, d_softmax_offset) else: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, (None, None, None, d_softmax_offset) model_configs_te_layer = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), "te_1_1": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -1436,6 +1516,7 @@ def _run_transformer_layer( model_configs_fp8_extra_state = { + # test: ModelConfig(b, sq, hq, dqk) "large": ModelConfig(2, 128, 4, 128, num_layers=1), } @@ -1445,7 +1526,8 @@ def _run_transformer_layer( @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): +def test_dpa_fp8_extra_state(model, dtype): + """Test DotProductAttention module in FP8 with checkpointing""" config = model_configs_fp8_extra_state[model] # Test backend availability is_training = True @@ -1459,9 +1541,9 @@ def test_sanity_attention_extra_state(model, dtype): if not fused_attn_supported and not flash_attn_supported: pytest.skip("No attention backend available.") - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( + outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state( dtype, config, mimic_v1_6=True, checkpoint=True ) @@ -1483,7 +1565,8 @@ def test_sanity_attention_extra_state(model, dtype): ) -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): +def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + """Run DotProductAttention module in FP8 with checkpointing""" steps = 10 path = "checkpoint.pt" fp8_enabled = True @@ -1580,7 +1663,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 16, 128), "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), @@ -1600,33 +1683,6 @@ def get_model(dtype, config): qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] -def _rmse(a, b): - return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) - - -def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): - logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) - logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) - try: - if a.dtype != b.dtype: - a = a.to(b.dtype) - torch.testing.assert_close(a, b, atol=atol, rtol=rtol) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert rmse < rmse_tol * rmse_range, ( - name_a - + " vs " - + name_b - + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - ) - - @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @@ -1638,6 +1694,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): + """Test MultiHeadAttention module in FP8""" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] @@ -1691,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -1699,8 +1756,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) - _error( + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -1708,12 +1766,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) if is_training: for i in range(len(param_names[:1])): logging.debug("========== {:^25s} ==========".format(param_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -1721,10 +1780,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): + """Run MultiHeadAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1851,6 +1912,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): + """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] # TODO(cyang): think of another way to verify dropout results @@ -1920,7 +1982,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -1928,6 +1990,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) if config.dropout_p != 0.0: # test cuDNN FP8 dropout @@ -1935,7 +1998,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): fused_attn_fwd_fp8 == 1 ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." else: - _error( + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -1943,11 +2006,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) if is_training: for i, _ in enumerate(fused_attn_bwd_f16): logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -1955,11 +2019,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): - + """Run DotProductAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -2092,7 +2157,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_fp8 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_1": ModelConfig(1, 512, 1, 64), "fp8_2": ModelConfig(4, 512, 16, 64), "fp8_3": ModelConfig(1, 2048, 1, 128), @@ -2147,7 +2212,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol = 5e-1 rtol = 5e-1 rmse_tol = 0.13 - _error( + compare_and_assert( fused_attn_fwd_fp8, unfused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -2155,8 +2220,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) - _error( + compare_and_assert( fused_attn_bwd_fp8, unfused_attn_bwd_f16, "fused_attn_bwd_fp8", @@ -2164,6 +2230,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 7078cb69de..c752d07d82 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -6,6 +6,7 @@ import subprocess import sys import pathlib +import logging import pytest import torch @@ -19,13 +20,15 @@ sys.path.append(str(_current_file.parent.parent)) from utils import ModelConfig, get_available_attention_backends +pytest_logging_level = logging.getLevelName(logging.root.level) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) model_configs_flash_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA @@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") config = model_configs_flash_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": @@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ) if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} + available_backends, *_ = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + ) + flash_attn_supported, *_ = available_backends + if not flash_attn_supported: + pytest.skip("No attention backend available.") subprocess.run( get_bash_arguments( @@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ), check=True, ) model_configs_fused_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig( @@ -135,6 +150,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA + "cp_4_0": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla" + ), # GQA + "cp_4_1": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), # GQA + "cp_4_2": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), # GQA } @@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("FP8 attention is only supported on sm90+!") config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": @@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + if dtype == "fp8" and config.softmax_type != "vanilla": + pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip( + "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + ) + if config.softmax_type != "vanilla" and qkv_format == "thd": + pytest.skip( + "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + ) + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} available_backends, _, fused_attn_backends = get_available_attention_backends( config, - qkv_dtype=dtypes[dtype], + qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), - window_size=config.window_size, - context_parallel=True, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, fp8_mha=fp8_mha, + log_level=pytest_logging_level, ), check=True, ) diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index 288c5382e6..4dc3af411a 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=False, is_training=False, fp8=is_fp8, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 38f400f659..9e90f9fdad 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -20,6 +20,7 @@ get_attention_backend, AttentionParams, AttentionLogging, + check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend @@ -137,6 +138,31 @@ def reset_rng_states() -> None: torch.cuda.set_rng_state(cuda_rng_state) +def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): + if not is_fp8: + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + return + + try: + if a.dtype != b.dtype: + a = a.to(b.dtype) + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + except Exception as e: + logging.debug(e) + + rmse = torch.sqrt((a - b).square().mean()).item() + logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert rmse < rmse_tol * rmse_range, ( + name_a + + " vs " + + name_b + + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + ) + + class ModelConfig: def __init__( self, @@ -147,12 +173,15 @@ def __init__( max_seqlen_kv: int = None, num_gqa_groups: int = None, head_dim_v: int = None, + softmax_type: str = "vanilla", dropout_p: float = 0.0, attn_mask_type: str = "no_mask", attn_bias_type: str = "no_bias", alibi_type: str = "none", bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), + context_parallel: bool = False, + cp_comm_type: str = "p2p", total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, @@ -171,13 +200,16 @@ def __init__( self.kv_channels = (self.head_dim_qk, self.head_dim_v) self.hidden_size = self.num_heads * self.head_dim_qk self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v + self.softmax_type = softmax_type self.dropout_p = dropout_p self.attn_mask_type = attn_mask_type self.attn_bias_type = attn_bias_type self.alibi_type = alibi_type self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape - self.window_size = window_size + self.window_size = check_set_window_size(self.attn_mask_type, window_size) + self.context_parallel = context_parallel + self.cp_comm_type = cp_comm_type self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers @@ -198,9 +230,7 @@ def get_available_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), pad_between_seqs: bool = False, - context_parallel: bool = False, deterministic: bool = False, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, @@ -250,19 +280,21 @@ def test(): head_dim_qk=config.head_dim_qk, head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, - window_size=window_size, + window_size=config.window_size, alibi_slopes_shape=alibi_slopes_shape, core_attention_bias_type=config.attn_bias_type, core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_requires_grad=core_attention_bias_requires_grad, pad_between_seqs=pad_between_seqs, attention_dropout=config.dropout_p, - context_parallel=context_parallel, + context_parallel=config.context_parallel, + cp_comm_type=config.cp_comm_type, deterministic=deterministic, fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, inference_params=inference_params, + softmax_type=config.softmax_type, ) ( use_flash_attention, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 795697635d..77cd8d235a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // TODO (cyang): add is_training to nvte_get_fused_attn_backend // sm90: fwd d<=256, bwd d=128 only // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || @@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && + !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000)) { if (cudnn_runtime_version >= 8900) { @@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) { flag_m512 = true; } if ( @@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // check 64-bit ragged offset support (supported_ragged_offset_size) && // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && + // softmax type + // pre-9.13.1: vanilla + // 9.13.1+: vanilla, off-by-one, learnable + (cudnn_runtime_version >= 91301 || + (cudnn_runtime_version < 91301 && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); const Tensor *input_QKV = convertNVTETensorCheck(QKV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); @@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con Tensor *input_output_dP = convertNVTETensorCheck(dP); Tensor *output_dQKV = convertNVTETensorCheck(dQKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_QKV->data.shape.size(); @@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd_qkvpacked( b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, - input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, + input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, + stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } // NVTE fused attention FWD with packed KV void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked( const Tensor *input_Q = convertNVTETensorCheck(Q); const Tensor *input_KV = convertNVTETensorCheck(KV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor *output_dQ = convertNVTETensorCheck(dQ); Tensor *output_dKV = convertNVTETensorCheck(dKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked( #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, + input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, + output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked( } // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_K = convertNVTETensorCheck(K); const Tensor *input_V = convertNVTETensorCheck(V); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dK = convertNVTETensorCheck(dK); Tensor *output_dV = convertNVTETensorCheck(dV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_Q->data.shape.size(); @@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, - output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4e6c3c858b..1d6435ad8a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_q = is_ragged_q ? max_t_q : s_q; s_kv = is_ragged_kv ? max_t_kv : s_kv; } - const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { FADescriptor_v1 descriptor{b, h, @@ -122,6 +124,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, true, @@ -138,6 +141,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // O std::shared_ptr, // Stats std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // page_table_k @@ -168,7 +172,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr Q, K, V, attn_scale; + std::shared_ptr Q, K, V, attn_scale, softmax_offset; std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr page_table_k, page_table_v; std::shared_ptr offset_q, offset_k, offset_v, offset_o, @@ -302,6 +306,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_sink_token(softmax_offset); + } + auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); std::vector o_stride(4); @@ -338,6 +351,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) @@ -358,17 +373,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, - page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_fprop_cache, descriptor); + auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv, + page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -473,6 +489,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } + + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -483,14 +504,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, + void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -506,6 +527,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -558,6 +580,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, @@ -579,6 +602,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // dV std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // offset_q @@ -608,7 +633,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_compute_data_type(fe::DataType_t::FLOAT); std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset, + seq_q, seq_kv; std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; @@ -771,6 +797,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_dsink_token(d_softmax_offset); + } + auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); @@ -796,6 +837,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr> // dV key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto offset_qo_tuple = @@ -814,17 +858,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, - offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_bprop_cache, descriptor); + auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset, + d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -938,6 +982,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -949,8 +998,9 @@ using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -977,6 +1027,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -990,53 +1044,50 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( max_tokens = get_max_tokens(num_tokens); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1050,11 +1101,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, - devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, - handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, + nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1074,9 +1125,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1122,6 +1174,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; @@ -1135,11 +1193,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, - devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, + devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1161,12 +1219,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1192,6 +1250,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -1216,53 +1278,50 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1277,11 +1336,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1302,10 +1361,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1359,6 +1419,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1374,9 +1440,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1401,12 +1468,13 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1425,6 +1493,10 @@ void fused_attn_arbitrary_seqlen_fwd( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1446,53 +1518,50 @@ void fused_attn_arbitrary_seqlen_fwd( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1507,11 +1576,11 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1532,13 +1601,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1577,6 +1647,12 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdV = output_dV->data.dptr; void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1592,9 +1668,10 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index e1a20274f4..b9658b0530 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -21,17 +21,19 @@ namespace transformer_engine { void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d7f0983763..995dbda7fb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, true, @@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, false, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 678b636910..0a0197423c 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -107,6 +107,7 @@ struct FADescriptor_v1 { NVTE_QKV_Layout layout; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; + NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; bool deterministic; @@ -116,14 +117,15 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, - window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, + bwd_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, - rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, - rhs.bwd_tensor_type); + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, + rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, + rhs.fwd_tensor_type, rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 44f5791490..a150978c4a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -124,6 +124,24 @@ enum NVTE_Mask_Type { NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5, }; +/*! \enum NVTE_Softmax_Type + * \brief Attention softmax types as described in + * Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3). + * For a given attention score S = Q*K^T, different softmax types perform different operations on S, + * NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + * NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + * NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + * where alpha is a learnable parameter in shape [H]. + */ +enum NVTE_Softmax_Type { + /*! Vanilla softmax */ + NVTE_VANILLA_SOFTMAX = 0, + /*! Off-by-one softmax */ + NVTE_OFF_BY_ONE_SOFTMAX = 1, + /*! Learnable softmax */ + NVTE_LEARNABLE_SOFTMAX = 2, +}; + /*! \enum NVTE_Fused_Attn_Backend * \brief Fused attention backends */ @@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] bias_type The attention bias type. * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. * \param[in] dropout The dropout probability. * \param[in] num_attn_heads The number of heads in Q. * \param[in] num_gqa_groups The number of heads in K, V. @@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * @@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * e.g. M, ZInv, rng_state. * \param[out] dQKV The gradient of the QKV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. * \param[in] max_seqlen Max sequence length used for computing, @@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); @@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] Q The Q tensor, in HD layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dQ The gradient of the Q tensor. * \param[out] dKV The gradient of the KV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] K The K tensor. * \param[in] V The V tensor. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[out] dK The gradient of the K tensor. * \param[out] dV The gradient of the V tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 67d21f6183..68b7aa8bbe 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -36,6 +36,10 @@ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_Softmax_Type", pybind11::module_local()) \ + .value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \ + .value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \ + .value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \ pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 40089dc2d6..9277569e11 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -18,10 +18,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right) { + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); return backend; } @@ -146,6 +147,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); @@ -172,28 +176,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + ragged_offset_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, + kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, + query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); } @@ -262,10 +268,15 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; + auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -280,12 +291,12 @@ static void FusedAttnForwardImpl( if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, is_training, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto kv_shape = @@ -293,12 +304,13 @@ static void FusedAttnForwardImpl( auto q_tensor = TensorWrapper(q, q_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -307,12 +319,13 @@ static void FusedAttnForwardImpl( auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -444,6 +457,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 min_num_segments = input_batch * max_segments_per_seq; } + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = @@ -453,37 +469,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), - nullptr); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, - query_workspace_tensor.data(), nullptr); + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -515,14 +532,17 @@ static void FusedAttnBackwardImpl( /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); @@ -540,10 +560,11 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto kv_shape = @@ -562,10 +583,11 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -586,11 +608,12 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, + kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index afa1bae633..4a60bd9fe1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -13,6 +13,7 @@ from packaging.version import Version as PkgVersion import torch +import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( SplitAlongDim, @@ -142,6 +143,7 @@ def __init__( attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -149,6 +151,7 @@ def __init__( self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number + self.softmax_type = softmax_type def mask_func(x, y): return ( @@ -185,6 +188,7 @@ def forward( core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -326,7 +330,21 @@ def forward( dtype=query_layer.dtype ) - # attention scores and attention mask [b, np, sq, sk] + # add attention sink to the last column: [b, np, sq, sk+1] + if self.softmax_type != "vanilla": + matmul_result = torch.cat( + [ + matmul_result, + softmax_offset.to(dtype=matmul_result.dtype).expand( + matmul_result.size(0), -1, matmul_result.size(2), -1 + ), + ], + dim=-1, + ) + attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False) + attn_mask_type = "arbitrary" + + # attention scores and attention mask softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( matmul_result, attention_mask, attn_mask_type, softmax_scale @@ -337,6 +355,10 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) + # remove attention sink: [b, np, sq, sk] + if self.softmax_type != "vanilla": + attention_probs = attention_probs[..., :-1] + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -917,6 +939,7 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, fused_attention_backend, @@ -925,6 +948,7 @@ def forward( fp8_meta, quantizers, deterministic, + softmax_offset, ): # pylint: disable=missing-function-docstring # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype @@ -997,8 +1021,10 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, ) if is_output_fp8: out_ret = out_fp8 @@ -1059,8 +1085,10 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, ) out_save = out_ret fp8_tensors = (None, None, None, None) @@ -1114,6 +1142,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.softmax_type = softmax_type ctx.window_size = window_size ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] @@ -1224,6 +1253,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, ) @@ -1287,42 +1317,17 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, ) - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - None, - None, - None, - None, - None, - dq, - dk, - dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) + d_bias = None + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + d_softmax_offset = None + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] return ( None, None, @@ -1336,7 +1341,8 @@ def backward(ctx, d_out): dq, dk, dv, - rest[0], + d_bias, + None, None, None, None, @@ -1351,6 +1357,7 @@ def backward(ctx, d_out): None, None, None, + d_softmax_offset, ) @@ -1390,6 +1397,7 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -1402,6 +1410,7 @@ def __init__( ) == "1" and get_device_compute_capability() == (9, 0) self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.softmax_type = softmax_type def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1453,6 +1462,7 @@ def forward( quantizers=None, pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -1603,6 +1613,8 @@ def forward( fp8_meta=fp8_meta, quantizers=quantizers, pad_between_seqs=pad_between_seqs, + softmax_type=self.softmax_type, + softmax_offset=softmax_offset, ) else: with self.attention_dropout_ctx(): @@ -1626,6 +1638,7 @@ def forward( qkv_layout, core_attention_bias_type, attn_mask_type, + self.softmax_type, window_size, None, # rng_gen fused_attention_backend, @@ -1634,6 +1647,7 @@ def forward( fp8_meta, quantizers, self.deterministic, + softmax_offset, ) # ...hd -> ...(hd) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 09384217c6..2e4b6b6177 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -46,6 +46,7 @@ _cu_seqlens_info_with_cp_cache = {} _seq_chunk_ids_cache_for_reordering_before_attn = {} _seq_chunk_ids_cache_for_reordering_after_attn = {} +_softmax_offset_chunk_ids_cache = {} def flash_attn_p2p_communicate( @@ -318,6 +319,55 @@ def flash_attn_a2a_communicate( return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs +def flash_attn_a2a_communicate_softmax_offset( + tensor: torch.Tensor, + h_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Split/AllGather communication for softmax offset.""" + if tensor is None: + return None + + global _softmax_offset_chunk_ids_cache + device = tensor.device + if (cp_size, device) not in _softmax_offset_chunk_ids_cache: + chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device) + _softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids + else: + chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)] + + if before_attn: + # softmax_offset: split round-robin to CP ranks + # [1, h, 1, 1] -> [1, cp, h//cp, 1, 1] + shape = tensor.shape + tensor = tensor.view( + *shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :] + ) + rank = get_distributed_rank(cp_group) + output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank]) + output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :]) + else: + # d_softmax_offset: all-gather from all ranks to all ranks + # [1, h//cp, 1, 1] -> [1, h, 1, 1] + inp = tensor.view(-1) + output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device) + with torch.cuda.stream(cp_stream): + torch.distributed.all_gather_into_tensor( + output, + inp, + group=cp_group, + async_op=False, + ) + torch.cuda.current_stream().wait_stream(cp_stream) + output = output.view( + *tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :] + ) + return output + + def _get_cu_seqlens_info_with_cp( batch_size: int, max_seqlen: int, @@ -1854,7 +1904,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], @@ -2014,7 +2064,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, cu_seqlens_q_per_step[cp_size - i - 1], @@ -2171,7 +2221,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], @@ -2289,7 +2339,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], @@ -3122,7 +3172,7 @@ def backward(ctx, dout): dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] - dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( + dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, @@ -3283,6 +3333,8 @@ def forward( cp_stream, quantizers, use_flash_attn_3, + softmax_type, + softmax_offset, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3391,6 +3443,10 @@ def forward( q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) + if softmax_type != "vanilla": + softmax_offset = flash_attn_a2a_communicate_softmax_offset( + softmax_offset, 1, cp_size, cp_group, cp_stream, True + ) if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v @@ -3430,6 +3486,8 @@ def forward( cu_seqlens_kv_padded=cu_seqlens_kv_padded, window_size=window_size, **fp8_meta_kwargs, + softmax_type=softmax_type, + softmax_offset=softmax_offset, ) if fp8: out = out._data @@ -3532,6 +3590,7 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.softmax_type = softmax_type ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer @@ -3695,7 +3754,7 @@ def backward(ctx, dout): dout_part, fake_dtype=dout_dtype, internal=True ) - dq, dk, dv, _ = fused_attn_bwd( + dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -3719,6 +3778,7 @@ def backward(ctx, dout): window_size=ctx.window_size, deterministic=ctx.deterministic, **fp8_meta_kwargs, + softmax_type=ctx.softmax_type, ) if ctx.fp8: dq = dq._data @@ -3763,6 +3823,17 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + d_bias = None + d_softmax_offset = None + if ctx.use_fused_attention: + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] + d_softmax_offset = flash_attn_a2a_communicate_softmax_offset( + d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + if ctx.fp8: dq = ctx.dQKV_quantizer.create_tensor_from_data( dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 @@ -3793,6 +3864,7 @@ def backward(ctx, dout): None, None, None, + d_bias, None, None, None, @@ -3803,6 +3875,7 @@ def backward(ctx, dout): None, None, None, + d_softmax_offset, ) @@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp( quantizers=None, pad_between_seqs=False, use_flash_attn_3=False, + softmax_type="vanilla", + softmax_offset=None, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp( else: assert isinstance( cp_group, dist_group_type - ), f"Unsupported process group for CP communication type {cp_comm_type}!" + ), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!" assert qkv_format in [ "bshd", "sbhd", "thd", - ], f"QKV format of {qkv_format} is not supported with context parallelism!" + ], f"Context parallelism does not support {qkv_format=}!" assert ( qkv_format != "sbhd" or use_fused_attention - ), "FlashAttention does not support sbhd format!" + ), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!" assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( - """Attention bias is only supported with FusedAttention and "causal" """ - """or "no_mask" mask types!""" + "Context parallelism only supports attention bias with FusedAttention backend and" + " non-padding mask types!" ) assert qkv_format != "thd" or ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" + ), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!" sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) @@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "The context parallel running configs cannot support sliding window attetnion!" + ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "The context parallel running configs cannot support MLA!" + ], "Context parallelism does not support MLA with {cp_comm_type=}!" + + if fp8 and fp8_meta is not None: + if fp8_meta["recipe"].fp8_dpa: + assert ( + softmax_type == "vanilla" + ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + assert ( + softmax_type == "vanilla" or use_fused_attention + ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + assert ( + softmax_type == "vanilla" or cp_comm_type == "a2a" + ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + assert ( + softmax_type == "vanilla" or qkv_format != "thd" + ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" args = [ is_training, @@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp( args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] + args += [ + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + quantizers, + use_flash_attn_3, + softmax_type, + softmax_offset, + ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b35b87a83f..f72cd69262 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -11,6 +11,7 @@ import logging import torch +from torch.nn.parameter import Parameter import transformer_engine_torch as tex from transformer_engine.pytorch.utils import get_cudnn_version @@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule): softmax_scale: Optional[float], default = `None` softmax scale for the attention scores. If `None`, defaults to `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -223,6 +235,7 @@ def __init__( cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -307,6 +320,20 @@ def __init__( self.attention_type = attention_type self.attention_dropout = attention_dropout + self.softmax_type = softmax_type + if self.softmax_type == "vanilla": + self.softmax_offset = None + if self.softmax_type == "off-by-one": + self.softmax_offset = torch.zeros( + self.num_attention_heads // self.tp_size, device="cuda" + ) + if self.softmax_type == "learnable": + self.register_parameter( + "softmax_offset", + Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + get_rng_state_tracker=get_rng_state_tracker, + ) + attn_kwargs = { "attention_dropout": attention_dropout, "attention_dropout_ctx": attention_dropout_ctx, @@ -328,6 +355,7 @@ def __init__( layer_number=layer_number, deterministic=self.deterministic, **attn_kwargs, + softmax_type=self.softmax_type, ) self.unfused_attention = UnfusedDotProductAttention( @@ -335,6 +363,7 @@ def __init__( attention_type=attention_type, **attn_kwargs, layer_number=layer_number, + softmax_type=self.softmax_type, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -634,6 +663,7 @@ def forward( query_layer, num_gemms=3, allow_non_contiguous=True, + allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): @@ -922,6 +952,7 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + # check if there is padding between sequences when qkv_format='thd' if pad_between_seqs is None: if qkv_format == "thd": pad_between_seqs = ( @@ -957,11 +988,13 @@ def forward( pad_between_seqs=pad_between_seqs, attention_dropout=self.attention_dropout, context_parallel=context_parallel, + cp_comm_type=self.cp_comm_type, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, inference_params=inference_params, + softmax_type=self.softmax_type, ) global _attention_backends if is_in_onnx_export_mode(): @@ -1022,6 +1055,12 @@ def forward( ) # run attention + softmax_offset = ( + self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) + if self.softmax_offset is not None + else None + ) + if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = dpa_utils.get_alibi( @@ -1071,7 +1110,6 @@ def forward( bias_dtype=query_layer.dtype, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) - # checkpoint_core_attention=False if checkpoint_core_attention: return self._checkpointed_attention_forward( self.fused_attention, @@ -1101,6 +1139,7 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, + softmax_offset=softmax_offset, ) return self.fused_attention( query_layer, @@ -1129,6 +1168,7 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, + softmax_offset=softmax_offset, ) from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled @@ -1157,6 +1197,7 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, ) return self.unfused_attention( _alibi_cache, @@ -1173,5 +1214,6 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, ) return None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9b2b9a1ac3..72c595e3ff 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -24,6 +24,7 @@ QKVLayout, AttnBiasType, AttnMaskType, + SoftmaxType, FusedAttnBackend, META_QKV, META_DQKV, @@ -206,6 +207,8 @@ class AttentionParams: Attention dropout. context_parallel: bool, default = `False` Whether context parallelism is used or not. + cp_comm_type: str, default = "p2p" + The communication type of context parallelism. deterministic: bool, default = `False` Whether to run `DotProductAttention` with determinism or not. is_training: bool, default = `True` @@ -216,6 +219,8 @@ class AttentionParams: The FP8 metadata tensor of `DotProductAttention`. inference_params: Optional[InferenceParams], default = `None` Inference-related parameters. See InferenceParams for details. + softmax_type: str, default = "vanilla" + The type of softmax operation. See DotProductAttention for details. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -237,11 +242,13 @@ class AttentionParams: pad_between_seqs: bool = False attention_dropout: float = 0.0 context_parallel: bool = False + cp_comm_type: str = "p2p" deterministic: bool = False is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None inference_params: Optional[InferenceParams] = None + softmax_type: str = "vanilla" def __eq__(self, other): """ @@ -308,11 +315,13 @@ def get_attention_backend( pad_between_seqs = attention_params.pad_between_seqs attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel + cp_comm_type = attention_params.cp_comm_type deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta inference_params = attention_params.inference_params + softmax_type = attention_params.softmax_type # Run config logger = logging.getLogger("DotProductAttention") @@ -565,6 +574,51 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention 3 for dropout") use_flash_attention_3 = False + # Filter: Softmax type + # context_parallel | softmax_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # no | vanilla | All + # no | off-by-one | FusedAttention, UnfusedDotProductAttention + # no | learnable | FusedAttention, UnfusedDotProductAttention + # yes | vanilla | FusedAttention, FlashAttention + # yes | off-by-one | FusedAttention + # yes | learnable | FusedAttention + if softmax_type != "vanilla": + logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) + use_flash_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type + ) + use_unfused_attention = False + if qkv_format == "thd": + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + ) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", + softmax_type, + ) + use_unfused_attention = False + if context_parallel: + logger.debug( + "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" + " = %s", + softmax_type, + ) + use_unfused_attention = False + if cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False + # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends # ---------------------------------------------------------------------------------------------------- @@ -806,6 +860,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], attention_dropout, num_heads, num_gqa_groups, diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 5fd16bf1a1..790d78c75e 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module): For that, please use `get_qkv_layout` to gain the layout information. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -245,6 +256,7 @@ def __init__( qk_norm_before_rope: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -262,6 +274,7 @@ def __init__( self.return_bias = return_bias self.cp_size = 1 self.cp_rank = 0 + self.softmax_type = softmax_type kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) @@ -416,6 +429,7 @@ def __init__( tp_group=tp_group, layer_number=self.layer_number, attention_type=self.attention_type, + softmax_type=self.softmax_type, ) # Linear diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b9810bf861..df2f5d1cab 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -12,6 +12,7 @@ NVTE_QKV_Format, NVTE_Bias_Type, NVTE_Mask_Type, + NVTE_Softmax_Type, NVTE_Fused_Attn_Backend, ) from ..tensor.quantized_tensor import Quantizer @@ -86,6 +87,12 @@ "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, } +SoftmaxType = { + "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, + "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, +} + FusedAttnBackend = { "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, @@ -131,8 +138,10 @@ def fused_attn_fwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), rng_gen: torch.Generator = None, + softmax_offset: torch.Tensor = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -197,6 +206,8 @@ def fused_attn_fwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -205,6 +216,9 @@ def fused_attn_fwd( rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + softmax_offset: torch.Tensor, default = None + softmax offset tensor in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. Returns ---------- @@ -286,6 +300,7 @@ def fused_attn_fwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, cu_seqlens_q, cu_seqlens_kv, @@ -300,6 +315,7 @@ def fused_attn_fwd( s_quantizer, o_quantizer, attn_bias, + softmax_offset, rng_gen, rng_elts_per_thread, ) @@ -333,6 +349,7 @@ def fused_attn_bwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -398,6 +415,8 @@ def fused_attn_bwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -417,6 +436,9 @@ def fused_attn_bwd( d_bias: torch.Tensor, optional gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias + d_softmax_offset: torch.Tensor, optional + gradient tensor of softmax offset in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. """ if attn_scale is None: d = q.size(-1) @@ -454,6 +476,7 @@ def fused_attn_bwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, deterministic, cu_seqlens_q, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4cb05725bc..4edc6d81e1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -73,28 +73,31 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread); + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 6d835a5c94..8179727e58 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -58,13 +58,14 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); + bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, + max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } @@ -72,14 +73,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread) { + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread) { TensorWrapper te_Q, te_K, te_V, te_O, te_S; auto none = py::none(); @@ -181,6 +183,16 @@ std::vector fused_attn_fwd( DType::kInt32, nullptr, nullptr, nullptr); } + // softmax offset + TensorWrapper te_SoftmaxOffset; + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); + std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; + te_SoftmaxOffset = + makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); @@ -199,11 +211,11 @@ std::vector fused_attn_fwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -215,51 +227,52 @@ std::vector fused_attn_fwd( // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(o_python); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = - (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false) - : rng_state; - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } + auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { output_tensors.push_back(py::cast(output_tensor)); NVTEBasicTensor temp_data = {output_tensor.data_ptr(), nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); + }; + // allocate memory for nvte_aux_tensor_pack.tensors + // f16_max512 : S [b, h, sq, skv] + // f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + size_t i = 0; + at::Tensor output_tensor; + // intermediate softmax tensor, S or M + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + // fp8 has an additional softmax stats tensor, ZInv + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + } + // rng_state + if (i < nvte_aux_tensor_pack.size) { + set_tensor_param(i++, rng_state); + } + // bias (optional) + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + set_tensor_param(i++, Bias.value()); + } + // softmax_offset (optional) + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + set_tensor_param(i++, SoftmaxOffset.value()); } // execute the kernel NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -274,9 +287,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -499,6 +513,15 @@ std::vector fused_attn_bwd( } } + // create dSoftmaxOffset in the same shape as SoftmaxOffset + at::Tensor dSoftmaxOffset; + TensorWrapper te_dSoftmaxOffset; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA); + dSoftmaxOffset = torch::empty({1, static_cast(h_q), 1, 1}, options); + te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset); + } + // create workspace TensorWrapper workspace; @@ -507,10 +530,10 @@ std::vector fused_attn_bwd( nvte_fused_attn_bwd( te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -523,16 +546,16 @@ std::vector fused_attn_bwd( nvte_fused_attn_bwd( te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - return {py_dQ, py_dK, py_dV, py::cast(dBias)}; + return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)}; } at::Tensor fa_prepare_fwd(at::Tensor qkvi) { diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0f2e3c4de1..70366dabe5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -966,12 +966,13 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: return dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) + if not self.allow_different_data_and_param_types: + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) self.activation_dtype = dtype def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -1060,6 +1061,7 @@ def prepare_forward( inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, ) -> Generator[torch.Tensor, None, None]: """Checks and prep for FWD. The context manager is needed because there isn't a way for a module to know @@ -1067,6 +1069,7 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ + self.allow_different_data_and_param_types = allow_different_data_and_param_types self.forwarded_at_least_once = True # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 89e43f845c..8a032b2f55 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module): and `DotProductAttention` modules. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -306,6 +317,7 @@ def __init__( qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, qk_norm_before_rope: bool = False, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -362,6 +374,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.attn_input_format = attn_input_format + self.softmax_type = softmax_type self.name = name @@ -397,6 +410,7 @@ def __init__( "qkv_format": self.attn_input_format, "seq_length": seq_length, "micro_batch_size": micro_batch_size, + "softmax_type": self.softmax_type, } self.self_attention = MultiheadAttention( From 2db20a6f8218ee9c04044b5596a71ae4154d68d3 Mon Sep 17 00:00:00 2001 From: shengfangd Date: Tue, 23 Sep 2025 09:00:34 +0800 Subject: [PATCH 31/78] [QA] Add pytest xml report for all tests in qa folder that use pytest (#2169) * Add pytest xml report for debug unittest and onnx unittest, and remove the duplicated test line in qa/L0_pytorch_debug_unittest/test.sh --------- Signed-off-by: erindai --- qa/L0_pytorch_debug_unittest/test.sh | 19 ++++++++++--------- qa/L1_pytorch_distributed_unittest/test.sh | 4 ++-- qa/L1_pytorch_onnx_unittest/test.sh | 4 +++- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index b4bf0a0246..7f19dda670 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -7,6 +7,8 @@ : ${TE_PATH:=/opt/transformerengine} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" # Config with the dummy feature which prevents nvinspect from being disabled. # Nvinspect will be disabled if no feature is active. @@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 exit $FAIL diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 7f061d222a..19889946a6 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -47,9 +47,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 1486d50971..720aa79e25 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -7,5 +7,7 @@ pip3 install onnxruntime==1.20.1 pip3 install onnxruntime_extensions==0.13.0 : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py From a92a0ad294a750e9c3d26dc9677746daa94da8ee Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 23 Sep 2025 11:15:06 -0400 Subject: [PATCH 32/78] [JAX] Local-Amax for Current-Scaling (#2183) * Adding Amax Primitive and related args. Signed-off-by: Ming Huang * Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP. Signed-off-by: Ming Huang * Adding doc for Amax Primitive. Signed-off-by: Ming Huang * Fix the function name conflict. Signed-off-by: Ming Huang * Modification as feedback suggested. Signed-off-by: Ming Huang * Fix errors from lint. Signed-off-by: Ming Huang * Fix the wrong amax-scope in the bwd. Signed-off-by: Ming Huang * Added more description for amax-scope Signed-off-by: Ming Huang * Fix the wrong attribute name. Signed-off-by: Ming Huang * Keep dim for AmaxCalcuation. Signed-off-by: Ming Huang * Remove keepDim and add shardy_rule Signed-off-by: Ming Huang * Fix shardy_rule Signed-off-by: Ming Huang * Remove extra-collective bytes from ref_coll_count due to local amax. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Signed-off-by: Ming-Xu Huang Co-authored-by: Phuong Nguyen --- tests/jax/test_distributed_layernorm.py | 2 - .../jax/cpp_extensions/activation.py | 12 +- transformer_engine/jax/cpp_extensions/base.py | 17 ++- .../jax/cpp_extensions/normalization.py | 41 ++++- .../jax/cpp_extensions/quantization.py | 142 +++++++++++++++++- transformer_engine/jax/dense.py | 14 +- transformer_engine/jax/layernorm_mlp.py | 8 +- 7 files changed, 213 insertions(+), 23 deletions(-) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index a777e2f432..f3296277c8 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -76,8 +76,6 @@ def generate_collectives_count_ref( all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize ) other_bytes = 0 - if fp8_recipe == recipe.Float8CurrentScaling(): - allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction return generate_collectives_count( allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d0a4e58fb6..9499b16246 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -26,7 +26,7 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_dbias, _quantize_dbias_impl +from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -979,6 +979,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -987,6 +988,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: If quantizer is None: @@ -1044,7 +1046,13 @@ def act_lu( activation_type=activation_type, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, + ) return out if isinstance(quantizer, DelayedScaleQuantizer): diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index cc8a07860a..96b73909e1 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -173,7 +173,7 @@ def shardy_sharding_rule(*args): _primitive_registry = {} -def register_primitive(cls): +def register_primitive(cls, outer_only=False): """ Register a JAX primitive and add it to the internal registry. """ @@ -186,13 +186,14 @@ def register_primitive(cls): def name_of_wrapper_p(): return cls.name + "_wrapper" - inner_p = core.Primitive(cls.name) - dispatch.prim_requires_devices_during_lowering.add(inner_p) - inner_p.multiple_results = cls.multiple_results - inner_p.def_impl(partial(xla.apply_primitive, inner_p)) - inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="cuda") - cls.inner_primitive = inner_p + if not outer_only: + inner_p = core.Primitive(cls.name) + dispatch.prim_requires_devices_during_lowering.add(inner_p) + inner_p.multiple_results = cls.multiple_results + inner_p.def_impl(partial(xla.apply_primitive, inner_p)) + inner_p.def_abstract_eval(cls.abstract) + mlir.register_lowering(inner_p, cls.lowering, platform="cuda") + cls.inner_primitive = inner_p outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 351767e367..d265be398c 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -27,7 +27,7 @@ NamedSharding, get_cudnn_version, ) -from .quantization import _quantize_dbias_impl +from .quantization import _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -880,6 +880,7 @@ def layernorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: """Layer normalization forward pass with optional quantization. @@ -893,6 +894,7 @@ def layernorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -952,7 +954,13 @@ def layernorm_fwd( epsilon=epsilon, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, + ) return out, mu, rsigma is_2x2x = quantizer.is_2x2x() @@ -1082,6 +1090,7 @@ def rmsnorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: """Root mean square normalization forward pass with optional quantization. @@ -1093,6 +1102,7 @@ def rmsnorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -1153,7 +1163,11 @@ def rmsnorm_fwd( quantizer=None, ) out, _ = _quantize_dbias_impl( - out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + out.data, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, ) return out, rsigma @@ -1278,6 +1292,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ): """Common wrapper for normalization forward pass. @@ -1294,6 +1309,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -1311,12 +1327,27 @@ def normalization_fwd( zero_centered_gamma is not supported if norm_type is 'rmsnorm'. """ if norm_type == "layernorm": - output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) + output, mu, rsigma = layernorm_fwd( + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + ) elif norm_type == "rmsnorm": assert ( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" - output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer) + output, rsigma = rmsnorm_fwd( + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + ) mu = None else: raise ValueError(f"{norm_type=} is not supported.") diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 895913d0ac..98b9b7e785 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,6 +6,8 @@ from functools import reduce from typing import Tuple, Optional, Union import math +from enum import Enum + import jax import jax.numpy as jnp @@ -26,7 +28,12 @@ get_min_device_compute_capability, NamedSharding, ) -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +from ..sharding import ( + all_reduce_max_along_all_axes_except_PP, + all_reduce_sum_along_dp_fsdp, + global_mesh_resource, + lax_paral_op, +) from ..quantize import ( ScaledTensor2x, ScaledTensor, @@ -526,6 +533,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive): """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" +class AmaxScope(Enum): + """ + Amax Scope Enum + """ + + LOCAL = 1 + TPSP = 2 + FSDP = 3 + + +class AmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning + """ + + name = "jax_local_amax" + multiple_results = False + impl_static_args = (1,) # amax_scope + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + ): + """ + amax calcuation abstract + """ + del amax_scope + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + return out_aval + + @staticmethod + def impl( + x, + amax_scope, + ): + """ + amax calcuation implementation + """ + del amax_scope + amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) + return amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del (amax_scope, arg_infos, result_infos) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding + + @staticmethod + def partition( + amax_scope, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + + def sharded_impl(x): + amax = AmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + ) + if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP + gmesh = global_mesh_resource() + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh) + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + + if amax_scope is AmaxScope.FSDP: # Run AR across FSDP + gmesh = global_mesh_resource() + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + + return amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, amax_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(amax_scope, mesh, value_types, result_types): + """ + amax calcuation shardy_sharding_rule + """ + del amax_scope, mesh, result_types + prefix = "AmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_spec = (f"{prefix}_amax",) + return SdyShardingRule((input_spec,), (output_spec,)) + + +register_primitive(AmaxCalculationPrimitive, outer_only=True) + + def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): @@ -572,6 +699,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -628,7 +756,10 @@ def _quantize_dbias_impl( # until the tensor is dequantized (e.g. in the GEMM). amax = x.amax if amax is None: - amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) + amax = AmaxCalculationPrimitive.outer_primitive.bind( + x.data, + amax_scope=amax_scope, + ) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale @@ -700,6 +831,7 @@ def quantize( x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -710,6 +842,7 @@ def quantize( flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. is None. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A ScaledTensor containing the quantized input tensor. @@ -718,6 +851,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + amax_scope=amax_scope, ) return out @@ -727,6 +861,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -737,6 +872,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + Returns: A tuple containing: @@ -750,6 +887,7 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + amax_scope=amax_scope, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8087159a3a..dd7f5e0e84 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -15,6 +15,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .quantize import ( ScaledTensorFactory, ScalingMode, @@ -64,6 +65,7 @@ def dense( input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, quantizer_set: QuantizerSet = noop_quantizer_set, + using_global_amax_of_x: bool = False, ): """Perform dense layer transformation with optional quantization. @@ -77,6 +79,7 @@ def dense( bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract quantizer_set: QuantizerSet which contains quantizers for different tensor types + using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. Returns: Transformed output tensor @@ -93,6 +96,7 @@ def dense( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ) return output @@ -103,6 +107,7 @@ def dense( 3, 4, 5, + 7, ), ) def _dense( @@ -113,6 +118,7 @@ def _dense( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ): """Internal implementation of dense layer transformation with custom VJP. @@ -127,6 +133,7 @@ def _dense( input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types + using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. Returns: Transformed output tensor @@ -139,6 +146,7 @@ def _dense( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ) return output @@ -151,6 +159,7 @@ def _dense_fwd_rule( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ): """Forward pass rule for dense layer transformation. @@ -175,6 +184,7 @@ def _dense_fwd_rule( x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, + amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -182,6 +192,7 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -212,7 +223,7 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad + contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. @@ -238,6 +249,7 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, + amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP, ) # GEMM NT diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801af..e3eaa53e1d 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,6 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -272,13 +273,12 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + amax_scope=AmaxScope.TPSP, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, - flatten_axis=-2, - quantizer=ffn1_quantizer_set.kernel, + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP ) # NN GEMM @@ -317,6 +317,7 @@ def _layernorm_mlp_fwd_rule( casted_kernel_2 = tex.quantize( kernel_2, quantizer=ffn2_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) # NN GEMM @@ -417,6 +418,7 @@ def _layernorm_mlp_bwd_rule( grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From 3f875fb57fcf2872d238f8c7cb199b171c424536 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Sep 2025 15:10:46 -0400 Subject: [PATCH 33/78] [JAX] Restore Shardy Rule with CompoundFactor (#2167) * Rework shardy rules * WAR for compound factor=1 Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 34 +++--- transformer_engine/jax/cpp_extensions/gemm.py | 11 +- .../jax/cpp_extensions/normalization.py | 5 +- .../jax/cpp_extensions/quantization.py | 5 +- .../jax/quantize/scaling_modes.py | 106 ++++++++++-------- 5 files changed, 90 insertions(+), 71 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 9499b16246..a8c14a6087 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -410,27 +410,28 @@ def shardy_sharding_rule( result_types, ): del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - prefix = "ActLuPrimitive_" - x_rank = len(value_types[0].shape) + prefix = "ActLu_" + input_shape = value_types[0].shape + output_shape = input_shape[:-2] + input_shape[-1:] + # Here we pass len of output so that the scales are propagated correctly scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 + output_shape, unique_var=prefix + "x", flatten_axis=-1 ) - x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) - out = (*x_axes[:-2], x_axes[-1]) - scale_inv = scale_rules.rowwise_rule + x_axes = scale_rules.input_spec + # Correct input spec with act dim + x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:] + out = scale_rules.input_spec colwise_out = (prefix + "out_colwise",) colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple( - multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) - ) + colwise_out = multidim_transpose(out, transpose_axis=-1) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule - # amax is always a unit tensor. amax = (prefix + "amax",) return SdyShardingRule( @@ -438,7 +439,8 @@ def shardy_sharding_rule( x_axes, ("…1",), ), - (out, colwise_out, scale_inv, colwise_scale_inv, amax), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), + **scale_rules.factor_sizes, ) @@ -883,26 +885,30 @@ def shardy_sharding_rule( result_types, ): del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - prefix = "BaseDActLuDBiasQuantizePrimitive_" + prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 + value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 ) x_axes = scale_rules.input_spec dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes + colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) return SdyShardingRule( (dz_axes, x_axes, ("…2",)), - (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2acc3fb68c..118000be7a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -712,7 +712,7 @@ def shardy_sharding_rule( del out_dtype, grad, use_split_accumulator del mesh, result_types - prefix = "GemmPrimitive_" + prefix = "Gemm_" warnings.warn( "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," @@ -746,13 +746,8 @@ def _generate_operand_rules(name, ndim, cdims): lhs_scale_specs = ("…1",) rhs_scale_specs = ("…2",) if scaling_mode.is_1d_block_scaling(): - # Shardy rules for MXFP8 scales cannot be related to the operands because of the - # global-unpadding and local-padding workflow. This can potentially insert expensive - # re-shards in the partition call later if the scales are not already sharded correctly. - lhs_scale_specs, rhs_scale_specs = map( - lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), - (lhs_specs, rhs_specs), - ) + lhs_scale_specs = lhs_specs + rhs_scale_specs = rhs_specs lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d265be398c..3348c725be 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -581,9 +581,9 @@ def shardy_sharding_rule( result_types, ) - prefix = "NormFwdPrimitive_" + prefix = "NormFwd_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 + value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec @@ -604,6 +604,7 @@ def shardy_sharding_rule( mu, rsigma, ), + **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 98b9b7e785..021af4c9db 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -495,9 +495,9 @@ def shardy_sharding_rule( ): del out_dtype, scale_dtype, is_outer, mesh, result_types - prefix = "BaseDBiasQuantizePrimitive_" + prefix = "DBiasQuantize_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), + value_types[0].shape, unique_var=prefix + "x", flatten_axis=flatten_axis, ) @@ -519,6 +519,7 @@ def shardy_sharding_rule( return SdyShardingRule( (x_axes, ("…1",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index e81a614f0e..b7828e9315 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -17,7 +17,7 @@ import operator import numpy as np -from jax.experimental.custom_partitioning import BATCHING +from jax.experimental.custom_partitioning import BATCHING, CompoundFactor from jax.tree_util import register_pytree_node_class import jax.numpy as jnp @@ -152,12 +152,15 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: @abstractmethod def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. @@ -232,12 +235,15 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. @@ -245,7 +251,7 @@ def get_shardy_sharding_rules( The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -323,20 +329,23 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization Returns: The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -562,52 +571,55 @@ def get_grouped_scale_shape( return (n_block_x * n_block_y,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization Returns: The Shardy rules for the scaling mode """ - del flatten_axis - input_spec = [f"{unique_var}{i}" for i in range(input_rank)] - rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] - colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] - - # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. - # Unfortunately, because Shardy rules are applied to the inner primitive, the - # only way to preserve the relationship is to lower unpadded scales to the - # underlying custom call and pad them in C++. Until that's implemented, the - # Shardy rules for block scales have to be completely disconnected from the - # Shardy rules for the tensor they belong to. - - # # We have to use two different factors in the two CompoundFactors because of Shardy - # # verifier requirements, even though they are the same. - # rowwise_var = unique_var - # colwise_var = f"{unique_var}_" - # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") - # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") - - # # The rowwise and colwise scale tensors should be sharded the same way as the input. - # # However, we need to adjust the dimensions where the block scaling factor applies. - # rowwise = input_spec.copy() - # rowwise[-1] = rowwise_var - - # colwise = input_spec.copy() - # colwise[flatten_axis - 1] = colwise_var - - # # This implementation needs to be updated for different block dims. - # assert self._block_dims == (1, 32) + input_rank = len(input_shape) + input_spec = [f"{unique_var}_{i}" for i in range(input_rank)] + flatten_axis = (flatten_axis + input_rank) % input_rank + + # This implementation needs to be updated for different block dims. + assert self._block_dims == (1, 32) + + # We have to use two different factors in the two CompoundFactors because of Shardy + # verifier requirements, even though they are the same. + blocksizes = {} + colwise_var = f"{unique_var}_None" + rowwise_var = f"{unique_var}_None" + if not input_shape[-1] == 32: + rowwise_var = input_spec[-1] + "_compound" + input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") + blocksizes["blocksize_x"] = 32 + if not input_shape[flatten_axis - 1] == 32: + colwise_var = input_spec[flatten_axis - 1] + "_compound" + input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") + blocksizes["blocksize_y"] = 32 + + # The rowwise and colwise scale tensors should be sharded the same way as the input. + # However, we need to adjust the dimensions where the block scaling factor applies. + rowwise = input_spec.copy() + rowwise[-1] = rowwise_var + + colwise = input_spec.copy() + colwise[flatten_axis - 1] = colwise_var return QuantizeShardyRules( tuple(input_spec), tuple(rowwise), tuple(colwise), - {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, + blocksizes, ) @@ -697,18 +709,22 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return self._get_impl().get_quantize_layout(usage) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis=-1 + self, + input_shape, + unique_var, + flatten_axis=-1, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 From afd15a16891fdc5d0f3efeb21e44ab15b54634c2 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Sep 2025 15:52:54 -0400 Subject: [PATCH 34/78] [JAX] Update JAX version requirement in pyproject.toml (#2197) update jax requirements Signed-off-by: Phuong Nguyen --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ef112d2798..64ff4c5cea 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,8 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", +"torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" From 7933781da84f26bd0029467ee6c1e68aa18ce47e Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 26 Sep 2025 01:03:45 -0700 Subject: [PATCH 35/78] temp fix to enable --overlap-grad-reduce Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9966c78f8..94d83fd638 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -459,7 +459,8 @@ def handle_custom_ddp_from_mcore(weight, wgrad): list(weight.main_grad.shape), weight.dtype, ) - elif ctx.fuse_wgrad_accumulation: + # TODO: Need to check why weight doesn't have attr grad_added_to_main_grad when fine_grained_activation_offloading is True. + elif ctx.fuse_wgrad_accumulation and not ctx.fine_grained_activation_offloading: wgrad = None else: wgrad = None From 9e727966f4505d6740372572f89facd9d01f4c40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 26 Sep 2025 20:26:29 +0200 Subject: [PATCH 36/78] [PyTorch] Unpin version of onnxscript and onnxruntime (#2202) * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski --- build_tools/pytorch.py | 2 +- qa/L1_pytorch_onnx_unittest/test.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 33a3abfb7e..a974e370d7 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,7 +14,7 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"] + return ["torch>=2.1", "einops", "onnxscript", "onnx"] def test_requirements() -> List[str]: diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 720aa79e25..7fce13a3dc 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -3,8 +3,8 @@ # See LICENSE for license information. -pip3 install onnxruntime==1.20.1 -pip3 install onnxruntime_extensions==0.13.0 +pip3 install onnxruntime +pip3 install onnxruntime_extensions : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} From 4d1457865847a83bb3b4582149188160fedddf98 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 26 Sep 2025 22:39:21 -0400 Subject: [PATCH 37/78] [JAX] Fix XML filename in the L0_jax_uniitest (#2205) fix xml file name Signed-off-by: Phuong Nguyen --- qa/L0_jax_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index e4a3f4630e..cb097d492a 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" From d75bf43f2e6fdc01afdf96a91b09245dc3c4987f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 27 Sep 2025 12:45:24 -0400 Subject: [PATCH 38/78] [JAX] CollectiveGemm (#2166) * init cgemm + unit tests * UB bootstrap with NCCL, no MPI dependency * add NVLINK-P2P check + error message * skip tests if no NVLINK available * use std::vector to store ncclComm_t * update misuse of TP warning Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 1 + examples/jax/collective_gemm/common.py | 245 ++++++++++ examples/jax/collective_gemm/conftest.py | 29 ++ .../jax/collective_gemm/run_test_cgemm.sh | 111 +++++ .../jax/collective_gemm/test_dense_grad.py | 214 ++++++++ examples/jax/collective_gemm/test_gemm.py | 206 ++++++++ .../test_layernorm_mlp_grad.py | 272 +++++++++++ qa/L0_jax_distributed_unittest/test.sh | 4 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 98 +++- .../userbuffers/userbuffers-host.cpp | 30 +- .../transformer_engine/comm_gemm_overlap.h | 25 + transformer_engine/common/util/logging.h | 10 + transformer_engine/jax/cpp_extensions/gemm.py | 458 ++++++++++++++++-- transformer_engine/jax/cpp_extensions/misc.py | 8 + transformer_engine/jax/csrc/extensions.h | 9 +- .../jax/csrc/extensions/cgemm_helper.cpp | 259 ++++++++++ .../jax/csrc/extensions/cgemm_helper.h | 189 ++++++++ .../jax/csrc/extensions/gemm.cpp | 140 +++++- transformer_engine/jax/csrc/extensions/misc.h | 26 + .../jax/csrc/extensions/pybind.cpp | 12 +- transformer_engine/jax/dense.py | 73 ++- transformer_engine/jax/flax/transformer.py | 1 + transformer_engine/jax/layernorm_mlp.py | 43 +- transformer_engine/jax/sharding.py | 19 + 24 files changed, 2385 insertions(+), 97 deletions(-) create mode 100644 examples/jax/collective_gemm/common.py create mode 100644 examples/jax/collective_gemm/conftest.py create mode 100644 examples/jax/collective_gemm/run_test_cgemm.sh create mode 100644 examples/jax/collective_gemm/test_dense_grad.py create mode 100644 examples/jax/collective_gemm/test_gemm.py create mode 100644 examples/jax/collective_gemm/test_layernorm_mlp_grad.py create mode 100644 transformer_engine/jax/csrc/extensions/cgemm_helper.cpp create mode 100644 transformer_engine/jax/csrc/extensions/cgemm_helper.h diff --git a/build_tools/jax.py b/build_tools/jax.py index 67efbf00fd..1f9552eb69 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -87,4 +87,5 @@ def setup_jax_extension( sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, + libraries=["nccl"], ) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py new file mode 100644 index 0000000000..da79b21377 --- /dev/null +++ b/examples/jax/collective_gemm/common.py @@ -0,0 +1,245 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Shared functions for the comm_overlap tests""" + +import jax.numpy as jnp +import numpy as np + + +# Add this after your existing imports +def dtype_tols(dtype, rtol=None, atol=None): + """Expected numerical tolerance for a data type.""" + # Return immediately if tolerances are fully specified + if rtol is not None and atol is not None: + return {"rtol": rtol, "atol": atol} + + # Default tolerances for common dtypes + if dtype in [jnp.float32, "float32"]: + return {"rtol": 1e-5, "atol": 1e-8} + elif dtype in [jnp.float16, "float16"]: + return {"rtol": 1e-3, "atol": 1e-6} + elif dtype in [jnp.bfloat16, "bfloat16"]: + return {"rtol": 1e-2, "atol": 1e-5} + else: + return {"rtol": 1e-5, "atol": 1e-8} + + +def assert_allclose( + actual, + desired, + rtol=None, + atol=None, + dtype=None, + **kwargs, +): + """Check if two tensors are close.""" + # Infer data type if needed + if dtype is None: + if isinstance(actual, float): + dtype = "float32" + else: + dtype = actual.dtype + + # Determine tolerances + tols = {} + if rtol is None or atol is None: + tols = dtype_tols(dtype) + if rtol is not None: + tols["rtol"] = rtol + if atol is not None: + tols["atol"] = atol + + # Cast tensors to fp32 + if not isinstance(actual, float): + actual = actual.astype(jnp.float32) + if not isinstance(desired, float): + desired = desired.astype(jnp.float32) + + # Check if tensors are close + np.testing.assert_allclose(actual, desired, **tols, **kwargs) + + +def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8): + if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol): + diff = jnp.abs(ref_output - gathered_output) + mask = diff > (atol + rtol * jnp.abs(gathered_output)) + print(mask.astype(int)) + print(jnp.where(mask, diff, 0)) + + +# Shared constants for all tests +DP_AXIS = "data" +TPSP_AXIS = "tensor_sequence" +PARAMS_KEY = "params" + +# Shared functions for distributed testing +import argparse +import jax +from jax.experimental import mesh_utils +from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap + +# Global flag to track if distributed has been initialized +_distributed_initialized = False + + +def _is_distributed_initialized(): + """Check if JAX distributed has been initialized.""" + return _distributed_initialized + + +def _initialize_distributed(args): + """Initialize JAX distributed with custom arguments.""" + global _distributed_initialized + + # Check if already initialized + if _distributed_initialized: + return + + if args.coordinator_address is None or args.num_processes is None or args.process_id is None: + raise ValueError( + "All distributed initialization arguments are required: " + "--coordinator-address, --num-processes, --process-id" + ) + if args.local_device_ids is None: + assert ( + args.num_devices_per_process is not None + ), "Either local_device_ids or num_devices_per_process must be provided" + # Calculate device range for this process + # Single process single device: each process gets one unique device + # Single process multiple devices: each process gets a unique range of devices + start_device = args.process_id * args.num_devices_per_process + device_range = range(start_device, start_device + args.num_devices_per_process) + global_device_ids_for_this_process = ",".join(map(str, device_range)) + else: + # Use explicitly provided global device IDs + global_device_ids_for_this_process = args.local_device_ids + args.num_devices_per_process = len(args.local_device_ids.split(",")) + + assert args.num_devices_per_process == 1, "Only single process single GPU is supported!" + + print( + f"Initializing JAX distributed with coordinator={args.coordinator_address}, " + f"num_processes={args.num_processes}, process_id={args.process_id}" + ) + # Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process" + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=global_device_ids_for_this_process, + ) + + _distributed_initialized = True + jax.clear_caches() + jax.config.update( + "jax_use_shardy_partitioner", False + ) # CollectiveGEMM does not work with Shardy yet + + assert jax.local_device_count() == 1, ( + f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" + f" {jax.local_device_count()}" + ) + + devices_per_process = 1 + num_total_devices = args.num_processes + + print( + f"Initializing CGEMM communicator with num_total_devices={num_total_devices}," + f" devices_per_process={devices_per_process}, process_id={args.process_id}" + ) + + collective_gemm_bootstrap( + num_total_devices=num_total_devices, + num_devices_per_process=devices_per_process, + process_id=args.process_id, + tensor_parallel_size=args.tensor_parallel_size, + ) + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +def _create_mesh(args): + """Create mesh configuration with proper validation.""" + num_gpu = args.num_processes * args.num_devices_per_process + assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices" + num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args) + + print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)") + + device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) + mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS)) + return mesh + + +def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): + """Create common argument parser for all collective GEMM tests.""" + parser = argparse.ArgumentParser(description=description) + + # Distributed initialization arguments + parser.add_argument( + "--coordinator-address", + type=str, + default=None, + help="Coordinator address for distributed initialization", + ) + parser.add_argument( + "--num-processes", + type=int, + default=None, + help="Number of processes for distributed initialization", + ) + parser.add_argument( + "--process-id", type=int, default=None, help="Process ID for distributed initialization" + ) + parser.add_argument( + "--local-device-ids", + type=str, + default=None, + help="Local device IDs for distributed initialization (comma-separated)", + ) + parser.add_argument( + "--num-devices-per-process", type=int, default=1, help="Number of devices per process" + ) + + # Test configuration arguments + parser.add_argument( + "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" + ) + parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") + parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") + parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") + parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") + parser.add_argument( + "--collective-type", + type=str, + default="all_gather", + choices=["all_gather", "reduce_scatter"], + help="Type of collective operation", + ) + parser.add_argument( + "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + ) + parser.add_argument( + "--enable-data-parallel", action="store_true", help="Enable data parallelism" + ) + parser.add_argument( + "--enable-result-check", action="store_true", default=True, help="Enable result checking" + ) + + return parser diff --git a/examples/jax/collective_gemm/conftest.py b/examples/jax/collective_gemm/conftest.py new file mode 100644 index 0000000000..83937971a4 --- /dev/null +++ b/examples/jax/collective_gemm/conftest.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""config for collective_gemm tests""" +import pytest + + +def pytest_addoption(parser): + """Pytest hook for collective_gemm tests""" + parser.addoption("--coordinator-address", action="store", default="localhost:12345") + parser.addoption("--num-processes", action="store", default=1) + parser.addoption("--process-id", action="store", default=0) + parser.addoption("--local-device-ids", action="store", default=None) + + +@pytest.fixture(autouse=True) +def distributed_args(request): + """Fixture for querying distributed initialization arguments""" + if request.cls: + request.cls.coordinator_address = request.config.getoption("--coordinator-address") + request.cls.num_processes = int(request.config.getoption("--num-processes")) + request.cls.process_id = int(request.config.getoption("--process-id")) + request.cls.local_device_ids = request.config.getoption("--local-device-ids") + request.cls.num_devices_per_process = ( + 1 + if request.cls.local_device_ids is None + else len(request.cls.local_device_ids.split(",")) + ) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh new file mode 100644 index 0000000000..5bf7ccb59a --- /dev/null +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -0,0 +1,111 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +# Check if NVLINK is supported before running tests +echo "*** Checking NVLINK support***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? + +# Check if command failed OR output indicates no NVLINK +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform" + echo "Collective GEMM tests require NVLINK connectivity" + echo "SKIPPING all tests" + exit 0 +else + echo "NVLINK support detected" +fi + +# Define the test files to run +TEST_FILES=( +"test_gemm.py" +"test_dense_grad.py" +"test_layernorm_mlp_grad.py" +) + +echo +echo "*** Executing tests in examples/jax/collective_gemm/ ***" + +HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM + +# Run each test file across all GPUs +for TEST_FILE in "${TEST_FILES[@]}"; do + echo + echo "=== Starting test file: $TEST_FILE ..." + + # Clear PIDs array for this test file + PIDS=() + + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_FILE}_gpu_${i}.log" + + if [ $i -eq 0 ]; then + # For process 0: show live output AND save to log file using tee + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + # For other processes: redirect to log files only + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done + + # Wait for all processes to finish + wait + + # Check and print the log content from process 0 (now has log file thanks to tee) + if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE SKIPPED" + elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE FAILED" + HAS_FAILURE=1 + else + echo "... $TEST_FILE PASSED" + fi + + # Remove the log files after processing them + wait + rm ${TEST_FILE}_gpu_*.log +done + +wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + +exit $HAS_FAILURE diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py new file mode 100644 index 0000000000..df2dd5618d --- /dev/null +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.dense import dense + +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOp, + CollectiveOpSet, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(collective_op): + if collective_op.is_all_gather: + input_axes = (DP_AXIS, TPSP_AXIS, None) + weight_axes = (None, TPSP_AXIS) + bias_axes = (TPSP_AXIS,) + output_axes = (DP_AXIS, None, TPSP_AXIS) + else: # RS + input_axes = (DP_AXIS, None, TPSP_AXIS) + weight_axes = (TPSP_AXIS, None) + bias_axes = (None,) + output_axes = (DP_AXIS, TPSP_AXIS, None) + return input_axes, weight_axes, bias_axes, output_axes + + +def _get_operand_sharding(mesh, collective_op): + input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes)) + weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes)) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes)) + return x_sharding, weight_sharding, bias_sharding + + +def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + output = dense( + x, + weight, + bias, + contracting_dims=((2,), (0,)), + input_axes=input_axes, + kernel_axes=weight_axes, + output_axes=output_axes, + collective_op_set=collective_op_set, + ) + return jnp.mean(output.astype(jnp.float32)) + + +def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + ) + + +def run_dense_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + + x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op) + ref_output, ref_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + noop_collective_op_set, + ) + output, sharded_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + collective_op_set, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveDenseGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather(self): + """Test Collective Dense Gradient with AllGather""" + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter(self): + """Test Collective Dense Gradient with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_dense_grad.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_dense_grad_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py new file mode 100644 index 0000000000..307e4444e7 --- /dev/null +++ b/examples/jax/collective_gemm/test_gemm.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective GEMM test on multi-GPU with tensor parallelism + +This script uses custom distributed initialization with the following arguments: +- --coordinator-address: Coordinator address for distributed initialization +- --num-processes: Number of processes for distributed initialization +- --process-id: Process ID for distributed initialization +- --local-device-ids: Local device IDs for distributed initialization + +Example: + python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3 +""" +import unittest +import os +from functools import partial + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp +from transformer_engine.jax.sharding import MeshResource + + +def _get_operand_sharding(mesh, collective_op, is_with_dp): + + dp_axis = DP_AXIS if is_with_dp else None + if collective_op == CollectiveOp.ALL_GATHER: + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS)) + bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + else: # RS + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(None)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + + return x_sharding, weight_sharding, bias_sharding, output_sharding + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) +def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): + output = tex.gemm( + x, + weight, + bias=bias, + contracting_dims=contracting_dims, + collective_op=collective_op, + ) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return output + + +def run_gemm_tests(args, mesh=None): + """Execute GEMM tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + print(f"Device mesh: {mesh}") + + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( + mesh, collective_op, args.enable_data_parallel + ) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + ref_output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=CollectiveOp.NONE, + output_sharding=output_sharding, + ) + output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=collective_op, + # CollectiveGEMM output should have a correct sharding without applying sharding constraint + output_sharding=None, + ) + assert ( + ref_output.sharding == output.sharding + ), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}" + gathered_ref_output = jax.lax.with_sharding_constraint( + ref_output, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_output = jax.lax.with_sharding_constraint( + output, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(gathered_ref_output) + jax.block_until_ready(gathered_output) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(gathered_ref_output, gathered_output) + + +class TestCollectiveGemmWithDP(unittest.TestCase): + """Collective GEMM with DP unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective GEMM test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather_with_dp(self): + """Test Collective GEMM with AllGather""" + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter_with_dp(self): + """Test Collective GEMM with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 5: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_gemm.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + sys.exit(1) + + args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args() + _initialize_distributed(args) + run_gemm_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py new file mode 100644 index 0000000000..7bd6eb6a30 --- /dev/null +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -0,0 +1,272 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.layernorm_mlp import layernorm_mlp + +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOpSet, + CollectiveOp, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(): + input_1_axes = (DP_AXIS, TPSP_AXIS, None) + weight_1_axes = (None, None, TPSP_AXIS) + bias_axes_1 = (None, TPSP_AXIS) + input_2_axes = (DP_AXIS, None, TPSP_AXIS) + weight_2_axes = (TPSP_AXIS, None) + bias_axes_2 = (None,) + return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 + + +def _get_operand_sharding(mesh): + input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = ( + _get_logical_axes() + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes)) + weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes)) + bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1)) + weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes)) + bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2)) + return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding + + +def _mean_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + output = layernorm_mlp( + x, + gamma, + beta=None, + kernels=[weight_1, weight_2], + biases=[bias_1, bias_2], + norm_type="rmsnorm", + dot_1_input_axes=input_1_axes, + dot_2_input_axes=input_2_axes, + kernel_1_axes=weight_1_axes, + kernel_2_axes=weight_2_axes, + activation_type=("gelu",), + collective_op_sets=collective_op_sets, + ) + return jnp.mean(output) + + +def _value_and_grad_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + return jax.jit( + jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) + )( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + + +def run_layernorm_mlp_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split( + rng, 7 + ) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight_1 = jax.random.normal( + weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_in) + bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16) + weight_2 = jax.random.normal( + weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_out) + bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16) + gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt( + args.hidden_in + ) + collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER) + collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER) + collective_op_sets = (collective_op_set_1, collective_op_set_2) + noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = ( + _get_operand_sharding(mesh) + ) + x_sharded = jax.device_put(x, x_sharding) + weight_1_sharded = jax.device_put(weight_1, weight_1_sharding) + bias_1_sharded = jax.device_put(bias_1, bias_1_sharding) + weight_2_sharded = jax.device_put(weight_2, weight_2_sharding) + bias_2_sharded = jax.device_put(bias_2, bias_2_sharding) + + input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes() + ref_output, ref_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + noop_collective_op_sets, + ) + output, sharded_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveLayerNormMLPGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_layernorm_mlp_grad(self): + """Test Collective Dense Gradient with AllGather""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_layernorm_mlp_grad.py --coordinator-address
" + " --num-processes --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_layernorm_mlp_grad_tests(args, mesh=None) diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index d9c46347fd..ae45f398e8 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -29,6 +29,10 @@ wait python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" wait TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +wait + +TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh" +wait if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index ec29e6e120..56369db27f 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl #endif _comm_created = true; } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); NVTE_CHECK_CUDA( @@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); @@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { } } +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, size_t chunk_id) { // Start with a chunk of the source tensor @@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 1ce89c512f..6c7bed55ac 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t), comm->comm_intra); + // Check for NVLINK support before attempting IPC operations + if (comm->nvsize > 1) { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + cudaDeviceProp deviceProp; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device)); + bool peer_access_available = false; + for (int i = 0; i < comm->nvsize; i++) { + if (i != comm->nvrank) { + int can_access_peer; + cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i); + if (peer_result == cudaSuccess && can_access_peer) { + peer_access_available = true; + break; + } + } + } + if (!peer_access_available) { + free(tmp); + NVTE_ERROR( + "No peer-to-peer access available between GPUs. This platform does not support the " + "GPU-to-GPU " + "communication required for multi-GPU userbuffers. Consider using single-GPU mode."); + return 1; + } + } + for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], cudaIpcMemLazyEnablePeerAccess)); } } @@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->mem_ptr[hndl] = *gpubuff; return comm->free_region++; + printf("***** Returning *****\n"); } diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 4d65e26ce8..cffc411a0d 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -67,6 +67,11 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python @@ -78,17 +83,26 @@ class CommOverlapCore { virtual ~CommOverlapCore(); + void *get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } @@ -148,6 +162,10 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python @@ -224,6 +242,10 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); + public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python @@ -237,6 +259,9 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 941899b28c..c2ce684c4e 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -12,6 +12,8 @@ #include #include +#include "nccl.h" + #ifdef NVTE_WITH_CUBLASMP #include #endif // NVTE_WITH_CUBLASMP @@ -104,4 +106,12 @@ #endif // NVTE_WITH_CUBLASMP +#define NVTE_CHECK_NCCL(expr) \ + do { \ + const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ + if (status_NVTE_CHECK_NCCL != ncclSuccess) { \ + NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ + } \ + } while (false) + #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 118000be7a..e5fcdac3c8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,8 +6,10 @@ import math import operator from collections.abc import Iterable -from typing import Tuple, Sequence, Union +from dataclasses import dataclass from functools import partial, reduce +from typing import Tuple, Sequence, Union +from enum import Enum import warnings import jax @@ -16,8 +18,13 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.custom_partitioning import SdyShardingRule -import transformer_engine_jax as tex -from transformer_engine_jax import get_num_compute_streams +from transformer_engine_jax import ( + get_num_compute_streams, + JAXX_Collective_Op, + get_device_compute_capability, + initialize_cgemm_communicator, + get_cgemm_num_max_streams, +) from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -37,11 +44,19 @@ is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, ) -from ..sharding import global_mesh_resource -from .misc import get_padded_spec +from .misc import get_padded_spec, is_all_reduce_in_float32 +from ..sharding import ( + global_mesh_resource, + tpsp_axis_size, + dp_or_fsdp_axis_size, +) __all__ = [ + "CollectiveOp", + "CollectiveOpSet", + "collective_gemm_bootstrap", + "noop_collective_op_set", "gemm", "grouped_gemm", "gemm_uses_jax_dot", @@ -56,7 +71,7 @@ def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if tex.get_device_compute_capability(0) >= 90: + if get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 @@ -152,6 +167,161 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +def collective_gemm_bootstrap( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams=3, + compute_stream_priority=0, + communication_stream_priority=0, + num_sm_for_communication=2, + use_ce=True, + aggregate_all_gather=False, +): + """Initialize NCCL communicators for Collective GEMM operations. + + This function sets up the distributed communication infrastructure needed for + tensor parallel collective GEMM operations. It supports two main scenarios: + + 1. **Multi-device per process**: TP domain = single process + - Each process manages multiple GPUs (num_devices_per_process > 1) + - TP group consists of GPUs within the same process + - Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4 + + 2. **Single device per process**: TP domain spans multiple processes + - Each process manages one GPU (num_devices_per_process = 1) + - TP group spans across multiple processes + - Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4 + + Args: + num_total_devices (int): Total number of ranks across all processes. + Must be divisible by num_devices_per_process. + num_devices_per_process (int): Number of GPUs per process. + - For multi-device: equals tp_size (e.g., 4 GPUs per process) + - For single-device: equals 1 (1 GPU per process) + process_id (int): Process identifier (0-based). + Must be in range [0, num_total_devices // num_devices_per_process). + tensor_parallel_size (int): Size of tensor parallel groups. + Must divide num_total_devices evenly. + num_max_streams (int, optional): Maximum number of CUDA streams for overlap. + Higher values enable more parallelism but use more GPU resources. Default: 3. + compute_stream_priority (int, optional): Priority for GEMM computation streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + communication_stream_priority (int, optional): Priority for NCCL communication streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + num_sm_for_communication (int, optional): Number of streaming multiprocessors + reserved for communication operations. Default: 2. + use_ce (bool, optional): Enable CUDA copy engines for memory transfers. + Can improve performance by offloading memory operations. Default: True. + aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations + into larger ones for better efficiency. Default: False. + + Raises: + AssertionError: If num_total_devices is not divisible by num_devices_per_process, + or if process_id is out of valid range. + AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now) + RuntimeError: If NCCL initialization fails or if configuration + is invalid (e.g., insufficient GPUs). + + Example: + # Basic initialization (single device per process) + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4 + ) + + # Advanced configuration with custom performance settings + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4, + num_max_streams=5, # More parallelism + compute_stream_priority=1, # Lower compute priority + communication_stream_priority=0, # Higher comm priority + num_sm_for_communication=4, # More SMs for communication + use_ce=True, # Enable copy engines + aggregate_all_gather=True # Aggregate small operations + ) + + Note: + This function must be called after JAX distributed initialization + and before any collective GEMM operations. Each process should call + this function with its own unique process_id. + """ + + assert ( + num_devices_per_process == 1 and jax.local_device_count() == 1 + ), "Only single device per process is supported at the moment!" + assert num_total_devices % num_devices_per_process == 0, ( + f"Invalid num_total_devices={num_total_devices}," + f" num_devices_per_process={num_devices_per_process}" + ) + assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + initialize_cgemm_communicator( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams, + compute_stream_priority, + communication_stream_priority, + num_sm_for_communication, + use_ce, + aggregate_all_gather, + ) + + +class CollectiveOp(Enum): + "Enum for Collective Type in Collective GEMM" + + NONE = JAXX_Collective_Op.NONE + ALL_GATHER = JAXX_Collective_Op.ALL_GATHER + REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER + + @property + def is_all_gather(self) -> bool: + """Check if AllGather""" + return self == CollectiveOp.ALL_GATHER + + @property + def is_reduce_scatter(self) -> bool: + """Check if ReduceScatter""" + return self == CollectiveOp.REDUCE_SCATTER + + @property + def is_none(self) -> bool: + """Check if None""" + return self == CollectiveOp.NONE + + +@dataclass(frozen=True) +class CollectiveOpSet: + """ + A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers. + """ + + forward: CollectiveOp + backward: CollectiveOp + + @staticmethod + def create(forward_collective_op: CollectiveOp): + """Create a set of CollectiveOp for forward and backward passes""" + if forward_collective_op.is_all_gather: + backward_collective_op = CollectiveOp.REDUCE_SCATTER + elif forward_collective_op.is_reduce_scatter: + backward_collective_op = CollectiveOp.ALL_GATHER + else: + backward_collective_op = CollectiveOp.NONE + return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op) + + +noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE) + + @partial(jax.jit, static_argnums=(1, 2)) def swizzled_scale(scale_inv, flatten_axis, is_colwise): "Swizzle scale_inv via JAX transpose ops" @@ -174,7 +344,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12) + impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 inner_primitive = None outer_primitive = None @@ -193,8 +363,12 @@ def abstract( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del use_split_accumulator + del use_split_accumulator, transpose_batch_sequence def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -238,7 +412,7 @@ def _dims_are_consecutive(dims): ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING - and not tex.is_non_nt_fp8_gemm_supported() + and not is_fp8_gemm_with_all_layouts_supported() ): assert not lhs_is_transposed and rhs_is_transposed, ( "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " @@ -263,6 +437,19 @@ def _dims_are_consecutive(dims): out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + # Adjust output shape for comm+GEMM overlap + if not collective_op.is_none and not is_outer: # Inner abstract + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + overlap_out_shape = list(out_shape).copy() + if collective_op.is_all_gather: + overlap_out_shape[1] *= tpsp_axis_size() + else: # RS + overlap_out_shape[sequence_dim] = ( + overlap_out_shape[sequence_dim] // tpsp_axis_size() + ) + assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}" + output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) + # Validate bias bias_shape = (0,) bias_dtype = out_dtype @@ -302,9 +489,12 @@ def _dims_are_consecutive(dims): pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) # Declare cuBLAS workspace + workspace_size = get_cublas_workspace_size_bytes() + if not collective_op.is_none: + workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. - workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace_size += 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return output, bias_grad, pre_gelu_out, workspace @@ -330,8 +520,12 @@ def lowering( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del out_dtype + del out_dtype, transpose_batch_sequence, sequence_dim, is_outer lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) @@ -350,6 +544,7 @@ def lowering( "fuse_gelu": fuse_gelu, "grad": grad, "use_split_accumulator": use_split_accumulator, + "collective_op": int(collective_op.value), } operand_output_aliases = {} @@ -378,6 +573,10 @@ def impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): if scaling_mode.is_1d_block_scaling(): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -396,7 +595,34 @@ def impl( lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) - outputs = GemmPrimitive.inner_primitive.bind( + # Alter lhs blocks so that CGEMM RS outputs correctly + if ( + collective_op.is_reduce_scatter + and not transpose_batch_sequence + and not is_outer + and not lhs.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = lhs.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = lhs.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + lhs = reordered.reshape(original_shape) + + (output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, rhs, @@ -410,8 +636,39 @@ def impl( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ) - return outputs[:-1] # discard workspace array + # Alter output blocks for CGEMM AG + if ( + collective_op.is_all_gather + and not transpose_batch_sequence + and not is_outer + and not output.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = output.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = output.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + output = reordered.reshape(original_shape) + + return [output, bias_grad, pre_gelu_out] @staticmethod def outer_impl( @@ -428,6 +685,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): return GemmPrimitive.impl( lhs, @@ -443,6 +704,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ) @staticmethod @@ -456,7 +721,12 @@ def batcher( fuse_gelu, grad, use_split_accumulator, + collective_op, + transpose_batch_sequence, + sequence_dim, + is_outer, ): + del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None lhs_bdims, _, rhs_bdims, *_ = batch_dims @@ -484,6 +754,10 @@ def batcher( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -492,6 +766,8 @@ def batcher( def _parse_operand_output_specs( arg_infos, contracting_dims, + transpose_batch_sequence, + collective_op, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -499,14 +775,12 @@ def _parse_operand_output_specs( # Ensure that tensor sequence parallelism is not used via setting tp_resource if gsr.tp_resource is not None: - for i in range(len(lhs_specs) - 1): - if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource: - warnings.warn( - "Tensor sequence parallelism is detected as" - f" tp_resource='{gsr.tp_resource}' appears twice consecutively in" - f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for" - " tensor sequence parallelism to avoid potential issues." - ) + if gsr.tp_resource in lhs_specs: + warnings.warn( + "Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'" + " appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource" + " for tensor sequence parallelism to avoid potential issues." + ) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) @@ -528,10 +802,43 @@ def _parse_operand_output_specs( assert reduce_spec is None, "Multiple reduce dimension is detected!" reduce_spec = l + sequence_dim = None + + # Find sequence dimension in lhs_specs if tensor sequence parallel is enabled + # We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim + if collective_op.is_all_gather: + try: + tpsp_idx = lhs_specs.index(gsr.tpsp_resource) + except ValueError as exc: + raise ValueError( + f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}." + " Please check your sharding configuration." + ) from exc + sequence_dim = tpsp_idx + assert (sequence_dim == 1) ^ transpose_batch_sequence, ( + "CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)" + " or (sequence_dim=0 and transpose_batch_sequence=True). Received:" + f" sequence_dim={sequence_dim}," + f" transpose_batch_sequence={transpose_batch_sequence}." + ) + + elif collective_op.is_reduce_scatter: + assert reduce_spec == gsr.tpsp_resource, ( + "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" + f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" + ) + sequence_dim = int(not transpose_batch_sequence) + if reduce_spec is not None: # Other non-reduce cdims (if exists) need to be unsharded lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) - rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) + # Only do AG Sequence dim if not Overlap + if collective_op.is_all_gather: + rhs_cspecs = tuple( + s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs + ) + else: + rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim. @@ -551,13 +858,31 @@ def _parse_operand_output_specs( for spec in rhs_non_cspecs ) - # Non-contracting dims of LHS to be gathered along the SP axis. - # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for - # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. - lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) + # Only do AG Sequence dim if not Overlap + if not collective_op.is_all_gather: + # Non-contracting dims of LHS to be gathered along the SP axis. + # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for + # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. + lhs_non_cspecs = tuple( + None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs + ) out_specs = lhs_non_cspecs + rhs_non_cspecs + # Only do AG Sequence dim if not Overlap RS + if collective_op.is_all_gather: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :] + elif collective_op.is_reduce_scatter: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = ( + out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :] + ) + # specs = merge(cspecs, non_cspecs) lhs_specs, rhs_specs = map( lambda cdims, cspecs, non_cspecs: ( @@ -572,10 +897,14 @@ def _parse_operand_output_specs( bias_specs = tuple(list(rhs_non_cspecs).copy()) gelu_specs = tuple(list(out_specs).copy()) + if not collective_op.is_none: + assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), reduce_spec, + sequence_dim, ) @staticmethod @@ -587,6 +916,10 @@ def infer_sharding_from_operands( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, @@ -595,11 +928,16 @@ def infer_sharding_from_operands( out_dtype, scaling_mode, grad, + use_split_accumulator, + result_infos, + is_outer, + sequence_dim, ) - del use_split_accumulator, result_infos - (_, (out_specs, dbias_specs, pre_gelu_specs), _) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs( + arg_infos, contracting_dims, transpose_batch_sequence, collective_op + ) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -624,20 +962,29 @@ def partition( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, ): - del result_infos + del result_infos, is_outer, sequence_dim ( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), reduce_spec, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + inferred_sequence_dim, + ) = GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + ) - # Assemble argument shardings - # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) @@ -686,11 +1033,19 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=inferred_sequence_dim, + is_outer=False, + collective_op=collective_op, ) - # All-Reduce GEMM output - if reduce_spec is not None: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + if reduce_spec is not None and not collective_op.is_reduce_scatter: + if is_all_reduce_in_float32(): # For unittest only + outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype( + out_dtype + ) + else: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) return outputs @@ -705,12 +1060,22 @@ def shardy_sharding_rule( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, operand_types, result_types, ): del out_dtype, grad, use_split_accumulator - del mesh, result_types + del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer + + if not collective_op.is_none: + raise NotImplementedError( + "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off" + " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false" + ) prefix = "Gemm_" @@ -792,6 +1157,8 @@ def _te_gemm( fuse_gelu: bool = False, grad: bool = False, use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -800,6 +1167,7 @@ def _te_gemm( lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -859,6 +1227,10 @@ def _te_gemm( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=-1, + is_outer=True, + collective_op=collective_op, ) @@ -1176,6 +1548,8 @@ def gemm( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, **kwargs, ) -> Tuple[jnp.ndarray, ...]: r"""General matrix multiplication with optional quantization. @@ -1209,8 +1583,11 @@ def gemm( TE's custom call to cuBLAS GEMM. use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only - supported with TE's custom call to cuBLAS GEMM. + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + transpose_batch_sequence: bool, default = False + Transpose the batch and sequence dimensions of the input tensor. + collective_op: CollectiveOp, default = CollectiveOp.NONE + Collective operation type for collective GEMM. Returns ------- @@ -1254,6 +1631,7 @@ def gemm( "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) + assert collective_op.is_none, "JAX GEMM does not support collective GEMM" return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) outputs = _te_gemm( @@ -1262,6 +1640,8 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, + transpose_batch_sequence=transpose_batch_sequence, + collective_op=collective_op, **kwargs, ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 3bda37128b..52f5edbf3a 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -293,3 +293,11 @@ def duplicate_with_new_description(self, desc: str): Create a new NamedSharding with the same mesh and spec but with a new description. """ return NamedSharding(self.mesh, self.spec, desc=desc) + + +@functools.lru_cache(maxsize=1) +def is_all_reduce_in_float32(): + """ + Check if all-reduce is in float32 + """ + return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1" diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 59079fe3f0..92937dd461 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -32,9 +33,6 @@ #include "transformer_engine/activation.h" #include "transformer_engine/multi_stream.h" -// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace -XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); - namespace transformer_engine { namespace jax { @@ -121,6 +119,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); @@ -134,4 +133,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); + #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp new file mode 100644 index 0000000000..7082bfb035 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -0,0 +1,259 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "cgemm_helper.h" + +#include "common/util/system.h" +#include "nccl.h" + +namespace transformer_engine { +namespace jax { + +ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &id_type) { + ncclUniqueId unique_id; + + int tp_domain_id = get_tp_domain_id(); + bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0); + + pid_t pgid = getpgid(0); + + std::string base_path = getenv("NVTE_JAX_NCCL_FILE_PATH", "/tmp"); + std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) + + "_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) + + "_domain_" + std::to_string(tp_domain_id) + ".bin"; + + if (is_tp_leader) { + NVTE_CHECK_NCCL(ncclGetUniqueId(&unique_id)); + + // Write the ID to a temporary file + std::ofstream file(id_file, std::ios::binary); + NVTE_CHECK(file.is_open(), "Failed to create NCCL unique ID file: ", id_file); + file.write(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + file.close(); + } else { + // Wait for the ID file to be created and read it + int attempts = 0; + const int max_attempts = 100; + while (attempts < max_attempts) { + std::ifstream file(id_file, std::ios::binary); + if (file.is_open()) { + file.read(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + if (file.gcount() == sizeof(ncclUniqueId)) { + file.close(); + break; + } + file.close(); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + attempts++; + } + NVTE_CHECK(attempts < max_attempts, + "Timeout waiting for " + id_type + " NCCL unique ID file from leader: ", id_file); + } + + if (is_tp_leader) { + _nccl_id_file_name.push_back(id_file); + } + + return unique_id; +} + +void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size) { + // Validate inputs + NVTE_CHECK(num_devices_per_process == 1, + "num_devices_per_process must be == 1, got num_devices_per_process=", + num_devices_per_process); + NVTE_CHECK(num_total_devices >= 1, + "num_total_devices must be >= 1, got num_total_devices=", num_total_devices); + NVTE_CHECK( + num_total_devices % num_devices_per_process == 0, + "num_total_devices must be divisible by num_devices_per_process, got num_total_devices=", + num_total_devices, ", num_devices_per_process=", num_devices_per_process); + + // Validate TP size + NVTE_CHECK(tp_size > 0, "tp_size must be > 0, got tp_size=", tp_size); + NVTE_CHECK(num_total_devices % tp_size == 0, + "num_total_devices must be divisible by tp_size, got num_total_devices=", + num_total_devices, ", tp_size=", tp_size); + + auto &handler = get(false); + handler.num_total_devices = num_total_devices; + handler.num_devices_per_process = num_devices_per_process; + handler.process_id = process_id; + handler.num_processes = num_total_devices / num_devices_per_process; + handler.tp_size = tp_size; + handler.tp_num_domains = num_total_devices / tp_size; + + // Initialize vectors with the correct size + handler.local_device_ids_within_process.resize(num_devices_per_process); + handler.local_device_ids_within_tp_domain.resize(num_devices_per_process); + handler.tp_domain_ids.resize(num_devices_per_process); + handler.global_device_ids.resize(num_devices_per_process); + handler.tp_comms.resize(num_devices_per_process); + + NVTE_CHECK(0 <= process_id && process_id < handler.num_processes, + "Invalid process_id=", process_id, ", which is out of range [0, ", + handler.num_processes, ")"); + + // Initialize local devices and calculate their global device IDs and TP topology + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + // Use the device that JAX has already assigned to this process + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + handler.local_device_ids_within_process[local_idx] = current_device; + handler.global_device_ids[local_idx] = process_id * num_devices_per_process + local_idx; + + // Calculate TP-related values for this device + int global_device_id = handler.global_device_ids[local_idx]; + if (num_devices_per_process == tp_size) { + // Scenario 1: Multi-device per process - TP domain = single process + handler.local_device_ids_within_tp_domain[local_idx] = local_idx; + handler.tp_domain_ids[local_idx] = process_id; + } else { + // Scenario 2: Single device per process - TP domain spans multiple processes + handler.local_device_ids_within_tp_domain[local_idx] = global_device_id % tp_size; + handler.tp_domain_ids[local_idx] = global_device_id / tp_size; + } + } + + ncclUniqueId tp_id = handler.coordinate_nccl_unique_id("tp"); + + NVTE_CHECK_NCCL(ncclGroupStart()); + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + NVTE_CHECK_CUDA(cudaSetDevice(handler.local_device_ids_within_process[local_idx])); + int tp_local_rank = handler.local_device_ids_within_tp_domain[local_idx]; + NVTE_CHECK_NCCL( + ncclCommInitRank(&handler.tp_comms[local_idx], handler.tp_size, tp_id, tp_local_rank)); + } + NVTE_CHECK_NCCL(ncclGroupEnd()); + + // Allocate device memory for barrier operations + NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int))); + + handler._initialize = true; + + // Bootstrap UB via creating a dummy CommOverlapP2PBase object + std::vector buffer_shape{1, 1}; + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, + JAXX_Collective_Op::ALL_GATHER); +} + +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag) { + auto &config = CgemmConfig::get(false); + config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); + auto &handler = CommunicatorHandler::get(false); + handler.init(num_total_devices, num_devices_per_process, process_id, tp_size); +} + +int GetCgemmNumMaxStreams() { + auto &config = CgemmConfig::get(); + return config.num_max_streams; +} + +CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector buffer_shape, + DType dtype, + JAXX_Collective_Op collective_op) { + auto &comm_handler = CommunicatorHandler::get(); + auto &cgemm_config = CgemmConfig::get(); + + int device_idx = comm_handler.get_local_device_idx_for_current_device(); + int64_t plan_id = 0; + hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast(dtype), + static_cast(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams, + cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, + cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx); + + auto it = plan_map.find(plan_id); + if (it != plan_map.end()) { + return it->second.get(); + } + + if (comm_handler.num_devices_per_process == comm_handler.tp_size) { + // Multi-device per process + } else if (comm_handler.num_devices_per_process == 1) { + // Single device per process + NVTE_CHECK(comm_handler.num_total_devices % comm_handler.tp_size == 0, + "For single device per process, num_total_devices must be divisible by tp_size, " + "got num_total_devices=", + comm_handler.num_total_devices, ", tp_size=", comm_handler.tp_size); + } else { + NVTE_ERROR("Unsupported TP configuration: num_devices_per_process=", + comm_handler.num_devices_per_process, ", tp_size=", comm_handler.tp_size, + ". Supported scenarios: " + "(1) num_devices_per_process == tp_size (multi-device per process), " + "(2) num_devices_per_process == 1 (single device per process)"); + } + + std::unique_ptr executor; + executor = std::make_unique( + buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, + comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), + cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, + cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, + cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + + CommOverlapCore *executor_ptr = executor.get(); + plan_map[plan_id] = std::move(executor); + return executor_ptr; +} + +void CommunicatorHandler::nccl_device_barrier_impl(ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using barrier"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + NVTE_CHECK_NCCL( + ncclAllReduce(_device_barrier, _device_barrier, 1, ncclInt, ncclSum, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +void CommunicatorHandler::nccl_allgather_impl(void *output_buf, size_t output_bytes, + void *input_buf, size_t input_bytes, ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using allgather"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + size_t expected_output_bytes = input_bytes * tp_size; + NVTE_CHECK(output_bytes == expected_output_bytes, "TP allgather buffer size mismatch: expected ", + expected_output_bytes, ", got ", output_bytes); + + NVTE_CHECK_NCCL(ncclAllGather(input_buf, output_buf, input_bytes, ncclChar, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +CommunicatorHandler::CommunicatorHandler() : _device_barrier(nullptr) { + allgather_func = [this](void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm comm) { + this->nccl_allgather_impl(output_buf, output_bytes, input_buf, input_bytes, comm); + }; + barrier_func = [this](ExtComm comm) { this->nccl_device_barrier_impl(comm); }; +} + +CommunicatorHandler::~CommunicatorHandler() { + if (_initialize && !tp_comms.empty()) { + for (auto &comm : tp_comms) { + if (comm != nullptr) { + ncclCommDestroy(comm); + } + } + } + if (_device_barrier) cudaFree(_device_barrier); + + for (const auto &file_path : _nccl_id_file_name) { + std::remove(file_path.c_str()); + } +} + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h new file mode 100644 index 0000000000..84b2b81540 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -0,0 +1,189 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ +#define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../extensions.h" +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "transformer_engine/comm_gemm_overlap.h" + +namespace transformer_engine { +namespace jax { + +// Configuration singleton for CGEMM parameters +class CgemmConfig { + public: + int num_max_streams; + int gemm_priority; + int comm_priority; + int num_comm_sm; + bool use_ce; + bool aggregate_ag; + + static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm, + bool _use_ce, bool _aggregate_ag) { + auto &config = get(false); + config._initialized = true; + config.num_max_streams = _num_max_streams; + config.gemm_priority = _gemm_priority; + config.comm_priority = _comm_priority; + config.num_comm_sm = _num_comm_sm; + config.use_ce = _use_ce; + config.aggregate_ag = _aggregate_ag; + } + + static CgemmConfig &get(bool is_initialized = true) { + static thread_local CgemmConfig instance; + NVTE_CHECK( + instance._initialized == is_initialized, + "CgemmConfig must be initialized before using it, got is_initialized=", is_initialized); + return instance; + } + + CgemmConfig(const CgemmConfig &) = delete; + CgemmConfig &operator=(const CgemmConfig &) = delete; + + private: + CgemmConfig() = default; + ~CgemmConfig() = default; + bool _initialized = false; +}; + +// Forward declaration +class CollectiveGemmPlanRegistry; + +// NCCL communicator handler for collective GEMM operations +// Support both single process single device AND single process multi device +// Two scenarios: +// 1. Single process multiple devices: TP domain = process (num_devices_per_process == tp_size) +// 2. Single process single device: TP domain spans processes (num_devices_per_process == 1) +class CommunicatorHandler { + public: + int num_total_devices = -1; + int num_devices_per_process = -1; + int process_id = -1; + int num_processes = -1; + + int tp_size = -1; + int tp_num_domains = -1; + std::vector local_device_ids_within_tp_domain; + std::vector tp_domain_ids; + std::vector tp_comms; + + std::vector local_device_ids_within_process; + std::vector global_device_ids; + + int get_global_rank() const { + int device_idx = get_local_device_idx_for_current_device(); + return global_device_ids[device_idx]; + } + + void nccl_device_barrier_impl(ExtComm); + void nccl_allgather_impl(void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm); + + ncclComm_t get_comm_for_current_device() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_comms[device_idx]; + } + + int get_local_device_idx_for_current_device() const { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + for (int i = 0; i < num_devices_per_process; i++) { + if (local_device_ids_within_process[i] == current_device) { + return i; + } + } + NVTE_ERROR("Current CUDA device ", current_device, + " not found in local_device_ids_within_process"); + } + + int get_local_device_id_within_tp_domain() const { + int device_idx = get_local_device_idx_for_current_device(); + return local_device_ids_within_tp_domain[device_idx]; + } + + int get_tp_domain_id() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_domain_ids[device_idx]; + } + + int get_tp_num_domains() const { return tp_num_domains; } + + static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size); + + private: + ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type); + + public: + static CommunicatorHandler &get(bool is_initialized = true) { + static CommunicatorHandler instance; + NVTE_CHECK(instance._initialize == is_initialized, + "CommunicatorHandler._initialize=", instance._initialize, + ", is_initialized=", is_initialized); + return instance; + } + + ExtAllgatherOp allgather_func; + ExtBarrierOp barrier_func; + + CommunicatorHandler(const CommunicatorHandler &) = delete; + CommunicatorHandler &operator=(const CommunicatorHandler &) = delete; + + private: + CommunicatorHandler(); + ~CommunicatorHandler(); + + bool _initialize = false; + int *_device_barrier = nullptr; + std::vector _nccl_id_file_name; +}; + +// Plan registry for caching collective GEMM executors +class CollectiveGemmPlanRegistry { + public: + static CollectiveGemmPlanRegistry &getInstance() { + static thread_local CollectiveGemmPlanRegistry instance; + return instance; + } + + CommOverlapCore *get_executor(std::vector buffer_shape, DType dtype, + JAXX_Collective_Op collective_op); + + private: + CollectiveGemmPlanRegistry() {} + CollectiveGemmPlanRegistry(const CollectiveGemmPlanRegistry &) = delete; + CollectiveGemmPlanRegistry &operator=(const CollectiveGemmPlanRegistry &) = delete; + + std::unordered_map> plan_map; +}; + +// Function declarations +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag); + +int GetCgemmNumMaxStreams(); + +} // namespace jax +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 06dded1d86..1467fa8873 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -6,13 +6,19 @@ #include "transformer_engine/gemm.h" #include +#include +#include #include #include #include "../extensions.h" +#include "cgemm_helper.h" +#include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/string.h" #include "common/util/system.h" +#include "cuda_runtime.h" +#include "nccl.h" #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -66,12 +72,75 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } +Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { + nvte_cublas_handle_init(); + + // Init UB buffer + if (collective_op != JAXX_Collective_Op::NONE) { + auto &comm_handler = CommunicatorHandler::get(); + std::vector lhs_shape = { + product(lhs.dimensions(), 0, lhs_axis_boundary), + product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())}; + std::vector rhs_shape = { + product(rhs.dimensions(), 0, rhs_axis_boundary), + product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())}; + + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + + std::vector buffer_shape{0, 0}; + DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + } + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype, + collective_op); + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, + FFI::Bind() + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator") + .Attr("collective_op")); + Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, + JAXX_Collective_Op collective_op) { // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || @@ -83,16 +152,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); - // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); - NVTE_CHECK(out_.numel() == output->element_count(), - "cuBLAS GEMM output buffer size is incorrect, " - "expected ", - out_.numel(), " elements ", to_string_like(out_shape), " but got ", - output->element_count(), " elements ", to_string_like(output->dimensions())); // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; @@ -133,9 +195,62 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + + if (collective_op == JAXX_Collective_Op::NONE) { + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", + to_string_like(out_shape), " but got ", output->element_count(), " elements ", + to_string_like(output->dimensions())); + + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); + } else { + std::vector buffer_shape{0, 0}; + DType buffer_dtype = out_dtype; + auto &comm_handler = CommunicatorHandler::get(); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + out_shape[0] = out_shape[0] * comm_handler.tp_size; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + out_shape[0] = out_shape[0] / comm_handler.tp_size; + } + auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, buffer_dtype, collective_op); + if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); + // Prepare the auxiliary buffer for the reduce-scattered GEMM output + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + + // Launch GEMM+RS + executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_, + pre_gelu_, workspace_, grad, false, use_split_accumulator, out_, + stream); + + } else if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + auto aux_out_ = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Empty + + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + // Copy the distributed LHS operand into the local chunk of the communication buffer + executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise); + // Launch AG+GEMM + executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, aux_out_, stream); + } + } return ffi_with_cuda_error_check(); } @@ -161,7 +276,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") - .Attr("use_split_accumulator"), + .Attr("use_split_accumulator") + .Attr("collective_op"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index af7f54feb6..c8fb713d7d 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -87,5 +87,31 @@ constexpr struct Alignment { std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +template +void hash_combine(int64_t &seed, const T &v, Rest... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} + +enum class JAXX_Collective_Op : int64_t { + NONE = 0, + ALL_GATHER = 1, + REDUCE_SCATTER = 2, +}; + +static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { + switch (op) { + case JAXX_Collective_Op::ALL_GATHER: + return CommOverlapType::AG; + break; + case JAXX_Collective_Op::REDUCE_SCATTER: + return CommOverlapType::RS; + break; + default: + NVTE_ERROR("Invalid Collective Op ", static_cast(op)); + break; + } +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index afbeb644c1..06e2e2e005 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -5,6 +5,8 @@ ************************************************************************/ #include "../extensions.h" +#include "cgemm_helper.h" +#include "common/util/cuda_runtime.h" namespace transformer_engine { namespace jax { @@ -57,7 +59,7 @@ pybind11::dict Registrations() { // GEMM dict["te_gemm_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler), pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM @@ -84,6 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); + m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); + m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) @@ -159,6 +163,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .export_values(); + + pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) + .value("NONE", JAXX_Collective_Op::NONE) + .value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER) + .value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER) + .export_values(); } } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index dd7f5e0e84..23df1a0ce2 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -11,6 +11,7 @@ from typing import Tuple, Sequence from functools import partial +import warnings import jax import jax.numpy as jnp @@ -62,10 +63,13 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + batch_sequence_transpose: bool = False, input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, - quantizer_set: QuantizerSet = noop_quantizer_set, + output_axes: Tuple[str, ...] = None, using_global_amax_of_x: bool = False, + collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, + quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -78,12 +82,20 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract - quantizer_set: QuantizerSet which contains quantizers for different tensor types + batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. + input_axes: Logical axes for sharding the activation input + kernel_axes: Logical axes for sharding the weight matrix + output_axes: Logical axes for sharding the output using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. + collective_op_set: A set of CollectiveOp objects for forward and backward passes. + quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ + if batch_sequence_transpose: + warnings.warn("batch_sequence_transpose is not well tested, use with caution!") + if not get_quantize_config().is_fp8_enabled(): input_dtype = x.dtype kernel = kernel.astype(input_dtype) @@ -93,32 +105,30 @@ def dense( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, ) return output -@partial( - jax.custom_vjp, - nondiff_argnums=( - 3, - 4, - 5, - 7, - ), -) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) def _dense( x, kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, # need to be a diff_arg for DelayedScaling state management ): """Internal implementation of dense layer transformation with custom VJP. @@ -130,10 +140,13 @@ def _dense( kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification + batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input + output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix - quantizer_set: QuantizerSet which contains quantizers for different tensor types using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. + collective_op_set: A set of CollectiveOp objects for forward and backward passes. + quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor @@ -143,10 +156,13 @@ def _dense( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, ) return output @@ -156,10 +172,13 @@ def _dense_fwd_rule( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -202,9 +221,12 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set.forward, ) + output = with_sharding_constraint_by_logical_axes(output, output_axes) if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape @@ -223,8 +245,16 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad -): # pylint: disable=unused-argument + contracting_dims, + batch_sequence_transpose, + input_axes, + kernel_axes, + output_axes, + using_global_amax_of_x, + collective_op_set, + ctx, + grad, +): """Backward pass rule for dense layer transformation. Returns: @@ -239,6 +269,7 @@ def _dense_bwd_rule( quantizer_set, flatten_axis_k, ) = ctx + grad = with_sharding_constraint_by_logical_axes(grad, output_axes) fwd_x_contracting_dims, fwd_k_contracting_dims = map( tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims @@ -266,8 +297,9 @@ def _dense_bwd_rule( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set.backward, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -279,7 +311,10 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), + transpose_batch_sequence=batch_sequence_transpose, ) + + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fb3ac7b9ae..ad66684f2b 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ return drop_path_shape +# TODO(Phuong): move this function to sharding.py def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: """ Extend the given Flax logical axis rules with the predefined TransformerLayer's diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index e3eaa53e1d..cf77f8e0a0 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -41,6 +41,7 @@ def layernorm_mlp( norm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6, + batch_sequence_transpose: bool = False, norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, @@ -49,6 +50,10 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + collective_op_sets: Tuple[tex.CollectiveOpSet] = ( + tex.noop_collective_op_set, + tex.noop_collective_op_set, + ), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -72,6 +77,7 @@ def layernorm_mlp( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization + batch_sequence_transpose: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication @@ -80,6 +86,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -122,6 +129,7 @@ def layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -130,12 +138,13 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -147,6 +156,7 @@ def _layernorm_mlp( norm_type: str, zero_centered_gamma: bool, epsilon: float, + batch_sequence_transpose: bool, norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], @@ -155,6 +165,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + collective_op_sets: Tuple[tex.CollectiveOpSet], quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -174,12 +185,16 @@ def _layernorm_mlp( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability + batch_sequence_transpose: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for layernorm sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding + kernel_1_axes: Logical axes for first weight matrix sharding + kernel_2_axes: Logical axes for second weight matrix sharding ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of quantizer sets Returns: @@ -196,6 +211,7 @@ def _layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -204,6 +220,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ) return output @@ -220,6 +237,7 @@ def _layernorm_mlp_fwd_rule( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -228,6 +246,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -247,6 +266,10 @@ def _layernorm_mlp_fwd_rule( del kernel_1_axes, kernel_2_axes ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.forward.is_reduce_scatter + assert not collective_op_set_2.forward.is_all_gather # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) @@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_1.forward, ) if use_bias_1 and tex.gemm_uses_jax_dot(): @@ -326,8 +351,10 @@ def _layernorm_mlp_fwd_rule( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_2.forward, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -335,6 +362,8 @@ def _layernorm_mlp_fwd_rule( bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) + # sharding of outputs should be the same as dot_1's input + dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) ctx = ( @@ -364,6 +393,7 @@ def _layernorm_mlp_bwd_rule( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -372,6 +402,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, ctx, grad, ): @@ -410,6 +441,10 @@ def _layernorm_mlp_bwd_rule( ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.backward.is_all_gather + assert not collective_op_set_2.backward.is_reduce_scatter # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) @@ -436,6 +471,8 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set_2.backward, ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -450,6 +487,7 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -476,6 +514,8 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set_1.backward, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -486,6 +526,7 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 339e74e2fc..7a82612695 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from typing import Callable, Optional import warnings + import jax import jax.numpy as jnp from jax.interpreters import pxla @@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x + + +def tpsp_axis_size(): + """ + Get the size of the tensor parallelism axis. + Return 1 if no TP axis is set. + """ + return get_mesh_axis_size(global_mesh_resource().tpsp_resource) + + +def dp_or_fsdp_axis_size(): + """ + Get the size of the data parallelism or FSDP axis. + Return 1 if no DP/FSDP axis is set. + """ + dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) + fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) + return dp_size if dp_size > 1 else fsdp_size From 963b39c50c4025eb447f524335ad239f57bc6935 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 28 Sep 2025 23:45:22 -0700 Subject: [PATCH 39/78] fix to enable --overlap-grad-reduce Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 94d83fd638..e4e622d157 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -294,11 +294,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): - if not ctx.cpu_offloading: - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - weights[i] = w - weights[i].main_grad = main_grads[i] - weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] + origin_weights[i].main_grad = main_grads[i] + origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) @@ -459,8 +456,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): list(weight.main_grad.shape), weight.dtype, ) - # TODO: Need to check why weight doesn't have attr grad_added_to_main_grad when fine_grained_activation_offloading is True. - elif ctx.fuse_wgrad_accumulation and not ctx.fine_grained_activation_offloading: + elif ctx.fuse_wgrad_accumulation: wgrad = None else: wgrad = None From a91e4585523f77a89cd41f12f3c869ee73572045 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Sep 2025 11:35:34 -0400 Subject: [PATCH 40/78] [JAX] Add xml export for `test_multiprocessing_encoder` and `test_cgemm` (#2210) * add xml export for test_multiprocessing_encoder and test_cgemm Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/collective_gemm/run_test_cgemm.sh | 12 +++- .../run_test_multiprocessing_encoder.sh | 61 ++++++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 5bf7ccb59a..af263eb53d 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -4,6 +4,10 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + # Check if NVLINK is supported before running tests echo "*** Checking NVLINK support***" NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) @@ -69,7 +73,8 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ --num-processes=$NUM_GPUS \ --process-id=$i 2>&1 | tee "$LOG_FILE" & PID=$! @@ -94,8 +99,11 @@ for TEST_FILE in "${TEST_FILES[@]}"; do elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then echo "... $TEST_FILE FAILED" HAS_FAILURE=1 - else + elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then echo "... $TEST_FILE PASSED" + else + echo "... $TEST_FILE INVALID" + HAS_FAILURE=1 fi # Remove the log files after processing them diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 2a1ac0f8fa..2a979e1775 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -15,11 +15,37 @@ TEST_CASES=( "test_te_current_scaling_fp8_shardy" ) +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + echo echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM # Run each test case across all GPUs for TEST_CASE in "${TEST_CASES[@]}"; do echo @@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do # Define output file for logs LOG_FILE="${TEST_CASE}_gpu_${i}.log" - # Run pytest and redirect stdout and stderr to the log file - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ - --num-process=$NUM_GPUS \ - --process-id=$i > "$LOG_FILE" 2>&1 & - done + # For process 0: show live output AND save to log file using tee + if [ $i -eq 0 ]; then + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \ + "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done # Wait for the process to finish wait - tail -n +7 "${TEST_CASE}_gpu_0.log" # Check and print the log content accordingly if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" + HAS_FAILURE=1 elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE PASSED" else + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 - echo "... $TEST_CASE FAILED" fi # Remove the log file after processing it @@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do done wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + exit $HAS_FAILURE From dfeef1a26ba48ccbd690567a19137b2af8aeb7c9 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Mon, 29 Sep 2025 13:39:03 -0700 Subject: [PATCH 41/78] [JAX] Address tolerance check for current scaling dact dbias (#2211) Address tolerance check for current scaling dact Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9e39b84c0b..7f15eec892 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -780,9 +780,15 @@ def _test_quantize_dact_dbias( assert_allclose(te_output.data, jax_output.data) if is_dbias: - # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. precise_comparison = not ( - in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. + (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) + # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. + or ( + activation_type == ("squared_relu",) + and in_dtype == jnp.bfloat16 + and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + ) ) assert_allclose( te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype From 3f5b47549567d13db76470073c8f0467c23d4fca Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 29 Sep 2025 14:12:26 -0700 Subject: [PATCH 42/78] [Core][PyTorch] NVFP4 recipe (#2177) * Add NVFP4 recipe Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Frank Sun Co-authored-by: Oleg Goncharov Co-authored-by: Zhongbo Zhu Co-authored-by: Evgeny Tsykunov Co-authored-by: Tim Moon Co-authored-by: Teddy Do * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add MathDx dependency to GitHub builds Signed-off-by: Tim Moon * Suggestions from GitHub Copilot Signed-off-by: Tim Moon * Move 2x shape logic from core to PyTorch Signed-off-by: Kirthi Shankar Sivamani * Fix compilation errors with CUDA 12.1 Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * SM 70 is not supported in CUDA 13 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Typo Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Revert "Move 2x shape logic from core to PyTorch" This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185. Signed-off-by: Tim Moon * Added dequantize kernel for FP4 Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warning Signed-off-by: Tim Moon * Add NVFP4 support with fusible ops Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable. Signed-off-by: Tim Moon * Fix logic for 2x shapes and move to PyTorch Signed-off-by: Kirthi Shankar Sivamani * Fix CG test model config Signed-off-by: Kirthi Shankar Sivamani * Debug NVFP4 tensor size function Signed-off-by: Tim Moon * Proper handling of the RNG state Signed-off-by: Przemek Tredak * Test SR properly Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix workspace size for GEMM heuristic. Signed-off-by: Kirthi Shankar Sivamani * Fix compile error in C++ NVFP4 test Some some numeric errors when blocks are all zero. Signed-off-by: Tim Moon * fix distrbuted test problem shape Signed-off-by: zhongboz * proper assert dim for low precision AG TP Signed-off-by: zhongboz * clean up duplicated code in nvfp4_utils.cuh Signed-off-by: zhongboz * lint Signed-off-by: zhongboz * pylint: disable=unused-argument Signed-off-by: zhongboz * `nvte_cublas_gemm_v2` to take alpha pointer (#12) * make nvte_cublas_gemm_v2 to take alpha/beta pointers Signed-off-by: Phuong Nguyen * users are expected to pass a valid C_tensor Signed-off-by: Phuong Nguyen * typos Signed-off-by: Phuong Nguyen * API to have const float* alpha Signed-off-by: Phuong Nguyen * Minor tweaks Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes. Signed-off-by: Tim Moon * Debug IMA with alpha pointer Signed-off-by: Tim Moon --------- Signed-off-by: Phuong Nguyen Signed-off-by: Tim Moon Co-authored-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support fused amax kernels with NVFP4 quantization Signed-off-by: Tim Moon * Disable fused amax with cuDNN LayerNorm kernel Signed-off-by: Tim Moon * Add NVFP4 cases to distributed tests for TE ops Signed-off-by: Tim Moon * Change assert to NVTE_CHECK in the hadamard cast fusion Signed-off-by: Przemek Tredak * Fix compile error Signed-off-by: Tim Moon * Use global thread IDs for Philox subsequences Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add shape checks for NVFP4 cast kernels Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Do not fuse amax if cuDNN normalization is forced by envvar Signed-off-by: Przemek Tredak --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Przemek Tredak Signed-off-by: zhongboz Signed-off-by: Phuong Nguyen Co-authored-by: Frank Sun Co-authored-by: Oleg Goncharov Co-authored-by: Zhongbo Zhu Co-authored-by: Evgeny Tsykunov Co-authored-by: Tim Moon Co-authored-by: Teddy Do Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Przemek Tredak Co-authored-by: Phuong Nguyen --- .github/workflows/build.yml | 8 +- benchmarks/benchmark_rht_cast.py | 152 ++ build_tools/utils.py | 15 +- pyproject.toml | 3 +- qa/L0_pytorch_unittest/test.sh | 1 + qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/cpp/operator/CMakeLists.txt | 8 + tests/cpp/operator/test_cast_mxfp8.cu | 42 +- .../operator/test_cast_mxfp8_gated_swiglu.cu | 54 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 741 ++++++++ tests/cpp/test_common.cu | 239 ++- tests/cpp/test_common.h | 37 +- tests/pytorch/distributed/run_numerics.py | 242 ++- .../pytorch/distributed/run_numerics_exact.py | 718 ++++++++ tests/pytorch/distributed/test_fusible_ops.py | 18 +- tests/pytorch/distributed/test_numerics.py | 7 +- .../distributed/test_numerics_exact.py | 70 + tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 243 +++ .../pytorch/nvfp4/test_nvfp4_module_exact.py | 559 ++++++ .../nvfp4/test_nvfp4_quantize_exact.py | 495 ++++++ .../nvfp4/test_nvfp4_rht_quantize_exact.py | 255 +++ tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py | 238 +++ tests/pytorch/test_cuda_graphs.py | 71 +- .../test_float8_current_scaling_exact.py | 9 +- tests/pytorch/test_fusible_ops.py | 121 +- tests/pytorch/test_recipe.py | 37 + tests/pytorch/test_sanity.py | 34 + tests/pytorch/utils.py | 25 +- transformer_engine/common/CMakeLists.txt | 30 +- transformer_engine/common/common.cu | 12 +- transformer_engine/common/common.h | 51 +- transformer_engine/common/gemm/config.cpp | 116 ++ transformer_engine/common/gemm/config.h | 36 + .../common/gemm/cublaslt_gemm.cu | 345 +++- .../hadamard_transform/hadamard_transform.cu | 876 ++++++++++ .../hadamard_transform_cast_fusion.cu | 841 +++++++++ .../common/include/transformer_engine/gemm.h | 189 +- .../transformer_engine/hadamard_transform.h | 68 + .../include/transformer_engine/recipe.h | 4 + .../transformer_engine/transformer_engine.h | 50 +- .../common/normalization/layernorm/ln_api.cpp | 4 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 4 +- transformer_engine/common/recipe/__init__.py | 114 +- .../common/recipe/current_scaling.cu | 27 +- transformer_engine/common/recipe/nvfp4.cu | 54 + transformer_engine/common/swizzle/swizzle.cu | 265 +-- .../common/transformer_engine.cpp | 86 +- .../common/transpose/cast_transpose.h | 9 + ...quantize_transpose_vector_blockwise_fp4.cu | 842 +++++++++ .../common/util/cast_gated_kernels.cuh | 5 +- .../common/util/cast_kernels.cuh | 807 ++++++++- .../common/util/dequantize_kernels.cuh | 110 +- .../common/util/nvfp4_transpose.cuh | 1515 +++++++++++++++++ transformer_engine/common/util/ptx.cuh | 82 +- .../common/util/pybind_helper.h | 3 +- transformer_engine/common/utils.cuh | 20 + transformer_engine/pytorch/constants.py | 2 + .../pytorch/cpp_extensions/gemm.py | 20 + transformer_engine/pytorch/csrc/common.cpp | 30 + transformer_engine/pytorch/csrc/common.h | 83 +- .../pytorch/csrc/extensions/activation.cpp | 244 ++- .../pytorch/csrc/extensions/attention.cpp | 18 +- .../pytorch/csrc/extensions/bias.cpp | 48 +- .../pytorch/csrc/extensions/gemm.cpp | 20 +- .../pytorch/csrc/extensions/normalization.cpp | 270 ++- .../pytorch/csrc/extensions/pybind.cpp | 19 + transformer_engine/pytorch/csrc/pybind.h | 20 +- transformer_engine/pytorch/csrc/quantizer.cpp | 590 ++++++- .../pytorch/csrc/type_converters.cpp | 40 + transformer_engine/pytorch/csrc/util.cpp | 55 +- transformer_engine/pytorch/distributed.py | 263 ++- .../pytorch/experimental/__init__.py | 10 + .../pytorch/experimental/config.py | 201 +++ .../pytorch/experimental/gemm.py | 139 ++ .../pytorch/experimental/quantization.py | 203 +++ .../quantization_microblock_ref.py | 811 +++++++++ .../pytorch/experimental/utils.py | 30 + transformer_engine/pytorch/fp8.py | 105 ++ transformer_engine/pytorch/module/_common.py | 38 +- transformer_engine/pytorch/module/base.py | 15 +- .../pytorch/module/layernorm_linear.py | 43 +- .../pytorch/module/layernorm_mlp.py | 48 +- transformer_engine/pytorch/module/linear.py | 45 +- .../pytorch/ops/basic/basic_linear.py | 8 + transformer_engine/pytorch/tensor/__init__.py | 3 + .../tensor/_internal/nvfp4_tensor_base.py | 348 ++++ .../pytorch/tensor/mxfp8_tensor.py | 5 +- .../pytorch/tensor/nvfp4_tensor.py | 898 ++++++++++ .../pytorch/tensor/quantized_tensor.py | 4 + transformer_engine/pytorch/tensor/utils.py | 21 +- transformer_engine/pytorch/triton/pad.py | 94 + transformer_engine/pytorch/utils.py | 14 +- 92 files changed, 15060 insertions(+), 753 deletions(-) create mode 100644 benchmarks/benchmark_rht_cast.py create mode 100644 tests/cpp/operator/test_cast_nvfp4_transpose.cu create mode 100644 tests/pytorch/distributed/run_numerics_exact.py create mode 100644 tests/pytorch/distributed/test_numerics_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_module_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py create mode 100755 tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py create mode 100644 transformer_engine/common/gemm/config.cpp create mode 100644 transformer_engine/common/gemm/config.h create mode 100644 transformer_engine/common/hadamard_transform/hadamard_transform.cu create mode 100644 transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu create mode 100644 transformer_engine/common/include/transformer_engine/hadamard_transform.h create mode 100644 transformer_engine/common/recipe/nvfp4.cu create mode 100644 transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu create mode 100644 transformer_engine/common/util/nvfp4_transpose.cuh create mode 100644 transformer_engine/pytorch/experimental/__init__.py create mode 100644 transformer_engine/pytorch/experimental/config.py create mode 100644 transformer_engine/pytorch/experimental/gemm.py create mode 100644 transformer_engine/pytorch/experimental/quantization.py create mode 100644 transformer_engine/pytorch/experimental/quantization_microblock_ref.py create mode 100644 transformer_engine/pytorch/experimental/utils.py create mode 100644 transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/nvfp4_tensor.py create mode 100644 transformer_engine/pytorch/triton/pad.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f40b281895..506bc83f08 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake==3.21.0 pybind11[global] ninja + pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -43,7 +43,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript + pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -63,7 +63,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install pybind11[global] + run: pip install pybind11[global] nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -83,7 +83,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install torch pybind11[global] einops onnxscript + run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/benchmarks/benchmark_rht_cast.py b/benchmarks/benchmark_rht_cast.py new file mode 100644 index 0000000000..9c47856f71 --- /dev/null +++ b/benchmarks/benchmark_rht_cast.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import pandas as pd +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +import transformer_engine.pytorch.cpp_extensions as ext + +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + +scale_padding_to = 1 +permute_scale = False + +TORCH_TO_TE_FLOAT_MAP = { + torch.bfloat16: tex.DType.kBFloat16, +} + + +def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16): + # Generate random input data + M, K = shape + x = torch.randn([M, K], dtype=input_dtype, device="cuda") + + assert shape[0] % 16 == 0, "Shape must be divisible by 16" + assert shape[1] % 16 == 0, "Shape must be divisible by 16" + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + stochastic_rounding=stochastic_rounding, + ) + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, K), dtype=x.dtype, device=x.device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + with torch.no_grad(): + stmt = "kernel_func(input, output)" + globals_dict = { + "kernel_func": nvfp4_quantizer.update_quantized, + "input": x, + "output": x_nvfp4_sut, + } + + timing = benchmark.Timer( + stmt=stmt, + globals=globals_dict, + num_threads=1, + ).blocked_autorange(min_run_time=5) + print(timing) + timing_us = timing.median * 1e6 + + input_nbytes = shape[0] * shape[1] * 2 # bf16 + output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4 + sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems + + total_nbytes = ( + 0 + + input_nbytes + * 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T)) + + 2 * 4 # Output 2 * float for scale & amax + + 2 * 4 # Input 2 * float + + output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T)) + + sf_nbytes * 2 # Scale factor + ) + + throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6) + + print( + f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:" + f" {throughput_GBps} GB/s" + ) + return timing_us, throughput_GBps + + +# Nsight Compute Profiling Command: +# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + args = parser.parse_args() + + if args.profile: + print("Profiling is enabled.") + else: + print("Profiling is disabled.") + + shapes = [ + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ] + + if args.profile: + shapes = [ + (16384, 6144), + ] + + data = [] + for stochastic_rounding in [True]: # , False]: + for shape in shapes: + print( + f"Running benchmark_func with shape {shape} and stochastic_rounding" + f" {stochastic_rounding}" + ) + timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding) + data.append( + [ + "benchmark_func", + shape, + stochastic_rounding, + timing_us, + throughput_GBps, + ] + ) + + df = pd.DataFrame( + data=data, + columns=[ + "kernel", + "shape", + "stochastic_rounding", + "timing_us", + "throughput(GB/s)", + ], + ) + print(df) + df.to_csv("benchmark_cast_nvfp4.csv", index=False) diff --git a/build_tools/utils.py b/build_tools/utils.py index 23fb565983..3d8ec462c8 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -234,15 +234,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - version = cuda_version() - if os.getenv("NVTE_CUDA_ARCHS") is None: + archs = os.getenv("NVTE_CUDA_ARCHS") + if archs is None: + version = cuda_version() if version >= (13, 0): - os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120" + archs = "75;80;89;90;100;100a;103a;120" + elif version >= (12, 9): + archs = "70;80;89;90;100;100a;103a;120" elif version >= (12, 8): - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120" + archs = "70;80;89;90;100;100a;120" else: - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90" - return os.getenv("NVTE_CUDA_ARCHS") + archs = "70;80;89;90" + return archs def cuda_version() -> Tuple[int, ...]: diff --git a/pyproject.toml b/pyproject.toml index 64ff4c5cea..8692ad9610 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,8 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", -"torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 394273ca47..cdf0df8887 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -31,6 +31,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 19889946a6..e698e997a6 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 498c1d3944..479d378ba6 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu @@ -31,6 +32,13 @@ add_executable(test_operator test_swap_first_dims.cu ../test_common.cu) +# Add profiling and debug flags for CUDA compilation +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage +# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping + +# Find required packages find_package(OpenMP REQUIRED) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 49bbf16556..3800921446 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method, // Cache computations for (size_t i = i_min; i < i_max; ++i) { for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); @@ -310,12 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method, const double rel_tolerable_mismatches_limit = 0.0; size_t mismatches_scales = 0; - compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + + compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); const size_t mismatches_elts = 32 * mismatches_scales; auto [atol, rtol] = getTolerances(otype); @@ -481,22 +483,22 @@ void performTest_x2(const ProcessingMethod processing_method, const double rel_tolerable_mismatches_limit = 0.0; size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 464b771288..512ee7e810 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -267,19 +267,20 @@ void performTest_x1(const size_t rows, ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { - compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { - compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + } const size_t mismatches_elts = 32 * mismatches_scales; @@ -378,21 +379,22 @@ void performTest_x2(const size_t rows, const double rel_tolerable_mismatches_limit = 1.0e-4; size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu new file mode 100644 index 0000000000..e905a00640 --- /dev/null +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -0,0 +1,741 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" +#include + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); + const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); + const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); + return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +} + +template +std::vector create_transpose(const InputType* const input, const size_t rows, size_t cols) { + std::vector input_t(cols * rows); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const size_t idx_t = j * rows + i; + input_t[idx_t] = input[idx]; + } + } + return input_t; +} + +// Compute the global encode scale factor for a given global amax +float compute_global_encode_scaling_factor_FP4(const float global_amax) { + constexpr float fp8_max = 448.0f; // 448.0f; + constexpr float fp4_max = 6.0f; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +// 1D Scaling: Original implementation with 1x16 blocks +template +void quantize_nvfp4_1d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + + constexpr size_t block_size_X = 16; + const size_t blocks_X = divide_round_up(cols, block_size_X); + + std::array cache_buffer; + for (size_t i = 0; i < block_size_X; ++i) { + cache_buffer[i] = 0.0f; + } + + for (size_t i = 0; i < rows; ++i) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t j_min = block_X * block_size_X; + const size_t j_max = j_min + block_size_X; + + // Find block amax + float block_amax = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx] = elt; + block_amax = std::max(block_amax, std::abs(elt)); + } + + // 2. Compute E4M3 scaling factor + // Compute per-block encoding/decoding scaling factor + const float S_dec_b = block_amax / 6.0f; + + // Scale & Store per-block decoding scaling factor + const float S_dec_b_fp8 = S_dec_b * S_enc; + + // Compute "correct" per-block encoding scaling factor + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = static_cast(S_dec_b_fp8); + const float scale_reciprocal = S_enc_b_fp8; + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + + // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + } + } + } +} + +// Compute 2D mathematical scaling factors (8x8 for 128x128 input) +template +void compute_2d_mathematical_scales(float (*OP)(const float), + const InputType* const input, + const size_t rows, + const size_t cols, + const float global_amax, + std::vector>& math_scales) { + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + math_scales.resize(blocks_Y, std::vector(blocks_X)); + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Find 2D block amax over entire 16x16 region + float block_amax = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + block_amax = std::max(block_amax, std::abs(elt)); + } + } + + // Compute E4M3 scaling factor for this 16x16 block + const float S_dec_b = block_amax / 6.0f; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + math_scales[block_Y][block_X] = S_dec_b_fp8; + } + } +} + +// 2D Scaling: NEW implementation with proper replication +template +void quantize_nvfp4_2d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Step 1: Compute mathematical 8x8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr + if (scales != nullptr) { + // Each of the 128 rows gets scaling factors from its corresponding 16×16 block + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + } + + // Step 3: Apply quantization using the mathematical scaling factors + std::array, block_size_Y> cache_buffer; + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Get the scaling factor for this block + const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float scale_reciprocal = S_enc_b_fp8; + + // Process and cache data for this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx_y][cache_idx_x] = elt; + } + } + + // Apply scaling to all elements in this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x1 = j - j_min; + const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1); + + const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1]; + const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ? + cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f; + + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + } + } + } + } +} + +// Wrapper function that calls appropriate implementation based on 2D flag +template +void quantize_nvfp4(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_2d_quantization = false) { + if (use_2d_quantization) { + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } else { + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } +} + +template +void compute_ref(float (*OP)(const float), + const InputType* input, + fp4e2m1x2* output, + fp4e2m1x2* output_t, + fp8e4m3* scales, + fp8e4m3* scales_t, + const float global_amax, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const size_t scales_stride_t, + const bool use_2d_quantization = false) +{ + std::vector input_t = create_transpose(input, rows, cols); + + if (use_2d_quantization) { + // Step 1: Compute mathematical 8×8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Generate scales (128×8) by replicating row-wise + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + + // Step 3: Generate scales_t (128×8) with proper transposed block mapping + for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data + const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X + for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate + const size_t scale_idx = i * scales_stride_t + block_Y_new; + scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig]; + } + } + + // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d + // (This part processes the actual FP4 data using the mathematical scaling factors) + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + + } else { + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + } +} + +void compare_nvfp4_tensors(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8) { + std::vector mismatch_messages; + size_t total_mismatches = 0; + + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + total_mismatches++; + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); + + // Optional: limit number of detailed messages to avoid overwhelming output + if (mismatch_messages.size() <= 100) { + std::cout << "Error in tensor " << name << ": " << msg << std::endl; + } + } + } + } + } + + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > 100) { + std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; + } +} + +// Optional: Function to dump tensor data to files for detailed analysis +void dump_nvfp4_tensor_data(const std::string& prefix, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + std::string test_file = prefix + "_test.txt"; + std::string ref_file = prefix + "_ref.txt"; + std::string diff_file = prefix + "_diff.txt"; + + std::ofstream test_out(test_file); + std::ofstream ref_out(ref_file); + std::ofstream diff_out(diff_file); + + if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + const int pos = idx + k; + + test_out << "pos[" << pos << "] = " << t << std::endl; + ref_out << "pos[" << pos << "] = " << r << std::endl; + diff_out << "pos[" << pos << "] test=" << t << " ref=" << r + << " abs_diff=" << fabs(t - r) + << " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl; + } + } + } + std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl; + } else { + std::cout << "WARNING: Could not open files for tensor data dump" << std::endl; + } +} + +void print_detailed_tensor_comparison(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n", + name.c_str(), rows, cols, rows * cols); + + const int total_elements = rows * cols; + const int check_count = 128; + + printf("--- FIRST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = 0; i < std::min(check_count, total_elements); ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + + if (total_elements > 2 * check_count) { + printf("\n--- LAST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = total_elements - check_count; i < total_elements; ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + } + printf("==================================\n"); +} + +void compareResults_nvfp4(const Tensor &test, + const void *ref, const void *ref_t, const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + if (if_on_gpus) test.to_cpu(); + + const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data = reinterpret_cast(ref); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + + // Print detailed element-by-element comparison + // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); + // print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows); + + // Optionally dump tensor data to files for detailed analysis + if (dump_data) { + dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + + compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); +} + +template +void performTest(float (*OP)(const float), + const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = DType::kFloat4E2M1; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + // Use get_scale_tensor_dims for NVFP4 scale tensor dimensions + // Now that CheckScaleTensorShape is fixed, this should work correctly + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + Tensor input("input", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); + std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); + std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); + std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + + fillCase(&input, InputsFillCase::uniform); + + // Find global amax + float amax = 0.0f; + const InputType* input_dptr = input.rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + amax = fmaxf(amax, static_cast(input_dptr[idx])); + } + } + // Set 2nd stage NVFP4 scaling factor + output.set_scale(amax); + + bool use_2d_quantization = false; + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + output.scale(), + rows, + cols, + scales_stride, + scales_stride_t, + use_2d_quantization); + + QuantizationConfigWrapper quant_config; + + // Initialize stochastic rounding + Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); + rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed + rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence + rng_state.from_cpu(); + quant_config.set_stochastic_rounding(false); + quant_config.set_rng_state(rng_state.data()); + + // Set 2D quantization based on compile-time flag + quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + + // Call appropriate function based on operation type + // Activation functions take 3 parameters (input, output, stream) + // nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream) + if (OP == &gelu) { + nvte_gelu(input.data(), output.data(), 0); + } else if (OP == &silu) { + nvte_silu(input.data(), output.data(), 0); + } else if (OP == &relu) { + nvte_relu(input.data(), output.data(), 0); + } else if (OP == &qgelu) { + nvte_qgelu(input.data(), output.data(), 0); + } else if (OP == &srelu) { + nvte_srelu(input.data(), output.data(), 0); + } else { + nvte_quantize_v2(input.data(), output.data(), quant_config, 0); + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err)); + } + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + const double atol = 0.05; + const double rtol = 0.1; + + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); + + const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_ptr = ref_scales.get(); + const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + scale_mismatches_num); + + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); +} + +std::vector> tensor_dims = { + {32, 32}, + {32, 64}, + {64, 32}, + {64, 96}, + {128, 128}, + {256, 256}, + {512, 512}, + {1024, 1024}, + {2048, 2048}, + {128, 256}, + {8192, 128}, + {2048, 160}, + {8, 32, 1024}, + {16, 8, 4, 512}, + {1024, 16384}, + {4096, 13312}, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + ActivationType::GeLU, + ActivationType::SiLU, + ActivationType::ReLU, + ActivationType::QGeLU, + ActivationType::SReLU, +}; + +} // namespace + +class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam + , + transformer_engine::DType>> {}; + +TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ActivationType Act_type = std::get<0>(GetParam()); + const auto tensor_dims = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + + // Skip tests if the input tensor is 1D + if (tensor_dims.size() < 2) { + GTEST_SKIP(); + } + + // Forward activations + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + performTest(OP, tensor_dims); + ); +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "CAST_ONLY"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kBFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index f974d9083d..cdbfb05b3c 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -107,6 +107,10 @@ size_t DIVUP(const size_t &x, const size_t &y){ return (((x) + ((y)-1)) / (y)); } +size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){ + return DIVUP(x, y) * y; +} + struct scale_inv_meta { std::vector shape; DType type; @@ -143,21 +147,71 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul, 4ul}; - { - auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; - alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + + ret_rowwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type = DType::kFloat8E8M0; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } - { - auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; - alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; - ret_colwise.shape = {scale_dim_0, scale_dim_1}; + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + NVTE_CHECK(last_dim % 32 == 0); + NVTE_CHECK(first_dim % 32 == 0); + + scale_inv_meta ret_rowwise, ret_colwise; + + size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y, scale_dim_X}; + + size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t}; + + ret_rowwise.type = DType::kFloat8E4M3; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + ret_colwise.type = DType::kFloat8E4M3; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + ret_rowwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0; ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); @@ -176,13 +230,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -202,13 +256,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(first_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(first_dim, 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(last_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(last_dim, 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -250,14 +304,15 @@ Tensor::Tensor(const std::string& name, NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING + || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } } else { - // Same shape for MX + // Same shape for MX and NVFP4 for (size_t i = 0; i < shape.ndim; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } @@ -283,10 +338,13 @@ Tensor::Tensor(const std::string& name, std::fill_n(cpu_data_columnwise_.get(), total_size, 0); } } - tensor_.set_rowwise_data(dptr_rowwise, type, shape); - tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); - if (isFp8Type(type)) { + const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); + tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); + + if (isFp8Type(type) || isFp4Type(type)) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMemset(amax, 0, sizeof(float)); @@ -305,13 +363,19 @@ Tensor::Tensor(const std::string& name, } if (columnwise) { tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); + std::vector{1}); columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(normalized_shape, tensor_.scaling_mode()); + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + // Used for NVFP4 second stage scaling + cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + cudaMemset(scale, 0, sizeof(float)); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); auto rowwise_scale_size = rowwise_scale_meta.bytes(); auto columnwise_scale_size = colwise_scale_meta.bytes(); auto scale_shape = rowwise_scale_meta.shape; @@ -346,13 +410,16 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (columnwise_) { + const DType colwise_type = tensor_.dtype(); + + const size_t colwise_size = bytes(s, colwise_type); cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); + tensor_.get_columnwise_data().data_ptr, + colwise_size, + cudaMemcpyDeviceToHost); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { if (tensor_.amax() != nullptr){ cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), @@ -364,8 +431,7 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -394,15 +460,15 @@ void Tensor::from_cpu() const { cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) { if (tensor_.amax() != nullptr){ cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -419,7 +485,7 @@ void Tensor::from_cpu() const { } void Tensor::set_scale(float scale) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { NVTE_CHECK(scale_cpu_data_); if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; @@ -429,7 +495,7 @@ void Tensor::set_scale(float scale) { } void Tensor::set_scale_inv(float scale_inv) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { NVTE_CHECK(rowwise_scale_inv_cpu_data_); } @@ -437,8 +503,7 @@ void Tensor::set_scale_inv(float scale_inv) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(tensor_.shape(), tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); if (num_scales == 1) { @@ -468,7 +533,8 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) + || isFp4Type(dtype()) && isFp4Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), @@ -681,12 +747,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } } -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit) +template +struct CastToType; + +template <> +struct CastToType { + using type = int; +}; + +template <> +struct CastToType { + using type = float; +}; + +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + using UpcastType = typename CastToType::type; + auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); + + const size_t N = row_blocks * col_blocks; const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, std::floor(N * rel_tolerable_mismatches_limit)); @@ -696,11 +780,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - const int test_val = static_cast(test[idx]); - const int ref_val = static_cast(ref[idx]); - const int abs_delta = std::abs(test_val - ref_val); + float t, r; + + bool assertion = false; - if (abs_delta > atol) { + if (std::is_same::value) { + t = static_cast(test[idx]); + r = static_cast(ref[idx]); + assertion = std::abs(t - r) > atol; + } else { + t = static_cast(*reinterpret_cast(&test[idx])); + r = static_cast(*reinterpret_cast(&ref[idx])); + const bool mismatch = (fabs(t - r) > atol_fp8e4m3) + && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3); + if (mismatch) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + } + if (assertion) { mismatches_num++; mismatch_indices.push_back(idx); } @@ -708,8 +812,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, std::cout << "Error in " << name << std::endl; for (const int index : mismatch_indices) { std::cout << "Mismatch at (" << index << "):" - << static_cast(test[index]) << " vs " - << static_cast(ref[index]) << std::endl; + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; } GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " << tolerable_mismatches_limit << "."; @@ -718,6 +822,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, } } +// Instantiate templates +template +void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + +template +void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: @@ -873,6 +993,10 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } +bool isFp4Type(DType type) { + return type == DType::kFloat4E2M1; +} + int32_t getDeviceComputeCapability() { cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); @@ -894,7 +1018,8 @@ std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols) { - const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool is_rowwise = (block_size_rows == 1) + && ((block_size_cols == 32) || (block_size_cols == 16)); const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index d1e273c6d8..b8993dfb62 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -62,6 +62,8 @@ using fp8e5m2 = __nv_fp8_e5m2; using fp8e8m0 = uint8_t; #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif template @@ -223,7 +225,9 @@ class Tensor { float scale() const { if(scale_cpu_data_) { - NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); + NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING), + "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -237,6 +241,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -250,6 +256,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -304,10 +312,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement -constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; -constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +constexpr size_t scale_tensor_alignment_X_colwise = 128; inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; @@ -456,12 +464,14 @@ void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - size_t& mismatches_num, - const size_t scale_diff_abs_tolerance = 0, - const double abs_tolerable_mismatches_limit = 0, - const double rel_tolerable_mismatches_limit = 0); +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); + std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols); @@ -484,6 +494,7 @@ const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +bool isFp4Type(DType type); int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; @@ -561,7 +572,7 @@ constexpr int32_t blackwellComputeCapability = 100; SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ printf("dtype: %d\n", static_cast(dtype)); \ - NVTE_ERROR("Invalid type MARKED TEST."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ @@ -580,7 +591,7 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 2."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ @@ -588,7 +599,7 @@ constexpr int32_t blackwellComputeCapability = 100; using namespace transformer_engine; \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 3."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ @@ -613,5 +624,5 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 4."); \ + NVTE_ERROR("Invalid type."); \ } diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 21aab6336b..a4aa74bd8f 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -9,6 +9,7 @@ import os import sys from functools import wraps +import math import transformer_engine.pytorch as te import torch @@ -20,10 +21,15 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, Format, Recipe, + QParams, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.distributed import gather_along_first_dim from run_layer_with_overlap import _compare_tensors SEQ_LEN, BATCH_SIZE = 16, 16 @@ -47,6 +53,14 @@ ) +def nvfp4_vanilla(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams() + nvfp4_recipe.fp4_quant_fwd_weight = QParams() + nvfp4_recipe.fp4_quant_bwd_grad = QParams() + return nvfp4_recipe + + # Quantization recipe setup def quantization_recipe() -> Recipe: if QUANTIZATION == "fp8": @@ -59,6 +73,8 @@ def quantization_recipe() -> Recipe: return Float8CurrentScaling() if QUANTIZATION == "fp8_block_scaling": return Float8BlockScaling() + if QUANTIZATION == "nvfp4": + return nvfp4_vanilla() return te.fp8.get_default_fp8_recipe() @@ -96,10 +112,14 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"): SEQ_LEN = 32 BATCH_SIZE = 32 HIDDEN_SIZE = 128 + # For fp8 block scaling, block size is 128, + # and to make low precision TP work, input tensor + # must be 128x128 divisible to be eligible for + # low precision All-Gather when needed elif QUANTIZATION == "fp8_block_scaling": SEQ_LEN = 128 BATCH_SIZE = 128 @@ -107,6 +127,7 @@ def main(argv=None, namespace=None): test_dict = [ test_quantizer, + test_quantized_all_gather, test_linear, test_layernorm, test_layernorm_linear, @@ -176,6 +197,9 @@ def _get_tolerances(dtype): # row parallel & sequence parallel, because we do the all_gather in backward pass if QUANTIZATION == "fp8_cs": return {"rtol": 0.4, "atol": 0.25} + elif QUANTIZATION == "nvfp4": + # TODO(zhongboz): investigate why the tolerance is so large + return {"rtol": 0.125, "atol": 0.12} elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} @@ -326,24 +350,36 @@ def _alloc_main_grad(model_single_node, model_distributed): ############################################### # Quantizer # ############################################### -def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): +def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size): """ quantizer is the reference quantizer on a single GPU. quantizer_dist is the distributed quantizer to be tested on multiple GPUs. """ if quantizer_class == Float8CurrentScalingQuantizer: quantizer_dist = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=True, amax_reduction_group=tp_group, ) quantizer = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=False, ) return quantizer, quantizer_dist + elif quantizer_class == NVFP4Quantizer: + quantizer_dist = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=True, + amax_reduction_group=tp_group, + ) + quantizer = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=False, + amax_reduction_group=None, + ) + return quantizer, quantizer_dist else: raise ValueError(f"Unsupported quantizer class: {quantizer_class}") @@ -414,6 +450,194 @@ def test_quantizer(): _test_quantizer(input_dtype, fp8_dtype) +############################################ +# Quantized All-Gather # +############################################ + + +def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape): + """ + Zero padding the scale_inv. + scale_inv shape is the padded shape, but not zero padded + unpadded_shape is the original shape before padding + """ + dim0, dim1 = scale_inv.shape + unpadded_dim0, unpadded_dim1 = unpadded_shape + pad_dim0 = (128 - unpadded_dim0 % 128) % 128 + pad_dim1 = (4 - unpadded_dim1 % 4) % 4 + new_dim0 = unpadded_dim0 + pad_dim0 + new_dim1 = unpadded_dim1 + pad_dim1 + + assert dim0 == new_dim0 + assert dim1 == new_dim1 + + # return input if no padding is needed + if pad_dim0 == 0 and pad_dim1 == 0: + return scale_inv + + # unpad first to remove random bits from torch empty + scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous() + # using torch padding + new_scale_inv = torch.nn.functional.pad( + scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0 + ) + + assert new_scale_inv.shape == (new_dim0, new_dim1) + + return new_scale_inv + + +def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise): + """ + Calculate the unpadded shape of the scale_inv tensor. + """ + M, K = 1, 1 + M = math.prod(input_shape[:-1]) + K = input_shape[-1] + + if quantizer_cls == NVFP4Quantizer: + if columnwise: + outer = K + inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + outer = M + inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + raise ValueError(f"Unsupported quantizer class: {quantizer_cls}") + + +@run_distributed_test() +def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls): + """Test the quantizer under distributed settings. + + Args: + input_dtype (torch.dtype): The data type of the input. + low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8. + """ + + M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2 + + # high precision input + x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype) + # set one element of the input to a very large value, which doesn't live in rank 0 after the split + # to test the amax reduction on purpose + # x_hp_cpu[M - 1, N - 1] = 1e4 + + # get the unpadded shapes + unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False) + unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True) + + # rank 0 takes the full copy and quantize with GPU 0 for verification + if WORLD_RANK == 0: + x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda") + x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK] + + # Create quantizers + quantizer, quantizer_dist = _construct_quantizer( + quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE + ) + + # quantize the entire input + if WORLD_RANK == 0: + x_low_precision_single = quantizer(x_hp_rank0) + + # run all-gather with a quantizer as input for quantized all-gather + x_low_precision_total, _ = gather_along_first_dim( + x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist + ) + + # check the outputs + if WORLD_RANK == 0: + # assert all data and scale_inv are the same + torch.testing.assert_close( + x_low_precision_single._rowwise_data, + x_low_precision_total._rowwise_data, + rtol=0.0, + atol=0.0, + ) + # check the rowwise scale without any padding + unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape + unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_rowwise_scale_inv_ref, + unpadded_rowwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + x_low_precision_single._columnwise_data, + x_low_precision_total._columnwise_data, + rtol=0.0, + atol=0.0, + ) + unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape + unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_columnwise_scale_inv_ref, + unpadded_columnwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + + +def test_quantized_all_gather(): + """ + Run quantized all-gather tests with various configurations. + """ + # skip this test for other quantization schemes + is_nvfp4 = QUANTIZATION == "nvfp4" + # add other recipes for testing if needed + if not is_nvfp4: + return + + input_dtypes = [torch.bfloat16] + fp4_dtype = [tex.DType.kFloat4E2M1] + fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + quantizer_cls_nvfp4 = [NVFP4Quantizer] + # add FP8 quantizers if needed + quantizer_cls_fp8 = [] + + low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype + quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8 + + for quantizer_cls in quantizer_cls_list: + for input_dtype in input_dtypes: + for low_precision_dtype in low_precisio_dtypes: + _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls) + + ############################################ # Linear # ############################################ @@ -514,10 +738,11 @@ def test_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"delay_wgrad_compute": True}, {"save_original_input": True}, ] + for kwargs in kwargs_list: if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": continue @@ -693,11 +918,12 @@ def test_layernorm_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"zero_centered_gamma": False}, {"return_layernorm_output": True}, {"delay_wgrad_compute": True}, ] + for kwargs in kwargs_list: for parallel_mode in ["column"]: for sequence_parallel in [False, True]: @@ -799,7 +1025,7 @@ def test_layernorm_mlp(): {"normalization": "RMSNorm"}, {"zero_centered_gamma": True}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"activation": "relu"}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, @@ -897,7 +1123,7 @@ def test_transformer_layer(): {"fuse_qkv_params": True, "fuse_wgrad_accumulation": True}, {"qkv_weight_interleaved": False}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"fuse_qkv_params": True}, {"activation": "relu"}, ] diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py new file mode 100644 index 0000000000..b1722b79a8 --- /dev/null +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -0,0 +1,718 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import datetime +import os +import sys +from functools import wraps +import math + +import transformer_engine.pytorch as te +import torch +from torch import nn +import torch.distributed as dist +import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + NVFP4BlockScaling, + Format, + Recipe, + QParams, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from run_layer_with_overlap import _compare_tensors + + +BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128 +WORLD_RANK, WORLD_SIZE = None, None +NCCL_WORLD = None +LOSS_FN = nn.MSELoss() +QUANTIZATION = None + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +# Quantization recipe setup +def quantization_recipe() -> Recipe: + if QUANTIZATION == "nvfp4": + return nvfp4_rht_and_2d_quantization() + raise ValueError(f"Unsupported quantization: {QUANTIZATION}") + + +def setup_environment_for_reference(): + if QUANTIZATION == "nvfp4": + os.environ["QAT_PARAMS"] = "9003" + else: + raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}") + + +def cleanup_environment(): + if "QAT_PARAMS" in os.environ: + del os.environ["QAT_PARAMS"] + + +def main(argv=None, namespace=None): + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION, BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + NCCL_WORLD = dist.new_group(backend="nccl") + + WORLD_SIZE = dist.get_world_size() + + parser = argparse.ArgumentParser() + parser.add_argument("--quantization", type=str, default=None) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--hidden-size", type=int, default=128) + parser.add_argument("--out-size", type=int, default=128) + args = parser.parse_args(argv, namespace) + + # Quantization scheme + QUANTIZATION = args.quantization + BATCH_SIZE = args.batch_size + HIDDEN_SIZE = args.hidden_size + OUT_SIZE = args.out_size + + test_dict = [ + test_linear, + test_layernorm_linear, + ] + + for test in test_dict: + test() + dist.destroy_process_group() + return 0 + + +def run_distributed_test(test_name=None): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + name = test_name if test_name is not None else func.__name__ + + dist_print(f"Starting test {name} with args {args} and {kwargs}") + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + func(*args, **kwargs) + + dist.barrier() + dist_print(f"Passed test {name}") + + return wrapper + + return decorator + + +def dist_print(msg, src=None, end="\n", error=False): + stream = sys.stderr if error else sys.stdout + if WORLD_RANK == (0 if src is None else src): + stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") + + +############################################ +# Linear # +############################################ +class TestDistributedLinearBase: + @staticmethod + def _prepare_data( + batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32 + ): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda") + bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None + gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda") + + return x, w, bias, gradient + + @staticmethod + def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad)) + return out + + @staticmethod + def _gather_tensor(local, world_size, tp_group, concat_dim): + out_list = [torch.zeros_like(local) for _ in range(world_size)] + torch.distributed.all_gather(out_list, local, tp_group) + return torch.cat(out_list, dim=concat_dim) + + @staticmethod + def _all_reduce_tensor(local, world_size, tp_group): + if world_size == 1: + return local + handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False) + return local + + @staticmethod + def _get_sum_abs_error(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _get_mean_abs_relative_error(a, b): + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) + + @classmethod + def run_linear_preprocess_parallel( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_size=1, + rank=0, + ): + if tp_size > 1: + if parallel_mode == "column": + # split w in N dim, which should be axis 0 + w = cls._shard_tensor(w, tp_size, 0)[rank] + bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None + # split gradient in N dim, which should be axis 1 + gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + if sequence_parallel: + # split x in M dim, which should be axis 0 + x = cls._shard_tensor(x, tp_size, 0)[rank] + # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1 + if parallel_mode == "row": + # split x in K dim, which should be axis 1 + x = cls._shard_tensor(x, tp_size, 1)[rank] + # split w in K dim, which should be axis 1 + w = cls._shard_tensor(w, tp_size, 1)[rank] + if sequence_parallel: + # split gradient in M dim, which should be axis 0 + gradient = cls._shard_tensor(gradient, tp_size, 0)[rank] + return x, w, bias, gradient + + @classmethod + def run_linear_postprocess_parallel( + cls, + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ): + if tp_size > 1: + if parallel_mode == "column": + # gather y_q in N dim, which should be axis 1 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1) + # gather wgrad in N dim, which should be axis 0 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0) + # gather bgrad in N dim, which should be axis 0 + bgrad = ( + cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None + ) + if sequence_parallel: + # gather dgrad in M dim, which should be axis 0 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0) + if parallel_mode == "row": + # gather dgrad in K dim, which should be axis 1 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1) + # gather wgrad in K dim, which should be axis 1 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1) + if sequence_parallel: + # gather y_q in M dim, which should be axis 0 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0) + # we need to sum bias gradient when using TP + SP + bgrad = ( + cls._all_reduce_tensor(bgrad, tp_size, tp_group) + if bgrad is not None + else None + ) + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_one_step( + cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False + ): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + if isinstance(layer, te.Linear): + # Kitchen Linear + y_q = layer.forward(x, is_first_microbatch=is_first_microbatch) + else: + # the default torch.nn.Linear + y_q = layer(x) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + bgrad = ( + layer._parameters["bias"].grad + if layer._parameters.get("bias", None) is not None + else None + ) + assert "weight" in layer._parameters + if fuse_wgrad_accumulation: + wgrad = layer._parameters["weight"].main_grad + assert layer._parameters["weight"].grad is None + else: + wgrad = layer._parameters["weight"].grad + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation=False, + ): + """ + Run multiple steps of linear layer and collect results. + """ + + y_q_list, dgrad_list, wgrad_list = [], [], [] + bgrad_list = [] if layer._parameters.get("bias", None) is not None else None + + for i in range(run_num_steps): + x_i = (x + i).clone().detach().requires_grad_(True) + # run_linear_one_step + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step( + layer, + x_i, + gradient, + is_first_microbatch=(i == 0) if enable_weight_cache else None, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + # Collect results + y_q_list.append(y_q.detach().clone()) + dgrad_list.append(dgrad.detach().clone()) + wgrad_list.append(wgrad.detach().clone()) + if bgrad_list is not None and bgrad is not None: + bgrad_list.append(bgrad.detach().clone()) + + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + + @classmethod + def run_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + fuse_wgrad_accumulation=False, + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = te.Linear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + layer = layer.to("cuda") + + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + if fuse_wgrad_accumulation: + assert ( + run_num_steps > 1 + ), "Fused weight gradient accumulation requires run_num_steps > 1" + layer.weight.main_grad = torch.zeros_like(layer.weight) + + # Run one step or multiple steps + if run_num_steps == 1: + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + else: + y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps( + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation, + ) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'row' or 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + + QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + fuse_wgrad_accumulation = kwargs.get("fuse_wgrad_accumulation", False) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # turn on weight quantization cache when fusing wgrad accumulation + enable_weight_cache = fuse_wgrad_accumulation + run_num_steps = 1 if not fuse_wgrad_accumulation else 5 + + x, w, bias, gradient = TestDistributedLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=1 if not fuse_wgrad_accumulation else 5, + enable_weight_cache=fuse_wgrad_accumulation, + ) + + # run the reference + setup_environment_for_reference() + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=run_num_steps, + enable_weight_cache=enable_weight_cache, + ) + # Clean up env + cleanup_environment() + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_linear(): + """Run linear layer tests with various configurations.""" + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": + continue + for parallel_mode in ["column", "row"]: + for sequence_parallel in [False, True]: + _test_linear(parallel_mode, sequence_parallel, **kwargs) + + +############################################ +# LayerNormLinear # +############################################ +class TestDistributedLayerNormLinearBase(TestDistributedLinearBase): + + @classmethod + def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + + parameters = layer._parameters + + # bias and weight gradients + bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None + assert "weight" in parameters + wgrad = parameters["weight"].grad + + return y_q, ln_out, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False + ): + # raise error, no test case for multiple steps for now + raise NotImplementedError("LayerNormLinear does not support test multiple steps for now") + + @classmethod + def run_layernorm_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + LayerNormLinearClass=te.LayerNormLinear, + normalization="LayerNorm", + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = LayerNormLinearClass( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + normalization=normalization, + return_layernorm_output=True, + ) + + layer = layer.to("cuda") + + # Copy weights + # kitchen_linear has different parameter names + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + # Run one step + y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, ln_out, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # run multiple steps currently not supported for LayerNormLinear + run_num_steps = 1 + + x, w, bias, gradient = TestDistributedLayerNormLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + + # run the reference + setup_environment_for_reference() + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = ( + TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + ) + # Clean up env + cleanup_environment() + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(ln_out, ln_out_ref, atol=0, rtol=0, msg="LN output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_layernorm_linear(): + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + for parallel_mode in ["column"]: + for sequence_parallel in [False, True]: + _test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 8ca1fcc1cb..11fe4333bc 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -27,6 +27,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex @@ -34,17 +35,20 @@ # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, make_recipe +from utils import dtype_tols, make_recipe, quantization_tols # Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") +if nvfp4_available: + quantization_list.append("nvfp4") @functools.cache @@ -115,6 +119,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -437,7 +449,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -609,7 +621,7 @@ def _test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 1ff5aff997..d09c530cba 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -31,6 +31,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( FP8GlobalStateManager.is_fp8_block_scaling_available() ) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -51,7 +52,9 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) +@pytest.mark.parametrize( + "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"] +) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -61,4 +64,6 @@ def test_distributed(quantization): pytest.skip(reason_for_no_mxfp8) if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) _run_test(quantization) diff --git a/tests/pytorch/distributed/test_numerics_exact.py b/tests/pytorch/distributed/test_numerics_exact.py new file mode 100644 index 0000000000..890a248044 --- /dev/null +++ b/tests/pytorch/distributed/test_numerics_exact.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +""" + Distributed numerics tests + + This numerical test aims for zero tolerance test for absolute confidence in numerics. + In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise + result with the native silicon. For distrbuted test cases, we can do the same by thing + by comparing BF16 AG results with the low precision AG results at layer level. +""" + + +if torch.cuda.device_count() < 2: + pytest.skip("Distributed training needs at least 2 GPUs.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(4, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(quantization, batch_size, hidden_size, out_size): + test_path = TEST_ROOT / "run_numerics_exact.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + test_cmd += ["--quantization", quantization] + test_cmd += ["--batch-size", str(batch_size)] + test_cmd += ["--hidden-size", str(hidden_size)] + test_cmd += ["--out-size", str(out_size)] + + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 + + +all_boolean = [True, False] + + +@pytest.mark.parametrize("quantization", ["nvfp4"]) +@pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (64, 128, 128), + (128, 128, 128), + (128, 256, 256), + (512, 1024, 768), + (512, 256, 1024), + (2048, 2048, 2048), + ], +) +def test_distributed(quantization, batch_size, hidden_size, out_size): + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + _run_test(quantization, batch_size, hidden_size, out_size) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py new file mode 100644 index 0000000000..a9e73aaf9f --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -0,0 +1,243 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def check_nvfp4_gemm_versus_reference( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + M: int, + K: int, + N: int, + accumulate: bool, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, +): + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input tensors + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + x = torch.randn(x_shape, dtype=x_dtype, device=device) + w = torch.randn(w_shape, dtype=w_dtype, device=device) + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Native TE NVFP4 quantization + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + # Quantize x and w + x_nvfp4_native = x_quantizer.make_empty( + x_shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_native = x_quantizer.update_quantized(x, x_nvfp4_native) + w_nvfp4_native = w_quantizer.make_empty( + w_shape, dtype=w_dtype, device=device, requires_grad=False + ) + w_nvfp4_native = w_quantizer.update_quantized(w, w_nvfp4_native) + + # Extract quantized data from native NVFP4Tensors + qx_data = ( + x_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if x_columnwise + else x_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + qw_data = ( + w_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if w_columnwise + else w_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + sx_native = ( + x_nvfp4_native._columnwise_scale_inv if x_columnwise else x_nvfp4_native._rowwise_scale_inv + ) + sw_native = ( + w_nvfp4_native._columnwise_scale_inv if w_columnwise else w_nvfp4_native._rowwise_scale_inv + ) + + # Trim quantized data to match the actual tensor dimensions (remove padding) + qx_data = qx_data[:M, :] + qw_data = qw_data[:N, :] + + # NVFP4 uses 16-element blocks, trim scales to remove padding + block_length = 16 # NVFP4 uses 16-element blocks + expected_sx_cols = expected_sw_cols = K // block_length + # Trim the scales to remove padding + sx_trimmed = sx_native[:M, :expected_sx_cols] + sw_trimmed = sw_native[:N, :expected_sw_cols] + + # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn + # for the reference GEMM to work correctly + sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) + sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) + + # Create reference quantizer for reference GEMM + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=True, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + + # Create reference quantized tensors needed by reference GEMM + x_nvfp4_ref = ref_quantizer.quantize(x) + w_nvfp4_ref = ref_quantizer.quantize(w) + + # Reference GEMM using quantizer's qgemm method + y_ref = ref_quantizer.qgemm( + qx=qx_data, + qw=qw_data, + m_params=None, # MMParams not used in reference + out_dtype=out_dtype, + sx=sx_trimmed, + sw=sw_trimmed, + bias=None, # No bias for this test + out=out.clone() if accumulate else None, + accumulate=accumulate, + gemm_type=None, # GEMMType not used in reference + qresult_x=x_nvfp4_ref, + qresult_w=w_nvfp4_ref, + ) + + # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) + # Allocate cuBLAS workspace + workspace = torch.empty(4, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + bias = None + bias_dtype = TE_DType[torch.bfloat16] + use_gelu = False + gelu_input = None + use_grad = False + use_split_accumulator = False + + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y_native are not the same tensor + assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y_native = torch.where(y_native.isnan(), torch.zeros_like(y_native), y_native) + + # Compare results with some tolerance + torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, K, N", + [ + (128, 128, 128), + (256, 128, 256), + (256, 256, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (4096, 512, 3072), + (112, 128, 96), + (304, 640, 304), + (1008, 3072, 992), + (256, 64, 256), + (128, 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (False, False), # Only rowwise x rowwise is supported by reference GEMM + # Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization + # Columnwise layouts are not supported by the reference implementation + ], + ids=["rowxrow"], +) +def test_nvfp4_gemm_versus_reference( + M: int, + K: int, + N: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + accumulate: bool, + is_x_columnwise: bool, + is_w_columnwise: bool, +): + check_nvfp4_gemm_versus_reference( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + M=M, + K=K, + N=N, + accumulate=accumulate, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py new file mode 100644 index 0000000000..ae99758399 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -0,0 +1,559 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import torch +import transformer_engine as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.distributed import fp8_autocast +from transformer_engine.common import recipe + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +class GetRecipes: + @staticmethod + def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(random_hadamard_transform=True) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(random_hadamard_transform=False) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(random_hadamard_transform=True) + return nvfp4_recipe + + @staticmethod + def nvfp4_2d_quantization_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(fp4_2d_quantization=False) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(fp4_2d_quantization=False) + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + @staticmethod + def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = False): + if with_rht and with_2d_quantization: + return GetRecipes.nvfp4_rht_and_2d_quantization() + elif with_rht: + return GetRecipes.nvfp4_rht_only() + elif with_2d_quantization: + return GetRecipes.nvfp4_2d_quantization_only() + else: + return GetRecipes.nvfp4_vanilla() + + +def setup_environment_for_reference(with_rht: bool = False, with_2d_quantization: bool = False): + if with_rht and with_2d_quantization: + os.environ["QAT_PARAMS"] = "9003" + elif with_rht: + os.environ["QAT_PARAMS"] = "960109" + elif with_2d_quantization: + os.environ["QAT_PARAMS"] = "9002" + else: + os.environ["QAT_PARAMS"] = "6010" + + +def cleanup_environment(): + if "QAT_PARAMS" in os.environ: + del os.environ["QAT_PARAMS"] + + +def reset_rng_states(): + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def check_nvfp4_module_versus_reference( + module_class, + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 module against reference implementation. + + Args: + module_class: te.Linear or te.LayerNormLinear + in_features: Input feature dimension + out_features: Output feature dimension + bias: Whether to use bias + x_dtype: Input tensor dtype + num_steps: Number of forward/backward steps to test + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + cleanup_environment() + reset_rng_states() + + # Create native module + print("\nCreate native module") + if module_class == te.pytorch.Linear: + native_module = te.pytorch.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.pytorch.LayerNormLinear: + native_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + else: + raise ValueError(f"Unsupported module class: {module_class}") + + # Create reference module with same weights + setup_environment_for_reference(with_rht, with_2d_quantization) + reset_rng_states() + + # Create reference module + print("Create reference module") + if module_class == te.pytorch.Linear: + ref_module = te.pytorch.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.pytorch.LayerNormLinear: + ref_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + + # Sync weights between native and reference modules + with torch.no_grad(): + # Copy main weight and bias parameters + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + + # Copy layer norm parameters if they exist + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + + # Training loop comparison + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + cleanup_environment() + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + # enable weight cache by giving is_first_microbatch + y_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + setup_environment_for_reference(with_rht, with_2d_quantization) + with fp8_autocast( + enabled=True, fp8_recipe=nvfp4_recipe + ): # Exact recipe does not play a role here + y_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + # Store results + native_outputs.append( + { + "output": y_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results across all steps + for step in range(num_steps): + native_out = native_outputs[step] + ref_out = ref_outputs[step] + + # Compare outputs + torch.testing.assert_close( + native_out["output"], + ref_out["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + + # Compare input gradients + torch.testing.assert_close( + native_out["input_grad"], + ref_out["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Input gradient mismatch at step {step}. Native: {native_out['input_grad']}, Ref:" + f" {ref_out['input_grad']}" + ), + ) + + # Compare weight gradients + torch.testing.assert_close( + native_out["weight_grad"], + ref_out["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Weight gradient mismatch at step {step}. Native: {native_out['weight_grad']}," + f" Ref: {ref_out['weight_grad']}" + ), + ) + + # Compare bias gradients + if bias and native_out["bias_grad"] is not None and ref_out["bias_grad"] is not None: + torch.testing.assert_close( + native_out["bias_grad"], + ref_out["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + # Clean up + cleanup_environment() + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + (512, 512), + (768, 3072), + (1024, 4096), + ], +) +# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1, 3], ids=["single_step", "multi_step"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + """Test NVFP4 Linear module against reference implementation.""" + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_module_versus_reference( + module_class=te.pytorch.Linear, + in_features=in_features, + out_features=out_features, + bias=bias, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) + + +def check_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 LayerNormLinear module against reference implementation, + including ln_out. + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + cleanup_environment() + reset_rng_states() + + # Native module + native_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Reference module + setup_environment_for_reference(with_rht, with_2d_quantization) + reset_rng_states() + ref_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Sync weights and LN params + with torch.no_grad(): + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + if ( + native_module.layer_norm_weight is not None + and ref_module.layer_norm_weight is not None + ): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None: + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + cleanup_environment() + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + setup_environment_for_reference(with_rht, with_2d_quantization) + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + y_ref, ln_out_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + native_outputs.append( + { + "output": y_native.detach().clone(), + "ln_out": ln_out_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "ln_out": ln_out_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results + for step in range(num_steps): + n = native_outputs[step] + r = ref_outputs[step] + torch.testing.assert_close( + n["output"], + r["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + torch.testing.assert_close( + n["ln_out"], + r["ln_out"], + atol=1e-6, + rtol=1e-6, + msg=f"LN output mismatch at step {step}", + ) + torch.testing.assert_close( + n["input_grad"], + r["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Input gradient mismatch at step {step}", + ) + torch.testing.assert_close( + n["weight_grad"], + r["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Weight gradient mismatch at step {step}", + ) + if bias and n["bias_grad"] is not None and r["bias_grad"] is not None: + torch.testing.assert_close( + n["bias_grad"], + r["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + cleanup_environment() + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + ], +) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1], ids=["single_step"]) +@pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"], ids=["LayerNorm", "RMSNorm"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_layernorm_linear_versus_reference( + in_features=in_features, + out_features=out_features, + bias=bias, + normalization=normalization, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py new file mode 100644 index 0000000000..dc3c4a4e9a --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -0,0 +1,495 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + NVFP4Quantizer, +) +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + # Reference quantization + quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=quant_tile_shape, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + if x_nvfp4_ref.data_t is not None + else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # # largest tile + (8192, 8192), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"]) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] +) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_transpose=return_transpose, + swizzled_scale=swizzled_scale, + use_cpp_allocator=use_cpp_allocator, + with_2d_quantization=with_2d_quantization, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + extrema_high: bool, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (16, 128), + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_boundary_values( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + """ + Stress rounding/threshold behavior by placing values just below/above + many potential bin edges within each 16-element microblock. + Validates native vs reference byte-for-byte and scale parity. + """ + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 123 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Construct a single row with paired boundary values: v-eps, v+eps + # spanning a wide dynamic range to exercise clipping and multiple bins. + # Ensure even N and N is multiple of 16 for microblocks, which holds for 128. + base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device) + eps = torch.full_like(base, 1e-3) + # Avoid zero eps for very small magnitudes + eps = torch.maximum(eps, 1e-4 * torch.ones_like(base)) + lower = base - eps + upper = base + eps + row = torch.empty(N, dtype=torch.float32, device=device) + row[0::2] = lower + row[1::2] = upper + x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim any padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 17 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Start from a contiguous tensor, then make a non-contiguous view by transpose + x_base = torch.randn((M, N), dtype=x_dtype, device=device) + x_nc = x_base.t() # shape (N, M), non-contiguous + assert not x_nc.is_contiguous() + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x_nc) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + # Quantized must match + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py new file mode 100644 index 0000000000..bb542456e5 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -0,0 +1,255 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py. +# Separate to make sure all the functionalities are working as expected. +# Otherwise reference implementation will get messy. + +# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality +# together with the quantization functionality. + +from typing import Tuple +import math + +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + NVFP4Quantizer, +) +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype + +import pytest +import torch + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + contiguous: bool, + return_transpose: bool, + use_cpp_allocator: bool, + swizzled_scale: bool = False, + hadamard_dimension: int = 16, + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, +) -> None: + assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled." + + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + x = x.transpose(0, 1) if not contiguous else x + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + amax_rowwise = x_nvfp4_sut._amax_rowwise + amax_colwise = x_nvfp4_sut._amax_columnwise + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + # Reference quantization using NVFP4QuantizerRef with built-in RHT + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=with_rht, + with_random_sign_mask=with_random_sign_mask, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + ref_amax_rowwise = x_nvfp4_ref.global_amax_row + + if return_transpose: + assert x_nvfp4_ref.data_t is not None + assert x_nvfp4_ref.scale_t is not None + qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8) + # Compute transpose amax using the same reference quantizer + x_t_for_amax = ( + ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous() + ) + ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1) + else: + qx_t_ref = None + sx_t_ref = None + ref_amax_colwise_t = None + + torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # Real shapes, + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_rht_with_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=True, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +): + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=False, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py new file mode 100755 index 0000000000..46077eb205 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -0,0 +1,238 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + +seed = 12345 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +_FP4_LUT = torch.tensor( + [ + 0.0, # 0: 0000 - zero + 0.5, # 1: 0001 - smallest positive normal + 1.0, # 2: 0010 + 1.5, # 3: 0011 + 2.0, # 4: 0100 + 3.0, # 5: 0101 + 4.0, # 6: 0110 + 6.0, # 7: 0111 - largest positive normal + -0.0, # 8: 1000 - negative zero + -0.5, # 9: 1001 - smallest negative normal + -1.0, # 10: 1010 + -1.5, # 11: 1011 + -2.0, # 12: 1100 + -3.0, # 13: 1101 + -4.0, # 14: 1110 + -6.0, # 15: 1111 - largest negative normal + ], + dtype=torch.float32, +) + + +def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor: + # Convert FP4 indices to their corresponding floating point values + # Each index (0-15) represents a 4-bit FP4 value in E2M1 format + # Values based on the FP4 E2M1 specification + fp4_lut = _FP4_LUT.to(fp4.device) + return fp4_lut[fp4.to(torch.long)] + + +def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor: + sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) + dqx = fp4_to_fp32(unpack_fp4(qx)) + sf = sf[: dqx.shape[0], : dqx.shape[1]] + dequant = dqx * sf * (amax / (6.0 * 448)) + return dequant + + +def RHT(x: torch.Tensor) -> torch.Tensor: + def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [ + 1.0, + 1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + ], + dtype=torch.float32, + ) + + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + rht_dim = 16 + # Build H and scale + H = _build_hadamard_matrix(rht_dim, x.device, x.dtype) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + +def quantize_fp4( + x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool +) -> torch.Tensor: + nvfp4_quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=use_RHT, + with_post_rht_amax=True, + stochastic_rounding=use_stochastic_rounding, + with_2d_quantization=use_2D, + ) + + x_nvfp4_sut = nvfp4_quantizer(x) + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + assert x_nvfp4_sut._columnwise_data is not None + qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._columnwise_scale_inv is not None + sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv + + return qx, sx, qx_t, sx_t + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool +) -> None: + device = "cuda" + torch.manual_seed(seed) + n_iters = 50 + + x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1 + y = x.t().contiguous() + if use_RHT: + y = RHT(y) + amax = torch.max(torch.abs(x)).float() + q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4( + x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT + ) + dq_rn = dequantize_fp4(q_rn, s_rn, amax) + dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax) + error_rn = (dq_rn - x).float() + me_rn = torch.sqrt((error_rn * error_rn).mean()) + error_t_rn = (dq_t_rn - y).float() + me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) + sr_result = torch.zeros_like(x).float() + sr_t_result = torch.zeros_like(x).float().t().contiguous() + for i in range(n_iters): + q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( + x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT + ) + + dq_sr = dequantize_fp4(q_sr, s_sr, amax) + dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax) + + sr_result += dq_sr.float() + sr_t_result += dq_t_sr.float() + + # sr_result_tmp = sr_result / (i + 1) + # error_sr = (sr_result_tmp - x).float() + # me_sr = torch.sqrt((error_sr * error_sr).mean()) + # sr_t_result_tmp = sr_t_result / (i + 1) + # error_t_sr = (sr_t_result_tmp - y).float() + # me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + # print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + # print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + + # Get the mean result of the stochastic rounding + # It should be more accurate than the RN result + sr_result /= n_iters + error_sr = (sr_result - x).float() + me_sr = torch.sqrt((error_sr * error_sr).mean()) + sr_t_result /= n_iters + error_t_sr = (sr_t_result - y).float() + me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + + print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." + assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (8192, 8192), + (8192, 8256), # to test the nonfused RHT path + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_2D", [False, True], ids=str) +@pytest.mark.parametrize("use_RHT", [False, True], ids=str) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + use_2D: bool, + use_RHT: bool, + M: int, + N: int, +) -> None: + if x_dtype == torch.float32 and use_RHT: + pytest.skip("RHT is only supported with bfloat16") + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + use_2D=use_2D, + use_RHT=use_RHT, + M=M, + N=N, + ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 90e624c947..be7a65deb3 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -32,12 +32,59 @@ reset_rng_states() model_configs = { - "small": ModelConfig(32, 2, 2, 32), + "small": ModelConfig(2, 32, 2, 32), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + # if using RHT, we can only support bf16 + # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + # if not using RHT, we can add fp32 as well + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -278,7 +325,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) def test_make_graphed_callables( *, module: str, @@ -295,8 +342,18 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": - pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") + if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op": + pytest.skip( + f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" + ) + if fp8 and fp8_recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe" + f" {fp8_recipe.__class__.__name__}" + ) + if fp8_params: + pytest.skip("NVFP4 params not supported") # Run model with different CUDA graph settings. model_config = model_configs[model_config] @@ -334,17 +391,19 @@ def test_make_graphed_callables( "module", _test_make_graphed_callables_with_fp8_weight_caching_modules, ) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, + dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, ) -> None: test_make_graphed_callables( module=module, - dtype=torch.float32, + dtype=dtype, fp8_params=fp8_params, fp8_recipe=fp8_recipe, fp8_weight_caching=True, diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index a0d6f1fd94..82bd61a01e 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -10,7 +10,6 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex -import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype @@ -273,6 +272,14 @@ def run_linear_multiple_steps( if bgrad_list is not None and bgrad is not None: bgrad_list.append(bgrad.detach().clone()) + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + @classmethod def run_linear( cls, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index bb07e87d98..4409866617 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -35,15 +35,17 @@ Float8Quantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex # Import utility functions -from utils import dtype_tols, make_recipe, reset_rng_states +from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states -# Check if FP8 is supported +# Check for supported quantization schemes fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] @@ -59,6 +61,8 @@ _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: _quantization_list.append("mxfp8") +if nvfp4_available: + _quantization_list.append("nvfp4") def maybe_skip_quantization( @@ -66,6 +70,7 @@ def maybe_skip_quantization( *, dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, ) -> None: """Skip test case if a quantization scheme is not supported""" @@ -73,12 +78,17 @@ def maybe_skip_quantization( if quantization is None: return - # Check if quantization scheme is supported + # Check if quantization scheme is supported on device + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + # Check dims if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) @@ -88,10 +98,14 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + elif quantization == "nvfp4": + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") - # Check if device is supported - if device is not None and torch.device(device).type != "cuda": - pytest.skip("Quantization is only supported on CUDA devices") + # Check dtype + if dtype is not None: + if quantization == "nvfp4" and dtype != torch.bfloat16: + pytest.skip("NVFP4 quantization is only supported with BF16 data") @torch.no_grad() @@ -141,6 +155,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -395,12 +417,12 @@ def test_fp8_scale_update( torch.testing.assert_close( y, torch.full_like(y, y_val_ref), - **dtype_tols(tex.DType.kFloat8E4M3), + **quantization_tols("fp8_delayed_scaling"), ) torch.testing.assert_close( x.grad, torch.full_like(x.grad, dx_val_ref), - **dtype_tols(tex.DType.kFloat8E5M2), + **quantization_tols("fp8_delayed_scaling"), ) # Check that scaling factors match expected @@ -434,7 +456,8 @@ def test_dtype_cast( # Skip invalid configurations in_shape = (size, size) with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype) + maybe_skip_quantization(quantization, dtype=final_dtype) # Random data dtype = torch.float32 @@ -502,7 +525,8 @@ def test_pyt_autocast( # Skip invalid configurations in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype) + maybe_skip_quantization(quantization, dtype=autocast_dtype) # Construct operation recipe = make_recipe(quantization) @@ -558,7 +582,7 @@ def test_identity( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -624,7 +648,7 @@ def test_reshape( # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) with_quantization = quantization is not None # Random data @@ -690,7 +714,7 @@ def test_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -752,7 +776,7 @@ def test_quantize( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8": maybe_skip_quantization(quantization, dims=in_shape) @@ -819,7 +843,7 @@ def _test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) quantization_needed = any( ( @@ -899,7 +923,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute or quantized_output or quantized_grad_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1010,7 +1034,7 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantization is None and (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not specified") @@ -1077,7 +1101,7 @@ def test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1114,7 +1138,7 @@ def test_layer_norm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1175,7 +1199,7 @@ def test_layer_norm( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1284,7 +1308,7 @@ def test_rmsnorm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1337,7 +1361,7 @@ def test_rmsnorm( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1417,7 +1441,7 @@ def test_add_extra_input( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x1_ref, x1_test = make_reference_and_test_tensors( @@ -1456,8 +1480,11 @@ def test_add_extra_input( # Check results tols = dtype_tols(dtype) - if with_quantization: - tols = dtype_tols(x1_test._fp8_dtype) + if in_place: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): + tols = dtype_tols(x1_test._fp8_dtype) + elif quantization == "nvfp4": + tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") @@ -1486,7 +1513,7 @@ def test_make_extra_output( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1559,7 +1586,7 @@ def test_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if cache_quantized_input: maybe_skip_quantization("fp8_current_scaling", device=device) @@ -1633,8 +1660,10 @@ def test_activation( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute or cache_quantized_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + if quantized_compute: + tols = quantization_tols(quantization) + elif cache_quantized_input: + tols = quantization_tols("fp8_current_scaling") # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1665,7 +1694,7 @@ def test_swiglu( quantized_compute = quantization is not None if not quantized_compute and (quantize_forward or quantize_backward): pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1699,7 +1728,7 @@ def test_swiglu( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1767,7 +1796,7 @@ def test_dropout( # Skip invalid configurations quantized_input = quantization is not None - maybe_skip_quantization(quantization, dims=shape, device=device) + maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype) # Random data # Note: Shift values to make sure inputs are non-zero @@ -1858,7 +1887,7 @@ def test_forward_linear_bias_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if dtype not in (torch.float16, torch.bfloat16): pytest.skip( @@ -1929,7 +1958,7 @@ def test_forward_linear_bias_activation( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1965,7 +1994,7 @@ def test_forward_linear_bias_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2040,7 +2069,7 @@ def test_forward_linear_bias_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2078,7 +2107,7 @@ def test_forward_linear_scale_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2146,7 +2175,7 @@ def test_forward_linear_scale_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2179,7 +2208,7 @@ def test_backward_activation_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0): pytest.skip("Unsupported tensor size for MXFP8") @@ -2241,7 +2270,7 @@ def test_backward_activation_bias( # Expected numerical error tols = dtype_tols(dtype) if with_quantization: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2360,7 +2389,7 @@ def test_backward_linear_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2428,7 +2457,7 @@ def test_backward_linear_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y1_test = y1_test.to(dtype=torch.float64, device="cpu") @@ -2463,7 +2492,7 @@ def test_backward_linear_scale( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2523,7 +2552,7 @@ def test_backward_linear_scale( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2564,7 +2593,7 @@ def test_linear( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) # Construct model @@ -2690,7 +2719,7 @@ def test_layernorm_mlp( ffn_shape = in_shape[:-1] + (ffn_hidden_size,) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=ffn_shape, device=device) quantization_needed = quantized_compute or quantized_weight if quantization is None and quantization_needed: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 9a51c53e35..004abfd977 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -19,6 +19,7 @@ fp8_model_init, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear from transformer_engine.pytorch.distributed import fp8_autocast @@ -499,3 +500,39 @@ def test_quantizer_update(self, module_class): y = module(x, [batch_size]) else: y = module(x) + + +fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available() + + +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # # largest tile + (8192, 8192), + ], +) +def test_fp4_dequantize(dtype, M, N): + q = NVFP4Quantizer() + a = torch.rand((M, N)).cuda().to(dtype=dtype) + starting_tensor = q(a) + dequantized_tensor = starting_tensor.dequantize() + new_tensor = q(dequantized_tensor) + torch.testing.assert_close( + new_tensor._rowwise_data, + starting_tensor._rowwise_data, + rtol=0, + atol=0, + ) + new_dequantized_tensor = new_tensor.dequantize() + torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 5151aa96e7..981c582430 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -87,9 +87,19 @@ def is_fp8_supported(config: ModelConfig): "large": ModelConfig(2, 128, 4, 128, num_layers=1), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -379,6 +389,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -407,6 +419,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) @@ -437,6 +451,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): @@ -476,6 +492,8 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4(): + pytest.skip("NVFP4 not supported for grouped linear") use_fp8 = fp8_recipe is not None with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): @@ -526,6 +544,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -568,6 +588,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -629,6 +651,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -683,6 +707,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -734,6 +760,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -764,6 +792,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -798,6 +828,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -832,6 +864,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 9e90f9fdad..d77256b7f9 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -73,6 +73,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: # Transformer Engine dtypes if isinstance(dtype, tex.DType): + if dtype == tex.DType.kFloat4E2M1: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.25 dtype = { tex.DType.kByte: torch.uint8, tex.DType.kInt32: torch.int32, @@ -95,10 +97,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: if dtype == torch.float8_e4m3fn: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 if dtype == torch.float8_e5m2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") +def quantization_tols(name: str) -> dict[str, float]: + """Estimated numerical error for a quantization scheme""" + if name in ( + "fp8", + "fp8_delayed_scaling", + "fp8_current_scaling", + "mxfp8", + "mxfp8_block_scaling", + ): + return dtype_tols(tex.DType.kFloat8E4M3) + if name == "nvfp4": + return dtype_tols(tex.DType.kFloat4E2M1) + raise ValueError(f"Unsupported quantization scheme ({name})") + + def make_recipe(name: Optional[str]) -> Optional[Recipe]: """Make recipe for quantization scheme""" if name is None: @@ -118,6 +135,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() + if name == "nvfp4": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + ) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 08e876404c..a4915080e8 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -53,6 +53,28 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +# NVIDIA MathDX include directory (from Python package install location) +if(NOT DEFINED MATHDX_INCLUDE_DIR) + execute_process( + COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx + OUTPUT_VARIABLE _PIP_SHOW_MATHDX + ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR + RESULT_VARIABLE _PIP_SHOW_MATHDX_RES + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0) + message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}") + endif() + string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}") + if(NOT _MATHDX_LOC_MATCH) + message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}") + endif() + set(MATHDX_LOCATION "${CMAKE_MATCH_1}") + set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") +endif() +if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") + message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") +endif() + # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) @@ -73,6 +95,7 @@ list(APPEND transformer_engine_SOURCES transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu transpose/swap_first_dims.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu activation/gelu.cu dropout/dropout.cu fused_attn/flash_attn.cu @@ -85,6 +108,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp fused_attn/utils.cu + gemm/config.cpp gemm/cublaslt_gemm.cu gemm/cutlass_grouped_gemm.cu normalization/common.cpp @@ -113,6 +137,9 @@ list(APPEND transformer_engine_SOURCES recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu + recipe/nvfp4.cu + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu @@ -144,7 +171,8 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 8b7f92aff9..666f57188d 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -39,6 +39,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { return CUDA_R_8F_E4M3; case DType::kFloat8E5M2: return CUDA_R_8F_E5M2; +#if CUDA_VERSION >= 12080 + case DType::kFloat4E2M1: + return CUDA_R_4F_E2M1; +#endif default: NVTE_ERROR("Invalid type"); } @@ -160,7 +164,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits) { + const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle) { + cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { @@ -169,6 +175,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, }(); // rank is the number of dimensions of the array constexpr uint32_t rank = 2; + + // Dimension for the packed data types must reflect the number of individual U# values. uint64_t size[rank] = {globalX, globalY}; // The stride is the number of bytes to traverse from the first element of one row to the next @@ -207,7 +215,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, // Swizzling can be used to avoid shared memory bank conflicts. - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + swizzle, // L2 Promotion can be used to widen the effect of a cache-policy to a wider // set of L2 cache lines. diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index e2a3c52aa2..bddd9bf194 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -48,8 +48,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } +inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + +inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } +inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -108,6 +114,7 @@ struct Tensor { SimpleTensor data; SimpleTensor columnwise_data; SimpleTensor amax; + SimpleTensor columnwise_amax; SimpleTensor scale; SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; @@ -119,6 +126,7 @@ struct Tensor { : data(), columnwise_data(), amax(nullptr, {1}, DType::kFloat32), + columnwise_amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), @@ -129,6 +137,7 @@ struct Tensor { data.clear(); columnwise_data.clear(); amax.clear(); + columnwise_amax.clear(); scale.clear(); scale_inv.clear(); columnwise_scale_inv.clear(); @@ -174,6 +183,7 @@ struct Tensor { * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). */ switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: case NVTE_DELAYED_TENSOR_SCALING: if (!has_data() && has_columnwise_data()) { std::vector ret; @@ -189,7 +199,6 @@ struct Tensor { } break; case NVTE_MXFP8_1D_SCALING: - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; } else { @@ -261,12 +270,18 @@ struct QuantizationConfig { NVTETensor noop_tensor = nullptr; Float8BlockScaleTensorFormat float8_block_scale_tensor_format = Float8BlockScaleTensorFormat::GEMM_READY; + NVTETensor rng_state = nullptr; + bool nvfp4_2d_quantization = false; + bool stochastic_rounding = false; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float), // amax_epsilon - sizeof(NVTETensor), // noop_tensor - sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor), // noop_tensor + sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format + sizeof(NVTETensor), // rng_seed and offset + sizeof(bool), // nvfp4_2d_quantization + sizeof(bool) // stochastic_rounding }; }; @@ -298,6 +313,8 @@ using fp8e8m0 = __nv_fp8_e8m0; #endif #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif using e8m0_t = uint8_t; @@ -334,17 +351,20 @@ struct TypeExtrema; template <> struct TypeExtrema { static constexpr float max = 6.0f; + static constexpr float max_inverse = 1.0 / max; }; #endif template <> struct TypeExtrema { static constexpr float max = 448.0f; + static constexpr float max_inverse = 1.0 / max; }; template <> struct TypeExtrema { static constexpr float max = 57344.0f; + static constexpr float max_inverse = 1.0 / max; }; template <> @@ -558,6 +578,18 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +// Add a pack_size argument to select the packed type for FP4 +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat4E2M1: { \ + using type = __nv_fp4x2_storage_t; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -717,10 +749,11 @@ void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. -void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, - const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, - const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits); +void create_2D_tensor_map( + CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, + const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, + const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); bool is_supported_by_CC_100(); diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp new file mode 100644 index 0000000000..cf211beaf9 --- /dev/null +++ b/transformer_engine/common/gemm/config.cpp @@ -0,0 +1,116 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "./config.h" + +#include +#include + +#include + +#include "../util/logging.h" + +NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; } + +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(buf, &config_.bias_tensor, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(buf, &config_.dbias_tensor, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(buf, &config_.with_gelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(buf, &config_.use_split_accumulator, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(buf, &config_.sm_count, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(&config_.bias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(&config_.dbias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(&config_.with_gelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(&config_.use_split_accumulator, buf, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(&config_.sm_count, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_matmul_config(NVTEMatmulConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h new file mode 100644 index 0000000000..54ccf06a53 --- /dev/null +++ b/transformer_engine/common/gemm/config.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_ +#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_ + +#include + +namespace transformer_engine { + +struct MatmulConfig { + NVTETensor bias_tensor = nullptr; + NVTETensor dbias_tensor = nullptr; + bool with_gelu_epilogue = false; + bool with_dgelu_epilogue = false; + NVTETensor epilogue_aux_tensor = nullptr; + bool use_split_accumulator = false; + int sm_count = 0; + + static constexpr size_t attr_sizes[] = { + sizeof(NVTETensor), // bias_tensor + sizeof(NVTETensor), // dbias_tensor + sizeof(bool), // with_gelu_epilogue + sizeof(bool), // with_dgelu_epilogue + sizeof(NVTETensor), // epilogue_aux_tensor + sizeof(bool), // use_split_accumulator + sizeof(int) // sm_count + }; +}; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f287072bcb..ab80fe7698 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -9,20 +9,55 @@ #include #include #include +#include #include +#include #include #include +#include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/handle_manager.h" #include "../util/logging.h" #include "../util/multi_stream.h" -#include "common/util/cuda_runtime.h" -#include "cutlass_grouped_gemm.cuh" +#include "./config.h" +#include "./cutlass_grouped_gemm.cuh" namespace { +/* Use CUDA const memory to store scalar 1 and 0 for cublas usage +*/ +__device__ __constant__ float one_device; +__device__ __constant__ float zero_device; + +inline float *GetScalarOne() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float one = 1.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), one_device)); + return dev_ptr; +} + +inline float *GetScalarZero() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float zero = 0.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), zero_device)); + return dev_ptr; +} + +__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; } + uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; @@ -82,6 +117,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool is_A_transposed = transA == CUBLAS_OP_T; bool is_B_transposed = transB == CUBLAS_OP_T; + // Set conditions for MXFP8 and NVFP4 gemm execution. + const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); + const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { // Unscaled or FP8 tensor scaling @@ -102,10 +141,26 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } - } else if (is_mxfp_scaling(A.scaling_mode)) { - // MXFP8 + } else if (nvfp4) { + // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. + + if (is_A_transposed) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode), + "Input A has unsupported combination of recipe and layout"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); + } + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; // NVFP4 gemm is only supported in TN layout. + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + } else if (mxfp8) { + // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe. // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { @@ -161,10 +216,20 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } } - } else if (is_mxfp_scaling(B.scaling_mode)) { - // MXFP8 - // Note: Row-wise and column-wise data are scaled along different - // dimensions (with matrix interpreted in row-major order). + } else if (nvfp4) { + if (is_B_transposed) { + NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode), + "Input B has unsupported combination of recipe and layout"); + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; // NVFP4 gemm is only supported in TN layout. + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (mxfp8) { if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { @@ -221,7 +286,7 @@ using cublasHandleManager = detail::HandleManageramax.dptr != nullptr || inputB->amax.dptr != nullptr)) { + // Reserve some workspace for alpha scale + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + uint8_t *workspace_ptr = reinterpret_cast(workspace); + float *new_alpha_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + + // Update alpha scale on device + // Note: Compute NVFP4 tensor scales based on amaxes and then + // divide from alpha scale. This way we only need to apply NVFP4 + // tensor scales in matmul output, instead of in matmul inputs. + float old_alpha = *reinterpret_cast(alpha); // Assumed to be on CPU + TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector{1}, DType::kFloat32); + nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb, + old_alpha, new_alpha_tensor.data(), stream); + alpha = new_alpha_ptr; + + // Make sure beta scale is on device + float old_beta = *reinterpret_cast(beta); // Assumed to be on CPU + if (old_beta == 0) { + beta = GetScalarZero(); // Device constant memory + } else if (old_beta == 1) { + beta = GetScalarOne(); // Device constant memory + } else { + // Move beta to workspace + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + float *new_beta_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta); + NVTE_CHECK_CUDA(cudaGetLastError()); + beta = new_beta_ptr; + } + } const cudaDataType_t A_type = get_cuda_dtype(param.Atype); const cudaDataType_t B_type = get_cuda_dtype(param.Btype); @@ -270,16 +378,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "FP8 input to GEMM requires inverse of scale!"); NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); // check consistency of arguments: // if fp8 is desired, context cannot be null // fp8 + gelu fusion + fp8 aux is unavailable right now. - if (use_fp8 && gelu) { + if ((use_fp8 || use_fp4) && gelu) { NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), "fp8 Aux output for gemm + gelu fusion not supported!"); } - if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!"); + if (is_fp4_dtype(outputD->data.dtype)) { + NVTE_ERROR("FP4 GEMM output is not supported!"); + } + if (use_fp4 && (D_type == CUDA_R_16F)) { + NVTE_ERROR("FP4 GEMM does not support FP16 output!"); } cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); @@ -319,12 +434,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &math_sm_count, sizeof(math_sm_count))); } - // set fp8 attributes -- input and output types should already be set to fp8 as appropriate - // Note: gelu fusion isn't available right now, and we don't need + // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4 + // as appropriate. Note: gelu fusion isn't available right now, and we don't need // amax(D) either (next op is high precision). - if (use_fp8) { - // Split accumulator. - const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; + const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + + if (use_fp8 || use_fp4) { + // Fast accumulation is only supported for FP8. + const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); @@ -333,7 +450,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, cublasLtMatmulMatrixScale_t scaling_mode_a; cublasLtMatmulMatrixScale_t scaling_mode_b; #endif // CUBLAS_VERSION >= 120800 - if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { + if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) { void *A_scale_inverse = param.A_scale_inv; void *B_scale_inverse = param.B_scale_inv; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -346,7 +463,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; #endif // CUBLAS_VERSION >= 120800 - } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { + } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cublas_version() >= 120800, "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); @@ -371,6 +488,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #else NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= 120800 + } else if (use_fp4) { // NVFP4 GEMM +#if CUBLAS_VERSION >= 120800 + NVTE_CHECK(cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE + cublasDataType_t scale_type = CUDA_R_32F; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + + // Set pointer mode: alpha and beta are both device pointers + // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); + + fp8e4m3 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e4m3 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; +#else + NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); #endif // CUBLAS_VERSION >= 120800 } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && @@ -503,14 +648,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif -#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ - CUBLAS_VERSION < 130000 +#else NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); @@ -565,16 +707,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ - param.A, /* A */ - Adesc, param.B, /* B */ - Bdesc, static_cast(&beta), /* beta */ - C, /* C */ - Cdesc, D, /* D */ - Ddesc, &heuristicResult.algo, /* algo */ - workspace, /* workspace */ - workspaceSize, stream)); /* stream */ + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */ + param.A, /* A */ + Adesc, param.B, /* B */ + Bdesc, beta, /* beta */ + C, /* C */ + Cdesc, D, /* D */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. @@ -600,35 +741,117 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Scales + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, - nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); +} + +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm_v2); + using namespace transformer_engine; + + // Data tensors + const Tensor *A_tensor = convertNVTETensorCheck(A); + const Tensor *B_tensor = convertNVTETensorCheck(B); + const Tensor *C_tensor = convertNVTETensorCheck(C); + Tensor *D_tensor = convertNVTETensorCheck(D); + NVTE_CHECK(C_tensor == D_tensor, + "Currently nvte_cublas_gemm_v2 does not support different C and D tensors."); + + // Workspace + void *workspace_ptr = nullptr; + size_t workspace_size = 0; + Tensor *workspace_tensor = convertNVTETensor(workspace); + if (workspace_tensor != nullptr) { + workspace_ptr = workspace_tensor->data.dptr; + workspace_size = + get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype); + } + + // Additional config + MatmulConfig config_; + if (config != nullptr) { + config_ = *reinterpret_cast(config); + } + + // Configure GEMM epilogue + const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue); + if (with_grad_epilogue) { + NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue, + "Invalid epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ")."); + } + Tensor dummy_tensor; + Tensor *epilogue_bias_tensor = &dummy_tensor; + if (!with_grad_epilogue && config_.bias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor); + } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor); + } + Tensor *epilogue_aux_tensor = &dummy_tensor; + if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) { + NVTE_CHECK(config_.epilogue_aux_tensor != nullptr, + "Requested epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor."); + epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor); + } + + // Launch GEMM + cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor, + transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N, + with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta, + config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { - NVTE_API_CALL(nvte_cublas_gemm_scaled); + NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -639,17 +862,14 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_atomic_gemm); using namespace transformer_engine; - - // Check CUDA and cuBLAS versions #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif +#else NVTE_CHECK( transformer_engine::cuda::cudart_version() >= 12020 && transformer_engine::cuda::cudart_version() < 13000, @@ -668,13 +888,17 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor const Tensor *inputCounter = convertNVTETensor(counter); Tensor *wspace = convertNVTETensor(workspace); + const void *alpha_ptr = GetScalarOne(); + const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero(); + NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, - n_split, gemm_producer, inputCounter, stream); + alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split, + gemm_producer, inputCounter, stream); +#endif } void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, @@ -695,9 +919,30 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens } for (int i = 0; i < num_gemms; i++) { - nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams)); + // Check whether GELU or dGELU epilogue is requested + Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]); + bool with_gelu_dgelu_epilogue = + (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr); + + // Construct config + MatmulConfig config; + if (grad) { + config.dbias_tensor = bias[i]; + config.with_dgelu_epilogue = with_gelu_dgelu_epilogue; + } else { + config.bias_tensor = bias[i]; + config.with_gelu_epilogue = with_gelu_dgelu_epilogue; + } + config.epilogue_aux_tensor = pre_gelu_out[i]; + config.use_split_accumulator = use_split_accumulator; + config.sm_count = math_sm_count; + + // Launch GEMM + const float alpha = 1.f; + const float beta = accumulate ? 1.f : 0.f; + nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i], + workspace[i % num_streams], &config, + detail::get_compute_stream(i % num_streams)); } // record events on compute streams diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu new file mode 100644 index 0000000000..9d4bec41d5 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -0,0 +1,876 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace { + +constexpr int kThreadsPerWarp = 32; +constexpr float k16x16HadamardScale = 0.25f; + +template +__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr) { + auto smem_addr = static_cast(__cvta_generic_to_shared(addr)); + if constexpr (kTranspose) { + asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } else { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } +} + +template +__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr, uint32_t stride) { + if constexpr (kTranspose) { + asm volatile( + "wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } else { + asm volatile( + "wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } +} + +template +__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, void* addr, + uint32_t stride) { + if constexpr (kTranspose) { + asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } else { + asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } +} + +__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) { + asm volatile( + "movmatrix.sync.aligned.m8n8.trans.b16 " + "%0, %1;\n\t" + : "=r"(a0) + : "r"(a0)); +} + +__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) { + __nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16); + float f_a = __bfloat162float(bf16x2.x); + float f_b = __bfloat162float(bf16x2.y); + asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b)); + float_dst = fabsf(float_dst); +} + +template +__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc( + uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1, + uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3, + uint32_t& amax_result) { + uint32_t zero = 0; + uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; + asm volatile( + "wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n" + "{%0, %1, %2, %3, %4, %5, %6, %7}, \n" + "{%8, %9, %10, %11}, \n" + "{%12, %13, %14, %15}, \n" + "{%16, %17, %18, %19, %20, %21, %22, %23};\n\t" + : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6), + "=r"(temp7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero), + "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6)); + if constexpr (kCalculateAmax) { + uint32_t max_even; + uint32_t max_odd; + // Reduction tree to amax(abs(result)) into bf16x2 reg outparam. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2)); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3)); + // N.B. mma is only called up to once per thread for identity and transpose respectively, so + // we don't have to accumulate into amax_result and can directly store into it. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(amax_result) + : "r"(max_even), "r"(max_odd)); + } +} + +template +__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i, + uint16_t random_sign_mask, + uint32_t* had_frag_t, + uint16_t random_sign_mask_t) { + int32_t tid = threadIdx.x % 32; // Local tid + float temp_i[2]; + float temp_t[2]; +#pragma unroll + for (int i = 0; i < 2; i++) { + // i is the vertical fragment index. + // For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals. + uint32_t r = i * 8 + tid / 4; + +#pragma unroll + for (int j = 0; j < 2; j++) { +#pragma unroll + for (int k = 0; k < 2; k++) { + // k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits. + // j is the column fragment idx selecting between even and odd fragments. + // j increments 8 columns by switching fragments. + uint32_t c = j * 8 + k + tid % 4 * 2; + // 1 -> -1.0f, 0 -> 1.0f + int32_t base_sign = __popc(r & c); + if constexpr (kReturnIdentity) { + int32_t sign_i; + // Because tensor cores want the dot product dimension, + // contiguous, the regular, non-inverse hadamard swaps + // signs of columns and rows for inverse. In a simple reference, + // x.reshape(-1, 16) @ sign @ H16, this would be opposite but + // (sign @ H16) is transposed in this fragment. + if constexpr (kInverseHadamardIdentity) { + sign_i = ((random_sign_mask >> r) ^ base_sign); + } else { + sign_i = ((random_sign_mask >> c) ^ base_sign); + } + temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31)); + } + if constexpr (kReturnTransposed) { + int32_t sign_t; + if constexpr (kInverseHadamardTransposed) { + sign_t = ((random_sign_mask_t >> r) ^ base_sign); + } else { + sign_t = ((random_sign_mask_t >> c) ^ base_sign); + } + temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31)); + } + } + + if constexpr (kReturnIdentity) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_i[i * 2 + j]) + : "f"(temp_i[1]), "f"(temp_i[0])); + } + if constexpr (kReturnTransposed) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_t[i * 2 + j]) + : "f"(temp_t[1]), "f"(temp_t[0])); + } + } + } +} + +__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx, + uint32_t gmem_col_idx) { + uint32_t smem_row_idx = gmem_row_idx; + uint32_t xor_factor = (smem_row_idx * 2) % 8; + uint32_t smem_col_idx = gmem_col_idx ^ xor_factor; + return smem_row_idx * 8 + smem_col_idx; +} + +template +__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], + IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + uint32_t& local_amax_reg, + uint32_t& local_amax_t_reg) { + uint32_t a_frag[4]; // A matrix fragment + uint32_t c_frag[4]; // Result fragment + + int warp_id = threadIdx.x / kThreadsPerWarp; + int local_rank = (threadIdx.x % kThreadsPerWarp); + + int ld_row_idx = local_rank % kHadamardDimension; + int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + + uint32_t temp_amax_reg; + uint32_t temp_amax_t_reg; + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } + + if (kReturnTransposedAmax) { + // TODO(Frank): This is not efficient, since we could directly load the + // matrix in transposed layout. + if (!kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], + b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_t_reg) + : "r"(local_amax_t_reg), "r"(temp_amax_t_reg)); + } + + if (kReturnPreRhtAmax) { + if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[1])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[2]) + : "r"(a_frag[2]), "r"(a_frag[3])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[2])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_pre_rht_amax_reg) + : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); + } +} + +template +__device__ __host__ constexpr int NextPowerOf2() { + static_assert(kN > 0, "kN must be > 0"); + // Round up to the next power of 2 by counting leading zeros. + return 1 << (32 - __builtin_clz(kN - 1)); +} + +template +__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax, + const float transpose_amax, float* staging_for_pre_rht, + float* staging_for_identity, float* staging_for_transpose, + float* output_pre_rht_amax_ptr, + float* output_identity_amax_ptr, + float* output_transpose_amax_ptr, const int warpid) { + // intra-warp reduction + constexpr int kWarpSize = 32; + int local_rank = threadIdx.x % 32; + float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max(pre_rht_amax) : 0.0f; + float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max(identity_amax) : 0.0f; + float warp_transpose_amax = + kReturnTransposedAmax ? warp_reduce_max(transpose_amax) : 0.0f; + + // inter-warp reduction + if (threadIdx.x % 32 == 0) { + if (kReturnPreRhtAmax) { + staging_for_pre_rht[warpid] = warp_pre_rht_amax; + } + if (kReturnIdentityAmax) { + staging_for_identity[warpid] = warp_identity_amax; + } + if (kReturnTransposedAmax) { + staging_for_transpose[warpid] = warp_transpose_amax; + } + } + __syncthreads(); + constexpr int kNumWarpsPow2 = NextPowerOf2(); + if (warpid == 0) { + if (kReturnIdentityAmax) { + float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f; + identity_accum = warp_reduce_max(identity_accum); + if (local_rank == 0) { + atomicMaxFloat(output_identity_amax_ptr, identity_accum); + } + } + } + if (warpid == 1) { + if (kReturnTransposedAmax) { + float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f; + transpose_accum = warp_reduce_max(transpose_accum); + if (local_rank == 0) { + atomicMaxFloat(output_transpose_amax_ptr, transpose_accum); + } + } + } + if (warpid == 2) { + if (kReturnPreRhtAmax) { + float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f; + pre_rht_accum = warp_reduce_max(pre_rht_accum); + if (local_rank == 0) { + atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum); + } + } + } +} + +__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr) { + if (output_pre_rht_amax_ptr != nullptr) { + *output_pre_rht_amax_ptr = 0; + } + if (output_identity_amax_ptr != nullptr) { + *output_identity_amax_ptr = 0; + } + if (output_transpose_amax_ptr != nullptr) { + *output_transpose_amax_ptr = 0; + } +} + +template +__global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + uint64_t num_rows, uint64_t row_length) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0); + static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0); + + constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y; + constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X; + + constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp; + + const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X; + + extern __shared__ __align__(128) char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uint8_t* dshmem = reinterpret_cast((base_shmem_ptr + 127) & ~127ULL); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + IType* in_sh_0 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + IType* in_sh_1 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + + IType* in_shs[2] = {in_sh_0, in_sh_1}; + + constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + + const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0); + + // Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + uint64_t* mbar = reinterpret_cast(dshmem); + dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y); + + float* max_staging_identity = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_transpose = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_pre_rht = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + + initialize_barriers(mbar, + is_master_thread); + + copy_2d_to_shared(in_shs[0], reinterpret_cast(&tensor_map_input), + input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0], + is_master_thread); + + uint32_t had_frag_i[4]; + uint32_t had_frag_t[4]; + get_hadamard_matrix_fragment( + had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t); + + float local_pre_rht_amax = 0.0; + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_pre_rht_amax_reg = *reinterpret_cast(&local_pre_rht_amax); + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { + for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { + int stage = STAGES_X * stage_y + stage_x; + + const int next_stage = stage + 1; + const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1; + const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y; + + if (next_stage < STAGES_X * STAGES_Y) { + const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y; + const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X; + + copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong + reinterpret_cast(&tensor_map_input), input_global_offset_X, + input_global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t compute_stage_x_num = + BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)); + const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y); + + const size_t in_row_stride = BUFF_DIM_X; + + IType* in_sh_ptr = in_shs[stage % 2]; + +#pragma unroll + for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) { + const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y + + threadIdx.y * kHadamardDimension); + const int in_row_offset = row_idx_offset * in_row_stride; + +#pragma unroll + for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) { + ComputeKernel( + had_frag_i, had_frag_t, + in_sh_ptr + in_row_offset + + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), + local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + } + + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + } + } + } + + const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp; + + if constexpr (kReturnPreRhtAmax) { + unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax); + } + if constexpr (kReturnIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + } + if constexpr (kReturnTransposedAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + } + + ReduceMax( + local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity, + max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr, + output_transpose_amax_ptr, warpid); + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restrict__ output, + T* __restrict__ output_t, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, uint64_t num_input_rows, + uint64_t num_input_cols, float* __restrict__ amax, + float* __restrict__ amax_t, bool inverse_hadamard) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + static_assert(kHadamardDimension == 16, "Currently only hadamard dimension 16 is supported."); + + // The whole threadblock will share the same smem. + extern __shared__ __align__(16) T smem[]; + + // Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16. + // If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices. + int32_t tid = threadIdx.x; + int32_t warp_id = threadIdx.y * blockDim.z + threadIdx.z; + int32_t local_bx = threadIdx.y; + int32_t local_by = threadIdx.z; + + // Define the register fragments + uint32_t a_frag[4]; // A matrix fragment + uint32_t b_frag_i[4]; // Transposed Hadamard matrix fragment, used for A @ B(col major) + uint32_t b_frag_t[4]; // Hadamard matrix fragment, used for A.T @ B.T(col major) + uint32_t c_frag[4]; // Result fragment + + // row and col for each thread. 32 threads will work together in 128 chunk to + // load the data from global memory to shared memory. + uint32_t row = tid / (kHadamardDimension * sizeof(T) / sizeof(uint4)); + uint32_t col = tid % (kHadamardDimension * sizeof(T) / sizeof(uint4)); + + uint32_t smem_index = tid; + + uint32_t input_start_col = (blockIdx.x * blockDim.y + local_bx) * kHadamardDimension; + uint32_t input_start_row = (blockIdx.y * blockDim.z + local_by) * kHadamardDimension; + + bool load = (input_start_col < num_input_cols) && (input_start_row < num_input_rows); + if (!load) { + // Out of bound, we are returning early. No thread divergence since the whole warp + // will return early. + return; + } + + uint64_t global_offset = input_start_col + input_start_row * num_input_cols; + uint64_t global_offset_t = + kOutputTrueTransposed ? (input_start_row + input_start_col * num_input_rows) : global_offset; + + T* base_smem = smem + kHadamardDimension * kHadamardDimension * warp_id; + + uint32_t* smem_b32 = reinterpret_cast(base_smem); + uint4* smem_b128 = reinterpret_cast(base_smem); + + // Asynchronously load the data from global memory to shared memory. + const uint4* input_b128 = reinterpret_cast(input + global_offset); + // Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each + // 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4 + // to load the data in the tensor core swizzled format. + __pipeline_memcpy_async(&smem_b128[smem_index], + &input_b128[row * num_input_cols / (sizeof(uint4) / sizeof(T)) + col], + sizeof(uint4)); + __pipeline_commit(); // Commit the memcpy. Wait when we are in the computation. + + if (inverse_hadamard) { + get_hadamard_matrix_fragment(b_frag_i, random_sign_mask, + b_frag_t, random_sign_mask_t); + } else { + get_hadamard_matrix_fragment( + b_frag_i, random_sign_mask, b_frag_t, random_sign_mask_t); + } + + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + __pipeline_wait_prior(0); + + __syncwarp(); // ensure all lanes finished their cp.async before reading smem + + // Load the A to a_frag. + if constexpr (kComputeIdentity) { + load_matrix_16x16_from_shared(a_frag[0], a_frag[1], a_frag[2], a_frag[3], smem_b32, + kHadamardDimension); + + // 16x16 @ 16x16 leveraging all threads in the warp. + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], local_amax_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnIdentity) { + uint4* output_b128 = reinterpret_cast(output + global_offset); + store_matrix_16x16_to_global(c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_b128, + num_input_cols); + } + } + + if constexpr (kComputeTransposed) { + if (kComputeIdentity) { + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + } else { + load_matrix_16x16_from_shared(a_frag[0], + a_frag[2], // NOTE: intentional index swapping + a_frag[1], // NOTE: intentional index swapping + a_frag[3], smem_b32, kHadamardDimension); + } + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], + // 2,1 is used if we are using movmatrix instruction. + // Thus loading the matrix in 2,1 order will just be normal. + // This is to be compatible with the movmatrix instruction. + a_frag[2], // NOTE: intentional index swapping for transpose purpose. + a_frag[1], // NOTE: intentional index swapping for transpose purpose. + a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], b_frag_t[3], c_frag[0], c_frag[1], + c_frag[2], c_frag[3], local_amax_t_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnTransposed) { + uint4* output_t_b128 = reinterpret_cast(output_t + global_offset_t); + store_matrix_16x16_to_global( + c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_t_b128, + kOutputTrueTransposed ? num_input_rows : num_input_cols); + } + } + + if constexpr (kUpdateIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + local_amax = warp_reduce_max(local_amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax = __shfl_sync(0xFFFFFFFF, local_amax, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax, local_amax); + } + } + if constexpr (kUpdateTransposeAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + local_amax_t = warp_reduce_max(local_amax_t); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax_t = __shfl_sync(0xFFFFFFFF, local_amax_t, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax_t, local_amax_t); + } + } +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 9.0+."); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +} + +} // namespace + +void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform); + + // Check tensors + // NOTE (frsun): This is non-intuitive, we are writing the result of + // transposed RHT to the output of rowwise. + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + NVTE_CHECK(output_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be simple tensor, but scaling mode is ", + to_string(output_.scaling_mode), "."); + const SimpleTensor& input = input_.data; + SimpleTensor output; + SimpleTensor& output_t = output_.data; + + // Check requested outputs + const bool return_identity = output.dptr != nullptr; + const bool return_transposed = output_t.dptr != nullptr; + if (!return_identity && !return_transposed) { // Nothing to do/ill-defined behavior. + return; + } + + checkCuDriverContext(stream); + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + using IType = bf16; + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kThreadBlockX = 4; + // Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth. + constexpr uint64_t kThreadBlockY = 4; + + uint64_t kNumWarpsPerSM = kThreadBlockX * kThreadBlockY; + + // The shared memory number of bytes required for **the whole threadblock**. + size_t shmem_bytes = kHadamardDimension * kHadamardDimension * sizeof(IType) * kNumWarpsPerSM; + + dim3 block(kThreadsPerWarp, kThreadBlockX, kThreadBlockY); + + dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX), + DIVUP(num_rows / kHadamardDimension, kThreadBlockY)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed, kReturnTransposed, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + auto kernel = + HadamardTransformKernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes); + + kernel<<>>( + reinterpret_cast(input.dptr), reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), random_sign_mask, random_sign_mask_t, + num_rows, row_length, nullptr, nullptr, false););); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then +// get the absolute max value of the result. +void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_amax); +#if CUDA_VERSION >= 12080 + + // Check input tensor + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor& input = input_.data; + + // Check amax tensors + SimpleTensor& output_pre_rht_amax = output_.amax; + SimpleTensor output_identity_amax; + SimpleTensor& output_transpose_amax = output_.columnwise_amax; + + // Check requested outputs + const bool return_pre_rht_amax = output_pre_rht_amax.dptr != nullptr; + const bool return_identity_amax = output_identity_amax.dptr != nullptr; + const bool return_transposed_amax = output_transpose_amax.dptr != nullptr; + if (!return_identity_amax && !return_transposed_amax && + !return_pre_rht_amax) { // Nothing to do/ill-defined behavior. + return; + } + + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kChunkBlockXSmall = 128; + constexpr uint64_t kChunkBlockYSmall = 128; + constexpr uint64_t kBuffDimX = 64; + constexpr uint64_t kBuffDimY = 64; + + alignas(64) CUtensorMap tensor_map_input{}; + + create_2D_tensor_map( + /*tensorMap=*/tensor_map_input, + /*tensor=*/input, + /*globalY=*/num_rows, + /*globalX=*/row_length, + /*shmemY=*/kBuffDimY, + /*shmemX=*/kBuffDimX, + /*stride_elems=*/row_length, + /*offset_elems=*/0, + /*type_num_bits=*/sizeof(IType) * 8, + /*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B); + + constexpr uint64_t kThreadBlockX = 4; + constexpr uint64_t kThreadBlockY = 1; + constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY; + + dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); + + dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed_amax, kReturnTransposedAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity_amax, kReturnIdentityAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_pre_rht_amax, kReturnPreRhtAmax, + + // *2 for ping-pong + size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType); + size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) * + (kChunkBlockYSmall / kBuffDimY); + size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3; + // Add padding in case shmem ptr is not aligned to 128 bytes. + shmem_bytes = (shmem_bytes + 128); + + auto kernel = HadamardAmaxTmaKernel< + IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY, + kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax, + kReturnIdentityAmax, kReturnTransposedAmax>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_bytes); + + kernel<<>>( + tensor_map_input, reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, + random_sign_mask_t, num_rows, row_length);))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", + CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace transformer_engine + +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform); + using namespace transformer_engine; + hadamard_transform(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} + +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_amax); + using namespace transformer_engine; + hadamard_transform_amax(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..ce191b5ffd --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -0,0 +1,841 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "curanddx.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/helper_cuda.hpp" +#include "cutlass/util/print_error.hpp" + +// clang-format off + +namespace transformer_engine { +namespace detail { +namespace { + +// Define a cuRANDDx descriptor +// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. +// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., +// if shared memory, if needed, is enough for the described problem, usually not applicable. + +// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread()); + + +using namespace cute; +using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +// calculate the global encode scale factor for a given global amax. +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float kFP8E4M3Max = 448.0f; + constexpr float kFP4E2M1Max = 6.0f; + // If scale is infinity, return max value of float32 + float global_encode_scale = cutlass::minimum_with_nan_propagation{}( + kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); + // If global amax is 0 or infinity, return 1 + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +struct SharedStorage { + static constexpr int AccumulatorPipelineStageCount = 16; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + +}; + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverterBase(cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + auto output_ptr = reinterpret_cast(&output); + asm volatile( \ + "{\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ + "}" \ + : "=h"(output_ptr[0]), + "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), + "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), + "r"(rbits[0]), "r"(rbits[1])); +#else + NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return output; +} + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const *rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = reinterpret_cast const *>(rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +__global__ static +void +rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, + TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TC * C, CStride dC, CSmemLayout , + TSFC * SFC, + TiledMMA mma, + float const* global_amax, + const size_t* rng_state) +{ + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static constexpr uint32_t kTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + + static constexpr int kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + static constexpr int AccumulatorPipelineStageCount = 16; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static constexpr int VectorSize = 16; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // Represent the full tensors + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16)); + Tensor mC = make_tensor(cute::subbyte_iterator(C), make_shape(M,N), dC); // (M,N) + + auto sfc_shape = make_shape( + M, + make_shape( make_shape(Int<16>{}, _4{}), N / 64 ) + ); + + auto sfc_stride = make_stride( + N / 16, + make_stride( make_stride(_0{}, _1{}), _4{} ) + ); + + auto sfc_layout = make_layout(sfc_shape, sfc_stride); + Tensor mSFC = make_tensor(make_gmem_ptr(SFC), sfc_layout); + + auto cluster_shape = Shape< _1, _1, _1>{}; + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + const int K_TILE_MAX = min(N, K) / 64; + uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); + uint32_t tiles_in_n = (N + 64 - 1) / 64; + uint32_t linear_tile_idx = blockIdx.x; + uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; + uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + + + auto mainloop_tiler = Shape<_128,_16,_64>{}; + auto epilogue_tiler = Shape<_128,_64,_64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + // Allocate SMEM + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + + // + // MMA: Define C accumulators and A/B partitioning + // + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); + + + using TiledMmaEpilogue = decltype(mma_epilogue); + Tensor tCgA = thr_mma.partition_A(gA_mk); + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler)); + + auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma, + Int{})); + + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue, + Int{})); + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = tma_partition(tma_load_a, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition(tma_load_b, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); + + if (is_epilogue_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(global_amax)); + } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + if (is_dma_warp) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + } + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + bool is_first_wave = linear_tile_idx == blockIdx.x; + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_,tile_idx_m,_); + int k_tile = 0; + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + + + CUTE_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { + int k_tile_idx_n = tile_idx_n + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage)); + } + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ) + { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_,_,_,read_stage); + auto tCrB_nk = tCrB(_,_,0,0); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) + { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTE_UNROLL + for (int i = 0; i < 4; i++) { + auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (is_epilogue_warp) { + const float global_amax_val = *global_amax; + static constexpr int FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int thread_idx = threadIdx.x % 128; + + Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N) + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto thr_r2g = tiled_r2g.get_slice(thread_idx); + + // NVFP4 non-E8 recipe constants and global scales + static constexpr float fp4_max = 6.0f; + + const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + const float global_decode_scale = 1.0f / global_encode_scale; + auto sfd_converter = cutlass::NumericConverter{}; + + do { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { + Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,tile_idx_n+k_tile); + + Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,tile_idx_n+k_tile); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrC = make_tensor(shape(tDgC)); + Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tDrC_frag = recast>(coalesce(tDrC)); + + Tensor src = thr_r2g.retile_S(tDrC); + Tensor dst = thr_r2g.retile_D(tDgC); + + Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout( + make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}) + )); + + Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); + Tensor tDrSFC = make_tensor(shape(tDgSFC)); + + static constexpr int NumVecs = size(tDgC) / VectorSize; + Tensor tC_rRowSFD_frg = recast>(tDrSFC); + + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtC, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + + ++accumulator_pipe_consumer_state; + + // Cast data from FP32 to BF16 to FP32. + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + + tC_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + + // Initialize RNG for tile + const size_t rng_sequence + = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = uint4{0, 0, 0, 0}; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); + // auto acc_scale = acc_scales[v]; + if constexpr (kEnableStochasticRounding) { + random_uint4 = dist.generate4(rng); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], + acc_scale + ), + reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}(cutlass::multiplies>{}(compute_frgs[v], acc_scale)); + } + } + + copy(tiled_r2g, src, dst); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); + + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + } +} + +// this function computes RHT-GEMM for +// A: m x n: col-major +// B: 16 x 16: row-major +// C: m x n: row-major +// SFC: m x (n/16): row-major +template +void +rht_gemm_ntt_w_sfc(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 2048) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = static_cast(m); + auto N = static_cast(n); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, m); // (dM,dK) + auto dB = make_stride(Int<1>{}, 16); // (dN,dK) + auto dC = make_stride(n, Int<1>{}); // (dM,dN) + + auto cga_shape = Shape< _1, _1, _1>{}; + auto cga_tile_shape = Shape<_128,_16,_16>{}; + auto cluster_tile_mainloop = Shape<_128,_16,_64>{}; + + // Construct the MMA + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + + // MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never} + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + + auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes + constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); + constexpr int kReservedBytes = 256; // Reserve for barriers and other uses + constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) + auto sC = Layout<_1>{}; // XXX Dummy + + // Create GMEM tensors + Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N) + Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16) + + // Create the TiledCopy + + auto tma_load_a = make_tma_copy_A_sm100( + SM90_TMA_LOAD{}, + tensorA, + sA(_,_,_,0), + cluster_tile_mainloop, + mma); + auto tma_load_b = make_tma_copy_B_sm100( + SM90_TMA_LOAD{}, + tensorB, + sB(_,_,_,0), + cga_tile_shape, + mma); + + // Assert checks on tile sizes -- no predication + NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, + "Inner dimension must be divisible by ", static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); + NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, + "Outer dimension must be divisible by ", 4 * static_cast(size<1>(cga_tile_shape)), + " but got ", N, "."); + + uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size)); + + tiles = (tiles < sm_count) ? tiles : sm_count; + + dim3 dimBlock(256); + dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); + dim3 dimGrid(tiles, 1, 1); + + int smem_size = sizeof(SharedStorage); + auto* kernel_ptr = &rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), + TA, decltype(dA), decltype(sA), decltype(tma_load_a), + TB, decltype(dB), decltype(sB), decltype(tma_load_b), + TC, decltype(dC), decltype(sC), + TSFC, + decltype(mma), + kEnableStochasticRounding>; + + bool status = cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (status != cudaSuccess) { + std::cerr << "Error: Failed to set Shared Memory size." << std::endl; + return; + } + (*kernel_ptr) + <<< dimGrid, dimBlock, smem_size, stream >>> + (M, N, k_tile_size, cga_tile_shape, + A, dA, sA, tma_load_a, + B, dB, sB, tma_load_b, + C, dC, sC, + SFC, + mma, global_amax, + rng_state); +} + +// this function is used to wrap the rht_gemm_ntt_w_sfc function +//to transpose the input tensor A +template +void +rht_gemm_ttt_wrapper(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) +{ + // in addition to transpose the input tensor A + // we also need to reshape m, n to at best + // ultilize as many SMs as possible while keeping + // a relatively large contiguous dimension. + // for example, after swapping m, n for transpose purposes, + // the input / output tensor shapes for RHT-GEMM are: + // A: n x m: col-major + // B: 16 x 16: row-major + // C: n x m: row-major + // SFC: n x (m/16): row-major + rht_gemm_ntt_w_sfc( + n, m, + A, B, C, + SFC, global_amax, + rng_state, + sm_count, stream, + k_tile_size); +} + +} // namespace +} // namespace detail + +// clang-format on + +void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, + const Tensor &hadamard_matrix_, + QuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise); + + // Check input and output tensors + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + SimpleTensor &global_amax = output_.amax; + SimpleTensor &output_t = output_.data; + SimpleTensor &scale_inv_t = output_.scale_inv; + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TC = cutlass::float_e2m1_t; + using TSFC = cutlass::float_ue4m3_t; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Hadamard matrix must be BF16 tensor, but scaling mode is ", + to_string(hadamard_matrix_.scaling_mode), "."); + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + if (m == 8192 && n == 5120) { + k_tile_size = 512; + } else if (m == 8192 && n == 10240) { + k_tile_size = 1024; + } else if (m == 8192 && n == 2560) { + k_tile_size = 1280; + } else if (m == 8192 && n == 11328) { + k_tile_size = 1024; + } else if (m == 8192 && n == 512) { + k_tile_size = 256; + } else if (m == 8192 && n == 3584) { + k_tile_size = 512; + } else if (m == 11328 && n == 8192) { + k_tile_size = 1024; + } else if (m == 5120 && n == 8192) { + k_tile_size = 512; + } else if (m == 10240 && n == 8192) { + k_tile_size = 1024; + } else if (m == 2560 && n == 8192) { + k_tile_size = 1280; + } else if (m == 512 && n == 8192) { + k_tile_size = 256; + } else if (m == 3584 && n == 8192) { + k_tile_size = 512; + } else if (m < 1024 || n < 1024) { + k_tile_size = 512; + } + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + detail::rht_gemm_ttt_wrapper( + /*m=*/m, + /*n=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*C=*/reinterpret_cast(output_t.dptr), + /*SFC=*/reinterpret_cast(scale_inv_t.dptr), + /*global_amax=*/reinterpret_cast(global_amax.dptr), + /*rng_state=*/rng_state, + /*sm_count=*/sm_count, + /*stream=*/stream, + /*k_tile_size=*/k_tile_size);); +} + +} // namespace transformer_engine + +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_cast_fusion_columnwise); + using namespace transformer_engine; + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + hadamard_transform_cast_fusion_columnwise( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0c358328b6..950014cc9b 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -15,9 +15,76 @@ #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus -/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. +/*! \brief Configuration for matrix multiplication. */ +typedef void *NVTEMatmulConfig; + +/*! \enum NVTEMatmulConfigAttribute + * \brief Type of option for matrix multiplication. + */ +enum NVTEMatmulConfigAttribute { + /*! Bias tensor + * + * If provided, the bias tensor is applied in the GEMM epilogue. + */ + kNVTEMatmulConfigBiasTensor = 0, + /*! Bias gradient tensor + * + * If provided, the bias gradient tensor will be filled in the GEMM epilogue. + */ + kNVTEMatmulConfigDBiasTensor = 1, + /*! Whether to compute GELU in GEMM epilogue. */ + kNVTEMatmulConfigWithGELUEpilogue = 2, + /*! Whether to compute GELU backward in GEMM epilogue. */ + kNVTEMatmulConfigWithDGELUEpilogue = 3, + /*! Auxilliary tensor for GEMM epilogue. + * + * For GELU, this will be filled with the GELU input. For GELU + * backward, this is expected to already be filled with the GELU + * input. + */ + kNVTEMatmulConfigEpilogueAuxTensor = 4, + /*! Whether to use split accumulator for FP8 GEMM. */ + kNVTEMatmulConfigUseSplitAccumulator = 5, + /*! Number of streaming multiprocessors to use in GEMM kernel. */ + kNVTEMatmulConfigSMCount = 6, + kNVTEMatmulConfigNumAttributes +}; + +/*! \brief Create a matrix multiplication configuration. */ +NVTEMatmulConfig nvte_create_matmul_config(); + +/*! \brief Query an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes); + +/*! \brief Destroy a matrix multiplication configuration. */ +void nvte_destroy_matmul_config(NVTEMatmulConfig config); + +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. + * + * Computes: + * - `D = alpha * op(A) * op(B) + beta * C` + * + * \param[in] transa Whether to transpose A matrix. + * \param[in] transb Whether to transpose B matrix. + * \param[in] alpha Scaling factor applied to matmul output. + * \param[in] A A matrix. + * \param[in] B B matrix. + * \param[in] beta Scaling factor applied to C matrix. + * \param[in] C C matrix. + * \param[out] D Output matrix. + * \param[in] workspace Workspace tensor. + * \param[in] config Additional configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream); + /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, - * allowing for using a scaling factor for the GEMM result and the accumulation input + * allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated) + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -133,14 +223,16 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, - bool transa, bool transb, bool grad, NVTETensor* workspace, +void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); #ifdef __cplusplus } // extern "C" -#endif +#endif // __cplusplus + +#ifdef __cplusplus /*! \namespace transformer_engine */ @@ -153,6 +245,89 @@ namespace transformer_engine { void nvte_cublas_handle_init(); +/*! \struct MatmulConfigWrapper + * \brief C++ wrapper for NVTEMatmulConfig. + */ +class MatmulConfigWrapper { + public: + MatmulConfigWrapper() : config_{nvte_create_matmul_config()} {} + + MatmulConfigWrapper(const MatmulConfigWrapper &) = delete; + MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete; + + MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~MatmulConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEMatmulConfig. + * + * \return NVTEMatmulConfig held by this MatmulConfigWrapper. + */ + operator NVTEMatmulConfig() const noexcept { return config_; } + + /*! \brief Set bias tensor. */ + void set_bias_tensor(NVTETensor bias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigBiasTensor, &bias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set bias gradient tensor. */ + void set_dbias_tensor(NVTETensor dbias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigDBiasTensor, &dbias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to compute GELU in GEMM epilogue. */ + void set_with_gelu_epilogue(bool with_gelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue, + &with_gelu_epilogue, sizeof(bool)); + } + + /*! \brief Set whether to compute GELU backward in GEMM epilogue. */ + void set_with_dgelu_epilogue(bool with_dgelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue, + &with_dgelu_epilogue, sizeof(bool)); + } + + /*! \brief Set auxilliary tensor for GEMM epilogue. */ + void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigEpilogueAuxTensor, + &epilogue_aux_tensor, sizeof(NVTETensor)); + } + + /*! \brief Set whether to use split accumulator for FP8 GEMM. */ + void set_use_split_accumulator(bool use_split_accumulator) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator, + &use_split_accumulator, sizeof(bool)); + } + + /*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */ + void set_sm_count(int sm_count) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int)); + } + + private: + /*! \brief Wrapped NVTEMatmulConfig. */ + NVTEMatmulConfig config_ = nullptr; +}; + } // namespace transformer_engine +#endif // __cplusplus + #endif // TRANSFORMER_ENGINE_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h new file mode 100644 index 0000000000..a0dd325da0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file hadamard_transform.h + * \brief Functions for Hadamard transforms. + */ + +#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ +#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Perform a randomized Hadamard transform on the input tensor. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the absolute maximum reduction on the input tensor with/without + * randomized hadamard transform. The rowwise result is the absolute maximum + * of the input tensor. The columnwise result is the absolute maximum of the + * input tensor transposed and applied randomized hadamard transformation. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the columnwise hadamard transform cast fusion. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] hadamard_matrix Hadamard matrix. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 2fc8c1095c..6e1e9dd7ac 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -122,6 +122,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, size_t start_offset, size_t block_len, const NVTEDType out_dtype, cudaStream_t stream); +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index dab4fcfe75..1a901ab82d 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -66,6 +66,7 @@ enum NVTETensorParam { kNVTEAmax = 3, /*!< Amax tensor */ kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTENumTensorParams }; @@ -88,10 +89,9 @@ enum NVTEScalingMode { */ NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_2D = 3, - /*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD), - and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). - */ - NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4, + /*! Single scale per block of 16 elements consecutive in either + * rowwise or columnwise direction */ + NVTE_NVFP4_1D_SCALING = 4, NVTE_INVALID_SCALING = 100 }; @@ -330,6 +330,12 @@ enum NVTEQuantizationConfigAttribute { * likely be refactored away in the future. */ kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3, + /*! RNG state (NVTETensor with 2 elements - seed and offset */ + kNVTEQuantizationConfigRNGState = 4, + /*! Whether to use 2D block scaling for NVFP4 */ + kNVTEQuantizationConfigNVFP42DQuantization = 5, + /*! Whether to enable stochastic rounding */ + kNVTEQuantizationConfigStochasticRounding = 6, kNVTEQuantizationConfigNumAttributes }; @@ -431,6 +437,15 @@ inline bool is_fp8_dtype(const DType t) { */ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } +/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16) + * + * Return true if TE datatype is high precision + * \param[in] DType TE Datatype of interest + */ +inline bool is_high_precision_dtype(const DType t) { + return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16; +} + /*! \struct TensorWrapper * \brief C++ wrapper for the NVTETensor class. */ @@ -566,6 +581,11 @@ class TensorWrapper { return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); } + template + TensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -590,6 +610,10 @@ class TensorWrapper { return get_parameter(kNVTEColumnwiseScaleInv); } + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEColumnwiseAmax); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -838,6 +862,24 @@ class QuantizationConfigWrapper { &format, sizeof(Float8BlockScaleTensorFormat)); } + /*! \brief Set stochastic rounding state */ + void set_rng_state(NVTETensor rng_state) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigRNGState, &rng_state, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to use 2D block scaling for NVFP4 */ + void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization, + &nvfp4_2d_quantization, sizeof(bool)); + } + + /*! \brief Set whether to use stochastic rounding */ + void set_stochastic_rounding(bool stochastic_rounding) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding, + &stochastic_rounding, sizeof(bool)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 398c0acbdd..5785fd2233 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -63,7 +63,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 82e360ed64..a3b05f7a29 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -49,7 +49,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index fc8d73a136..ea0287ef15 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -4,7 +4,6 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations -import warnings import os from enum import Enum from typing import Literal, Optional, Union, Callable, NamedTuple @@ -23,9 +22,12 @@ class _FormatHelper(NamedTuple): class Format(Enum): """ Supported FP8 formats. + Supported FP4 formats. Values ------ + E2M1 : + All FP4 tensors are in e2m1 format E4M3 : All FP8 tensors are in e4m3 format E5M2 : @@ -35,6 +37,7 @@ class Format(Enum): FP8 tensors in the backward pass are in e5m2 format """ + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) @@ -42,9 +45,13 @@ class Format(Enum): @dataclass(frozen=True) class MMParams: - """for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) - apply split accumulator or not, turning it on will increase accuracy but impact gemm performance, - so only turn it on for certain gemms + """Matrix multiplication options. + + Parameters + ---------- + use_split_accumulator : bool, default = `True` + Use FP8 fast accumulation on Hopper or Ada. For more details, + see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul. """ use_split_accumulator: bool = True @@ -55,10 +62,24 @@ class QParams: """Quantization parameters. power_2_scale: use power of 2 scale parameter amax_epsilon: optional minimum value of abs max + random_hadamard_transform: whether to use random hadamard transform + stochastic_rounding: whether to use stocastic rounding """ power_2_scale: bool = False amax_epsilon: float = 0.0 + random_hadamard_transform: bool = False + stochastic_rounding: bool = False + fp4_2d_quantization: bool = False + + def __repr__(self) -> str: + return ( + f"Qparams(\npower_2_scale={self.power_2_scale},\n" + f"amax_epsilon={self.amax_epsilon},\n" + f"random_hadamard_transform={self.random_hadamard_transform},\n" + f"stochastic_rounding={self.stochastic_rounding},\n" + f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + ) class Recipe: @@ -66,6 +87,10 @@ class Recipe: Base recipe class. """ + def nvfp4(self): + """Whether the given recipe is NVFP4 1D block scaling.""" + return isinstance(self, NVFP4BlockScaling) + def mxfp8(self): """Whether the given recipe is MXFP8 block scaling.""" return isinstance(self, MXFP8BlockScaling) @@ -351,3 +376,84 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) + + +@dataclass() +class NVFP4BlockScaling(Recipe): + """ + Use the NVFP4 scaling strategy. + + This is a 2-level block scaling strategy. In level 1, each group of + 16 consecutive values is scaled together using their own scaling + factor. The type of the scaling factor is E4M3 (4 bits of exponent, + 3 bits of mantissa). In level 2, a global per tensor FP32 scaling + factor is used to scale the entire tensor. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), in this recipe the quantized tensor and its transpose + are not numerically equivalent. Due to this, when Transformer Engine + needs both the tensor and its transpose (e.g. to calculate both + forward and backward pass), during the quantization both versions are + computed from the high precision input to avoid double quantization + errors. + + Parameters + ---------- + fp4_format : {Format.E2M1}, default = Format.E2M1 + FP4 data type. + fp8_format : {Format.E4M3}, default = Format.E4M3 + FP8 data type. Only E4M3 is supported. + fp8_dpa: bool, default = `False` + FP8 dot product attention. Not yet supported. + fp8_mha: bool, default = `False` + FP8 multi-head attention. Not yet supported. + """ + + # Configuration envvars + disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1" + disable_stochastic_rounding: bool = ( + os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" + ) + disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + + fp4_format: Format = Format.E2M1 + fp8_format: Format = Format.E4M3 + + # Not applying quantization to attention for now + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" + assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + + # Quantization params + # Note: RHT is currently only applied to column-wise usage so that + # it can be used for wgrad GEMM. + self.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=False, + fp4_2d_quantization=False, + ) + self.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, + stochastic_rounding=False, + fp4_2d_quantization=not self.disable_2d_quantization, + ) + self.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=not self.disable_stochastic_rounding, + fp4_2d_quantization=False, + ) + + def __repr__(self) -> str: + return ( + f"recipe_type={self.__class__.__name__}, " + f"fp4_format={str(self.fp4_format).split('.')[1]}, " + f"fp8_format={str(self.fp8_format).split('.')[1]}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}, " + f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " + f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " + f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " + ) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index fd907efcba..ee2c845159 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -20,6 +20,13 @@ namespace { constexpr int amax_kernel_threads = 512; +__launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *amax_ptr = 0; +} + template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, const size_t N, @@ -65,7 +72,8 @@ template void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr, cudaStream_t stream) { // Zero out amax so we can update with atomic max - NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); + zero_amax_kernel<<<1, 1, 0, stream>>>(amax, noop_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); // Return immediately if tensor is empty if (N == 0) { @@ -130,15 +138,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); auto &output = *convertNVTETensorCheck(output_); - NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, - "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || + output.scaling_mode == NVTE_NVFP4_1D_SCALING, + "Output tensor for amax computation must be FP8 tensor with per-tensor scaling or " + "NVFP4 1D scaling, " "but got scaling_mode=", to_string(output.scaling_mode)); NVTE_CHECK(output.amax.numel() == 1, "Output tensor for amax computation has invalid amax tensor " "(expected 1 entry, got shape=", output.amax.shape, ")"); - NVTE_CHECK(output.amax.dptr != nullptr, + NVTE_CHECK(output.amax.dptr != nullptr || output.columnwise_amax.dptr != nullptr, "Output tensor for amax computation has amax tensor without data"); NVTE_CHECK(output.amax.dtype == DType::kFloat32, "Output tensor for amax computation has invalid amax tensor " @@ -157,11 +167,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt } // Compute amax + float *amax_ptr = reinterpret_cast( + (output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); - launch_amax_kernel(reinterpret_cast(input.data.dptr), - reinterpret_cast(output.amax.dptr), input.data.numel(), - noop_ptr, stream);); // NOLINT(*) + input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel( + reinterpret_cast(input.data.dptr), amax_ptr, input.data.numel(), noop_ptr, + stream);); // NOLINT(*) } } // anonymous namespace diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu new file mode 100644 index 0000000000..5ebc7ba4f3 --- /dev/null +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "../common.h" +#include "../utils.cuh" + +namespace transformer_engine { +namespace nvfp4_recipe { + +// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; +constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); + +// Kernel to compute alpha *= amax_A * amax_B / factor +__global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, + const float *amax_B, float *alpha_out) { + // factor is defined in the enclosing namespace + *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; +} + +} // namespace nvfp4_recipe +} // namespace transformer_engine + +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_per_tensor_scale); + using namespace transformer_engine; + + auto *tA = convertNVTETensor(inpA); + auto *tB = convertNVTETensor(inpB); + auto *tOut = convertNVTETensor(alpha_out); + + void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; + void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; + void *alpha_ptr = tOut->data.dptr; + + // check for not null pointers + NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); + NVTE_CHECK(amax_B_ptr != nullptr, "amax_B_ptr is null"); + NVTE_CHECK(alpha_ptr != nullptr, "alpha_ptr is null"); + + nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>( + alpha_in, reinterpret_cast(amax_A_ptr), + reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 9ec86a37c6..36e06173d0 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -18,7 +18,9 @@ namespace transformer_engine { namespace { -constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; +constexpr int MXFP8_BLOCK_SIZE = 32; +constexpr int NVFP4_BLOCK_SIZE = 16; + constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; @@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ const int original_K = kernel_args.original_k_list[tensor_id]; constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); - constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; - constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; // Get block index in grid. Emulate 2D grid. const int num_tiles_k = K / SF_TILE_DIM_K; @@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { - NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); - } + NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING || + input->scaling_mode == NVTE_BLOCK_SCALING_1D || + input->scaling_mode == NVTE_BLOCK_SCALING_2D || + input->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), + "Input tensor has invalid dtype (", to_string(input->dtype()), ")."); // Do nothing if tensor is empty if (input->data.numel() == 0) { @@ -345,123 +349,150 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s CheckInputTensor(*output, "scaling_factor_output"); auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, + "Unsupported scaling mode for swizzling."); + + bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; // 1D block scaling, row-wise or colum-wise - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int m = - input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; - const int k = - input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - - NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); - NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); - NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output->has_data()) { - NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), - output->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), - output->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); + int m, k; + if (input->has_data()) { + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else { + if (nvfp4) { + m = input->columnwise_scale_inv.shape[0]; + k = input->columnwise_scale_inv.shape[1]; + } else { + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; } + } - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; - dim3 block_size(TB_DIM, TB_DIM); - if (input->has_data()) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_first_dim(); - const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 2: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 1: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + // For NVFP4, the scale inverse for tranposed data needs rowwise swizzle. + const bool rowwise_swizzle = input->has_data() || nvfp4; + const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4; + + dim3 block_size(TB_DIM, TB_DIM); + if (rowwise_swizzle) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + int original_M, original_K; + void *input_scale_inv_ptr, *output_scale_inv_ptr; + + if (!nvfp4 || input->has_data()) { + int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + original_M = input->flat_first_dim(); + original_K = input->flat_last_dim() / block_scale_size; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + } else { + original_M = input->flat_last_dim(); + original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->columnwise_scale_inv.dptr; + output_scale_inv_ptr = output->columnwise_scale_inv.dptr; } - if (input->has_columnwise_data()) { - int vec_load_size = (num_tiles_m - 1) % 4 + 1; - if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_last_dim(); - const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 2: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 1: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; } + } + if (columnwise_swizzle) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; + // NVFP4 shouldn't end up here because it only needs rowwise swizzle + NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); - // 2D block scaling - } else { - NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -551,6 +582,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, } NVTE_CHECK_CUDA(cudaGetLastError()); } + +// TODO(nvfp4): Add NVFP4 support. void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); @@ -677,7 +710,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, * WIP (Phuong): * - Opt for bank conflicts * - Adding swizzle for 2d-block scaling. -*/ + */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_scaling_factors); using namespace transformer_engine; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 55654989a7..f49fe239aa 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "common.h" #include "common/util/cuda_runtime.h" @@ -63,8 +64,8 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: - return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; + case NVTE_NVFP4_1D_SCALING: + return "NVTE_NVFP4_1D_SCALING"; case NVTE_INVALID_SCALING: return "NVTE_INVALID_SCALING"; } @@ -94,12 +95,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { t.columnwise_scale_inv.shape, ")"); } } else { - if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || - t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; size_t expected_x, expected_y, alignment; - const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; + const size_t block_size_rowwise = 32; const size_t block_size_colwise = 32; if (t.has_data()) { @@ -110,6 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(block_size_rowwise)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", @@ -122,11 +123,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { alignment; alignment = block_alignment[0]; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", t.columnwise_scale_inv.shape, ")"); } + } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { + if (t.has_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } } } } @@ -154,6 +173,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { "(expected Float32 or Byte, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // TODO(ksivaman): Fix this to check for amaxes and other details. + // For now only needed for swizzle. + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); @@ -195,10 +234,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt "(expected Float32 or Float8E8M0, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // FP4 output needs to have the scale_inv + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); - // Note: amax is supported for non-FP8 output as it can be fused into the computation - // and later used for quantization with no need to compute it separately + // Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax. + // NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); @@ -491,6 +549,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, case kNVTEColumnwiseScaleInv: t->columnwise_scale_inv = *param; break; + case kNVTEColumnwiseAmax: + t->columnwise_amax = *param; + break; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -514,6 +575,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p return t.scale_inv; case kNVTEColumnwiseScaleInv: return t.columnwise_scale_inv; + case kNVTEColumnwiseAmax: + return t.columnwise_amax; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -629,6 +692,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size); break; + case kNVTEQuantizationConfigRNGState: + std::memcpy(&config_.rng_state, buf, attr_size); + break; + case kNVTEQuantizationConfigNVFP42DQuantization: + std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size); + break; + case kNVTEQuantizationConfigStochasticRounding: + std::memcpy(&config_.stochastic_rounding, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index abfa226e88..89266f4bbc 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #include "../common.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine::detail { @@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor const bool pow_2_scale, const SimpleTensor &noop_tensor, cudaStream_t stream); +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv, + SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor &noop_tensor, cudaStream_t stream); + } // namespace transformer_engine::detail #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu new file mode 100644 index 0000000000..eced2c4bb6 --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -0,0 +1,842 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "curanddx.hpp" + +namespace transformer_engine { + +#if CUDA_VERSION >= 12080 +namespace quantize_transpose_nvfp4 { +namespace { + +using std::int32_t; +using std::uint32_t; +using std::uint8_t; + +using transformer_engine::detail::TypeExtrema; + +// Define a cuRANDDx descriptor +// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. +// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., +// if shared memory, if needed, is enough for the described problem, usually not applicable. +// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + + curanddx::SM<800>() + curanddx::Thread()); + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr int kThreadsPerWarp = 32; + +// for fp4, we use uint8_t to store 2 fp4 numbers +constexpr int kNFP4PerContainer = 2; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; +// constexpr int kScaleDim = 32; +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16 +constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8 +// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +// for 2D block scaling, we need to reduce amax in warp +static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; + +// max for every group_size elements in warp +template +__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { + for (int offset = group_size / 2; offset > 0; offset /= 2) { + val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride)); + } + return val; +} + +template +__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax, + const float global_encode_scale) { + float decode_scale = amax / TypeExtrema::max; + decode_scale = decode_scale * global_encode_scale; + decode_scale = fminf(decode_scale, TypeExtrema::max); + return static_cast(decode_scale); +} + +template +__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale, + const float global_decode_scale) { + return fminf(1.0f / (static_cast(decode_scale) * global_decode_scale), + TypeExtrema::max); +} + +template +__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) { + return static_cast(input) * encode_scale; +} + +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float fp8_max = TypeExtrema::max; + constexpr float fp4_max = TypeExtrema::max; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.f || global_encode_scale == 0.f) { + return 1.f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + curanddx::uniform_bits dist; + random_uint4 = dist.generate4(rng); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +template +__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx, + uint32_t col_length) { + // This function takes in indices from the scale factor matrix and returns an offset in the + // swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled + // index). col_length is the column length of the scale factor matrix. tile_scales_inv is the + // pointer to the scale factor matrix. + + // https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts + // For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4 + // columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4 + // columns. + + // NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix. + // To think in high level, the swizzled scale factor matrix could be composed as: + // unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8) + // cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64 + // rb_cnt = M // 128 # Assuming M is divisible by 128 + // tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4) + // tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + // swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4)) + + constexpr uint32_t kTotalRowsPerBaseBlock = 128; + constexpr uint32_t kRowsPerBaseBlockCol = 32; + constexpr uint32_t kColsPerBaseBlockCol = 4; + + const size_t rb = row_idx / kTotalRowsPerBaseBlock; + const size_t rem = row_idx % kTotalRowsPerBaseBlock; + const size_t d4 = rem / kRowsPerBaseBlockCol; + const size_t d3 = rem % kRowsPerBaseBlockCol; + const size_t cbg = col_idx / kColsPerBaseBlockCol; + const size_t d5 = col_idx % kColsPerBaseBlockCol; + + const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol); + // row-major offset in the logical shape + // (rb_cnt , cbg_cnt , 32 , 4 , 4) + // Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] / + // 32 = [0-4]) + return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5; +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const uint32_t rbits) { +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + uint16_t out_4x; + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" + "}" + : "=h"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const uint32_t rbits) { +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + // NOTE: rbits unused for rn. + uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. + asm volatile( + "{\n" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); + return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL +} + +template +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const uint32_t rbits) { + if constexpr (kApplyStochasticRounding) { + return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits); + } else { + return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits); + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, const float* global_amax, OType* const output_c, + OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, + const size_t row_length, const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, + const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state, + const float* noop_ptr) { + constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer; + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + + const size_t block_idx_x = blockIdx.x; + const size_t block_idx_y = blockIdx.y; + const size_t rng_sequence = + threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. + // Instead of static_assert, return early if these invalid modes are detected. + if constexpr (kIs2DBlockScaling && kIsE8Scaling) { + return; + } + if constexpr (kIs2DBlockScaling && !kReturnIdentity) { + return; + } + // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 + // use constexpr to define the size, when not using 2D, use minimal size 1x1 + constexpr int kFP4BlockScalingSize = 16; + constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1; + constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4 + constexpr int k2DBlockAmaxReduceDim = + kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1; + __shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim]; + __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = (c_g < row_length ? min(static_cast(kNVecIn), row_length - c_g) + : 0); // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory + // for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; + const float global_encode_scale = + kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); + const float global_decode_scale = 1.0 / global_encode_scale; + + // Step 2: Cast and store to output_c + if constexpr (kReturnIdentity) { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = + (c_g < row_length ? min(static_cast(kNVecOut / kNFP4PerContainer), + (row_length - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // doing shuffle sync for 2D block scaling (not applicable for E8 scaling) + if constexpr (kIs2DBlockScaling) { + // first amax shuffle sync in warp, then reduce in smem + // T0 T8 T16 T24 should do amax reduction together + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + CType amax_warp_reduced = groupMax( + amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); + // now T0 ~ T8 in each warp has the reduced amax values + int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + if (tid_in_warp_y == 0) { + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] + [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; + } + __syncthreads(); + + if (data_row_idx % kFP4BlockScalingSize == 0) { + CType amax_2d = 0.0; + for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { + amax_2d = fmaxf(amax_2d, + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + } + amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; + } + __syncthreads(); + // every thread now knows 2D amax + amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]; + } + // Step 2.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + write_scale_inv &= (c_g < row_length); + } + if (write_scale_inv) { + size_t row_idx = block_idx_y * kTileDim + r_s; + size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce; + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(row_length, kScaleBlockDim)); + tile_scales_inv_c[offset] = scale_inv; + } else { + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); + const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory + // for not aligned case) + output_g += stride_g / kNFP4PerContainer; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if constexpr (kReturnTranspose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory + const size_t stride_g = + static_cast(c_stride) * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = (c_g < num_rows ? min(static_cast(kNVecOut / kNFP4PerContainer), + (num_rows - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; + if constexpr (kIs2DBlockScaling) { + // TODO(zhongbo): 2D block scaling, directly read from amax_smem + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + constexpr int kNumColsPerWarp = + kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements + constexpr int kNumWarpsPerBlock = + kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block + constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock; + int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp; + int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore; + int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x; + amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize]; + } else { +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + } + // Step 3.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // Step 3.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + write_scale_inv &= (c_g < num_rows); + } + if (write_scale_inv) { + size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce); + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim)); + tile_scales_inv_t[offset] = scale_inv; + } else { + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = + ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], + encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], + encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], + encode_scale); + const uint32_t rbits = + kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0, + num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global + // memory for not aligned case) + output_g += stride_g / kNFP4PerContainer; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace quantize_transpose_nvfp4 +#endif // CUDA_VERSION >= 12080 + +namespace detail { + +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor& noop_tensor, cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); +#if CUDA_VERSION >= 12080 + + // pow 2 scale is for MXFP4 since it's using E8M0 scaling + // raise error if pow2_scale is true + NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now"); + + if (!return_identity && !return_transpose) { + return; + } + + if (use_2d_quantization && !return_identity) { + return; + } + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + + if (return_identity) { + scale_stride_x = 1; + scale_stride_y = scale_inv.shape[1]; + } + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + using namespace transformer_engine::quantize_transpose_nvfp4; + + const size_t num_blocks_x = DIVUP(row_length, static_cast(kTileDim)); + const size_t num_blocks_y = DIVUP(num_rows, static_cast(kTileDim)); + + // noop tensor for cuda graph + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + + const size_t* rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY( + output.dtype, 2, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16; + constexpr bool kPow2Scale = false; + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, kAligned, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + swizzled_scale, kSwizzledScale, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kApplyStochasticRounding, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, + num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, + scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, + noop_ptr);) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace detail +} // namespace transformer_engine diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 50ff82d85f..6093b54b6d 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -598,6 +598,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { @@ -828,6 +829,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); } } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; @@ -947,6 +949,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; @@ -1260,7 +1263,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cast_gated(gated_input, output, stream); } } - } else if (is_mxfp_scaling(output->scaling_mode)) { + } else if (is_mxfp8_scaling(output->scaling_mode)) { if (use_tma_kernels) { cast_mxfp8_gated(grad, gated_input, output, stream); } else { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 8d87351181..b0498602b5 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -23,6 +23,7 @@ #include "../util/vectorized_pointwise.h" #include "../utils.cuh" #include "math.h" +#include "nvfp4_transpose.cuh" #include "ptx.cuh" #include "transformer_engine/transformer_engine.h" @@ -108,6 +109,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -135,8 +138,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_sh = reinterpret_cast(dshmem); IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; @@ -284,7 +288,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float scaled_out = in * block_scale_inverse; const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); } } @@ -408,10 +412,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; @@ -439,7 +445,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } } @@ -454,19 +460,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); } if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); } // Create a "bulk async-group" out of the previous bulk copy operation. @@ -487,18 +493,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Added extra 1-element padding per thread_X to reduce bank conflicts float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const size_t shmem_thread_offset = + const int shmem_thread_offset = tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; - const size_t shmem_elt_idx = swizzled_group_offset + e; + const int shmem_elt_idx = swizzled_group_offset + e; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } @@ -506,15 +512,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int i = 0; i < THREADS_Y; ++i) { // Add extra element offset per MXFP8 scaling block [1x32] - const size_t scaling_block = threadIdx.x / SCALE_DIM_X; + const int scaling_block = threadIdx.x / SCALE_DIM_X; thread_partial_dbias += partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } - const size_t dbias_stride = cols; - const size_t dbias_offset_Y = blockIdx.y; - const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); if (!col_out_of_bounds_dbias) { dbias_workspace[dbias_idx] = thread_partial_dbias; @@ -536,6 +542,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } // namespace mxfp8_kernel +namespace nvfp4_kernel { + +using namespace ptx; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +} + +#define DIRECT_SCALING_FACTORS_STORE 1 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, + const float *const nvfp4_second_stage_scale_ptr, const size_t rows, + const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = + CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = + reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = + (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (colwise_scale_is_within_bounds) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + +#if DIRECT_SCALING_FACTORS_STORE + // Check boundaries + if (rowwise_scale_is_within_bounds) { + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; + scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; + } +#else + const int shmem_scales_offset_Y = + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = + shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; +#endif + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = + __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + +#if !DIRECT_SCALING_FACTORS_STORE + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = + scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && + (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t &scales = + *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); + } + } +#endif + + float chunk_amax = 0.0f; + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + chunk_amax = reduce_max(thread_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, chunk_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace nvfp4_kernel + constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_THREADS_PER_CHUNK = 128; @@ -898,7 +1426,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, } template -static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { const size_t N = product(input.data.shape); const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); @@ -1179,6 +1707,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ); // NOLINT(*) } +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +template +void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { + using namespace nvfp4_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = + use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = + reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = + use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem + + TMA_SHMEM_ALIGNMENT; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + }); // NOLINT(*) + ); // NOLINT(*) +} + namespace detail { using Empty = transformer_engine::Empty; @@ -1386,20 +2049,33 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o auto dbias_tensor = convertNVTETensor(dbias); auto workspace_tensor = convertNVTETensor(workspace); - const QuantizationConfig *quant_config_cpp = - reinterpret_cast(quant_config); + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } - // extract noop tensor from quant_config_cpp if it's not null - const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; - const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + // Dispatch to quantization kernel depending on data format switch (output_tensor->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { if (output_tensor->has_columnwise_data()) { NVTE_CHECK(output_tensor->has_data(), "Quantizing in only the columnwise direction not supported yet!"); if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); } else { cast_transpose_fused( *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, @@ -1407,51 +2083,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o } } else if (output_tensor->has_data()) { fp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); } break; } case NVTE_MXFP8_1D_SCALING: { mxfp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; } + case NVTE_NVFP4_1D_SCALING: { + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 && + output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4_quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4_quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for " + "2D quantization"); + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor.data, stream); + /*noop_tensor=*/noop_tensor->data, stream); break; } case NVTE_BLOCK_SCALING_1D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; if (output_tensor->has_data()) { - bool rowwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; } if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); columnwise_option = columnwise_compact ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; @@ -1459,7 +2174,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor.data, stream); + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } default: diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e2d8d34f3d..9f70ce4cd4 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include "../common.h" @@ -26,6 +28,7 @@ #include "math.h" #include "ptx.cuh" #include "transformer_engine/activation.h" +#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transpose.h" namespace transformer_engine { @@ -226,7 +229,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); @@ -247,7 +250,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str ); // NOLINT(*) } -static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); checkCuDriverContext(stream); @@ -331,6 +334,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); } + +#if CUDA_VERSION >= 12080 +template +__global__ void __launch_bounds__(512) + dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride) { + const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = thread_idx % M; + const size_t y = thread_idx / M; + + union fp4vec { + uint64_t vec; + fp4e2m1x4 small_vec[4]; + }; + using OVec = Vec; + const uint64_t *const input_vectorized = reinterpret_cast(input); + OVec *output_vec = reinterpret_cast(output); + + const size_t my_index = x + y * M; + const size_t my_scale_index = x + y * scale_stride; + const size_t my_output_index = (x + y * M) * 4; + fp4vec value; + value.vec = input_vectorized[my_index]; + fp8e4m3 scale = scales[my_scale_index]; + float amax = *tensor_amax; + constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float final_scale = static_cast(scale) * amax * factor_inv; +#pragma unroll + for (int i = 0; i < 4; i++) { + float4 current = static_cast(value.small_vec[i]); + OVec out; + out.data.elt[0] = static_cast(current.x * final_scale); + out.data.elt[1] = static_cast(current.y * final_scale); + out.data.elt[2] = static_cast(current.z * final_scale); + out.data.elt[3] = static_cast(current.w * final_scale); + output_vec[my_output_index + i] = out; + } +} +#endif // CUDA_VERSION + +void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +#if CUDA_VERSION >= 12080 + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output"); + NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + constexpr int FP4_BLOCK_SIZE = 16; + const size_t N = input.flat_first_dim(); + const size_t M = input.flat_last_dim(); + + NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", + FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); + + const size_t Mread = M / FP4_BLOCK_SIZE; + const size_t total = N * Mread; + const size_t threads = 512; + const size_t blocks = DIVUP(total, threads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back());); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif // CUDA_VERSION >= 12080 +} + } // namespace dequantization namespace detail { @@ -339,17 +417,25 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); - if (is_tensor_scaling(input.scaling_mode)) { - dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { - if (is_supported_by_CC_100()) { - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + dequantization::fp8_dequantize(input, output, stream); + break; } - } else { - // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + dequantization::mxfp8_dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + dequantization::fp4_dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh new file mode 100644 index 0000000000..fe9736298d --- /dev/null +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -0,0 +1,1515 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_transpose.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ +#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ + +#include +#include +#include + +#if CUDA_VERSION > 12080 +#include +#endif // CUDA_VERSION > 12080 + +#include + +#include "../common.h" +#include "../utils.cuh" +#include "curanddx.hpp" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +#if CUDA_VERSION > 12080 +namespace nvfp4_transpose { + +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + + curanddx::SM<800>() + curanddx::Thread()); + +using namespace ptx; +using nvfp4_scale_t = fp8e4m3; + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; +constexpr size_t RNG_GENS_PER_THREAD = + SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + // constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + // NOTE: Divide by 6.0f is not elegant and not efficient. + // However, this is part of the emulation code to ensure exact match. + using namespace detail; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + const float S_dec_b = block_amax / fp4_max * S_enc; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); +} + +// Compute the global encode scale factor for a given global amax +__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + curanddx::uniform_bits dist; + random_uint4 = dist.generate4(rng); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( + const uint64_t in_4x, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); + } else { + return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); + } +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); + } else { + return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); + } +} + +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +template +__global__ void __launch_bounds__(THREADS_NUM) + nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, + const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, + rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, + const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + // NEW: 2D Block-based scaling constants + constexpr size_t BLOCK_DIM = 16; + constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 + constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 + constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile + constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + const size_t warp_id = threadIdx.x / 32; + const size_t lane_id = threadIdx.x % 32; + float thread_amax = 0.0f; + const size_t block_in_warp = lane_id / BLOCKS_PER_WARP; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + + // Helper function for warp reduction + auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { +#pragma unroll + for (int delta = 8; delta >= 1; delta /= 2) { + float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta); + thread_amax = fmaxf(thread_amax, other_amax); + } + return thread_amax; + }; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + +#pragma unroll + for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + const size_t block_in_tile_y = block_iter; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + for (int elem = 0; elem < BLOCK_DIM; elem += 2) { + const size_t elem_0_row = block_iter * BLOCK_DIM + elem; + const size_t elem_1_row = elem_0_row + 1; + const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + const size_t elem_1_col = elem_0_col; + + const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col; + const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col; + + IType2 val_2x; + val_2x.x = in_sh[shmem_offset_0]; + val_2x.y = in_sh[shmem_offset_1]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x); + } + + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else { + for (int elem = 0; elem < BLOCK_DIM; ++elem) { + const size_t elem_row = block_iter * BLOCK_DIM + elem; + const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + + // Bounds checking + const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows); + const bool col_out_of_bounds = (block_offset_X + elem_col >= cols); + if (!row_out_of_bounds && !col_out_of_bounds) { + const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col; + float elt = static_cast(in_sh[shmem_offset]); + + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset] = static_cast(elt); + } + + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + } + // Warp reduction to get block amax + block_amax = warp_reduce_amax(thread_amax, block_in_warp); + + if (lane_id == 0 || lane_id == 16) { + block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax; + } + } + + // sync thread to ensure block_amax_matrix is done storing + __syncthreads(); + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 3. Scale elements + + // Load data in + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + } + } else { + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + fp4e2m1x4 regs[SCALE_DIM / 4]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, + rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = tid_X_rowwise; + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + } + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace nvfp4_transpose +#endif // CUDA_VERSION > 12080 + +// Compile-time flag to choose kernel variant +#ifndef USE_2D_NVFP4_KERNEL +#define USE_2D_NVFP4_KERNEL 0 +#endif + +template +void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if CUDA_VERSION > 12080 + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + // TODO(Frank): Is there a better way to do this? + bool return_transpose = output->has_columnwise_data(); + + using namespace nvfp4_transpose; + using namespace ptx; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + if (return_transpose) { + NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = output->columnwise_scale_inv.shape[1]; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = nvfp4_transpose_kernel; + + if constexpr (use_2d_quantization) { + kernel = nvfp4_transpose_kernel_2D; + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION > 12080 +} +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 581de9f9fd..85717afdf2 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -14,6 +14,10 @@ #include #include +#if CUDA_VERSION >= 12080 +#include +#endif // CUDA_VERSION >= 12080 + namespace transformer_engine { namespace ptx { @@ -117,9 +121,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { return __int_as_float(biased_exp << FP32_MANTISSA_BITS); } +#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \ + ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM103_ALL))) + __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + uint16_t out; asm volatile( "{\n" @@ -222,18 +230,86 @@ struct alignas(2 * sizeof(T)) FPx2 { T y; }; +template +struct FPx4 { + T x1; + T x2; + T x3; + T x4; +}; + +template +struct Type2x {}; + +template <> +struct Type2x { + using type = float2; +}; + +template <> +struct Type2x { + using type = __nv_bfloat162; +}; + +template <> +struct Type2x { + using type = __half2; +}; + using floatx2 = FPx2; using bf16x2 = FPx2; using fp16x2 = FPx2; using fp8e4m3x2 = FPx2; using fp8e5m2x2 = FPx2; +using floatx4 = FPx4; +using bf16x4 = FPx4; +using fp16x4 = FPx4; +using fp8e4m3x4 = FPx4; +using fp8e5m2x4 = FPx4; + static_assert(sizeof(floatx2) == 8); static_assert(sizeof(bf16x2) == 4); static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); +#if CUDA_VERSION >= 12080 +using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; +static_assert(sizeof(fp4e2m1x2) == 1); +static_assert(sizeof(fp4e2m1x4) == 2); +#endif // CUDA_VERSION >= 12080 + +// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1 + +// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6. + +// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures: +// sm_100a +// sm_101a +// sm_120a + +// When converting to .e2m1x2 data formats, the destination operand d has .b8 type. +// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, +// and the converted values are packed in the destination operand d such that the value +// converted from input a is stored in the upper 4 bits of d and the value converted +// from input b is stored in the lower 4 bits of d. + +// SIMD like "Fused" cast + multiplication (x4) +#if CUDA_VERSION >= 12080 +template +__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23, + const float scale) { + const float x0 = static_cast(in01.x) * scale; + const float x1 = static_cast(in01.y) * scale; + const float x2 = static_cast(in23.x) * scale; + const float x3 = static_cast(in23.y) * scale; + out = fp4e2m1x4(make_float4(x0, x1, x2, x3)); +} +#endif // CUDA_VERSION >= 12080 + // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { @@ -369,7 +445,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const "r"(reinterpret_cast(p2))); } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 68b7aa8bbe..bce124e705 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -22,7 +22,8 @@ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 3f5bcc975d..bc764ac746 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -35,6 +35,26 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// +// Device-side error +#define NVTE_DEVICE_ERROR(message) \ + do { \ + printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \ + __func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \ + (message)); \ + assert(0); \ + } while (false) + +// Device-side error on thread 0 +#define NVTE_DEVICE_THREAD0_ERROR(message) \ + do { \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \ + threadIdx.y == 0 && threadIdx.z == 0) { \ + NVTE_DEVICE_ERROR(message); \ + } \ + } while (false) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) return {a.x + b.x, a.y + b.y}; } diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index d1470e22e3..a1fae730c5 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -89,3 +89,5 @@ dist_group_type = torch.distributed.ProcessGroup MXFP8_BLOCK_SCALING_SIZE = 32 + +NVFP4_BLOCK_SCALING_SIZE = 16 diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index e4f4e619fe..d330e023ea 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -13,6 +13,8 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.utils import is_experimental +from ..experimental.gemm import experimental_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ @@ -77,6 +79,24 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + # If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation + if is_experimental(A) or is_experimental(B): + return experimental_gemm( + A, + B, + workspace, + out_dtype, + quantization_params, + gelu, + gelu_in, + accumulate, + layout, + out, + bias, + use_split_accumulator, + grad, + ) + debug_quantizer = None if isinstance(quantization_params, DebugQuantizer): debug_quantizer = quantization_params diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index dffb899f7e..49ae963d74 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,6 +12,20 @@ namespace transformer_engine::pytorch { +/*! convert fp4 data shape back to original shape */ +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose) { + std::vector ret; + size_t start_idx = (transpose) ? 1 : 0; + for (size_t i = start_idx; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() * 2); + if (transpose) { + ret.push_back(shape.front()); + } + return ret; +} + std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { @@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; } +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, + arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, + at::cuda::getCurrentCUDAStream()); + }); +} + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) { + at::PhiloxCudaState philox_args; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_cuda_state(elts_per_thread); + return philox_args; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2d35de8522..c94bd0d2a5 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -194,20 +195,25 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - /*! @brief Construct a high precision tensor giving it this quantizer's amax - - Note: this member function also zeros out the amax, as it is meant to be used in conjunction with - a kernel computing the amax, which might expect the amax to be initialized to zero + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. */ - std::pair create_hp_tensor_with_amax(const std::vector& shape, - DType dtype); + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype); std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - /*! @brief Convert to a quantized data format avoiding amax computation */ + /*! @brief Quantize to FP8, skipping local amax computation + * + * The quantizer's amax pointer is assumed to already hold the local + * amax. The amax may still be reduced across the amax reduction + * group. + */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt); @@ -277,6 +283,60 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; +class NVFP4Quantizer : public Quantizer { + public: + // fp4 dtype + DType dtype; + // amax reduction for low precision FP4 AG + bool with_amax_reduction; + c10::intrusive_ptr amax_reduction_group; + // random hadamard transform + bool with_rht; + bool with_post_rht_amax; + // 2D block scaling + bool with_2d_quantization; + bool stochastic_rounding; + + int rht_matrix_random_sign_mask_t; + at::Tensor rht_matrix; + + explicit NVFP4Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype); + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; + + /*! @brief Quantize to NVFP4, skipping local amax computation + * + * The input tensor's amax pointer is assumed to already hold the + * local amax. The amax may still be reduced across the amax + * reduction group. + */ + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); + + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + + private: + void quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, bool compute_amax); +}; + std::unique_ptr convert_quantizer(py::handle quantizer); std::vector getTensorShape(const at::Tensor& t); @@ -420,6 +480,15 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); + +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); + +// unpack the PhiloxCudaState into CUDA tensor +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread); + } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7851cc5ffc..cdfb4be408 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -8,179 +8,269 @@ #include "common.h" #include "pybind.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { -template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { +namespace { + +py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t), + const at::Tensor& input, py::handle quantizer, + int shape_divisor = 1) { init_extension(); // Input tensor auto input_tensor = input.contiguous(); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct output tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape = input_cpp.shape(); + const auto input_shape = input_nvte.shape(); std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); output_shape.back() /= shape_divisor; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); - // Compute activation + // Choose implementation + enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation directly - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation in high-precision fused together with amax, then quantize. - - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); - } else { - // Compute activation in high-precision, then quantize - - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp->quantize(temp_cpp, out_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + quantizer_cpp->quantize(temp_nvte, out_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation directly + { + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer) { +py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, + cudaStream_t), + const at::Tensor& grad_output, const at::Tensor& input, + py::handle quantizer) { init_extension(); // Grad output and input tensors auto grad_output_tensor = grad_output.contiguous(); auto input_tensor = input.contiguous(); - const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape_te = input_cpp.shape(); + const auto input_shape_te = input_nvte.shape(); const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); - // Compute activation backward + // Choose implementation + enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation backward directly - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation backward in high-precision fused together with amax, then quantize. - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); - } else { - // Compute activation backward in high-precision, then quantize - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation backward in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_nvte, grad_input_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation backward directly + { + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return grad_input_py; } -/* GELU and variants*/ +} // namespace + +/* GELU and variants */ py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_gelu, input, quantizer); } py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dgelu, grad, input, quantizer); } py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_geglu, input, quantizer, 2); } py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dgeglu, grad, input, quantizer); } py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_qgelu, input, quantizer); } py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dqgelu, grad, input, quantizer); } py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_qgeglu, input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dqgeglu, grad, input, quantizer); } -/* ReLU and variants*/ +/* ReLU and variants */ py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_relu, input, quantizer); } py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_drelu, grad, input, quantizer); } py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_reglu, input, quantizer, 2); } py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dreglu, grad, input, quantizer); } py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_srelu, input, quantizer); } py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsrelu, grad, input, quantizer); } py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_sreglu, input, quantizer, 2); } py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsreglu, grad, input, quantizer); } -/* Silu and variants*/ +/* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_silu, input, quantizer); } py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsilu, grad, input, quantizer); } py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_swiglu, input, quantizer, 2); } py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dswiglu, grad, input, quantizer); } -} // namespace transformer_engine::pytorch + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 8179727e58..5db9dd73da 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); } -void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, - arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, - at::cuda::getCurrentCUDAStream()); - }); -} - -// extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) { - at::PhiloxCudaState philox_args; - std::lock_guard lock(gen->mutex_); - philox_args = gen->philox_cuda_state(elts_per_thread); - return philox_args; -} - } // namespace namespace transformer_engine::pytorch { @@ -198,7 +182,7 @@ std::vector fused_attn_fwd( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack(philox_args, static_cast(rng_state.data_ptr())); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index a80cb35f25..0531596dd3 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -122,13 +122,27 @@ std::vector dact_dbias( } // Choose implementation - enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; + enum class Impl { + UNFUSED, + FUSED_DACT_DBIAS_QUANTIZE, + FUSED_DACT_AMAX_FP8, + FUSED_DACT_AMAX_NVFP4 + }; Impl impl = Impl::UNFUSED; if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || detail::IsMXFP8Quantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - impl = Impl::FUSED_DACT_AMAX; + impl = Impl::FUSED_DACT_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_DACT_AMAX_NVFP4; + } } // Perform compute @@ -172,20 +186,38 @@ std::vector dact_dbias( }); break; } - case Impl::FUSED_DACT_AMAX: - // Fused dact-amax kernel, unfused dbias and quantize + case Impl::FUSED_DACT_AMAX_FP8: + // Fused dact-amax kernel, unfused dbias and FP8 quantize { - auto *quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(quantizer_cpp_cs != nullptr, + auto *fp8_quantizer_cpp = + dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Invalid quantizer for fused dact-amax kernel impl"); auto [temp_nvte, temp_py] = - quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + break; + } + case Impl::FUSED_DACT_AMAX_NVFP4: + // Fused dact-amax kernel, unfused dbias and NVFP4 quantize + { + auto *nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, + "Invalid quantizer for fused dact-amax kernel impl"); + auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax( + grad_input_nvte, grad_output_dtype); NVTE_SCOPED_GIL_RELEASE({ dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); }); const auto temp_torch = temp_py.cast(); at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } default: diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 0d18a5ec5b..1364597519 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -213,6 +213,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + // Construct GEMM config + transformer_engine::MatmulConfigWrapper config; + if (grad) { + config.set_dbias_tensor(bias_tensor.data()); + config.set_with_dgelu_epilogue(gelu); + } else { + config.set_bias_tensor(bias_tensor.data()); + config.set_with_gelu_epilogue(gelu); + } + config.set_epilogue_aux_tensor(te_pre_gelu_out.data()); + config.set_use_split_accumulator(use_split_accumulator); + config.set_sm_count(num_math_sms); + // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; auto main_stream = at::cuda::getCurrentCUDAStream(); @@ -276,10 +289,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), - bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, - te_workspace.data(), alpha, *beta, use_split_accumulator, - num_math_sms, main_stream); + nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(), + out_tensor.data(), out_tensor.data(), te_workspace.data(), config, + main_stream); }); } } else { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index c63f892cea..3fa0fb0aa3 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -66,67 +66,102 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); - TensorWrapper bias_cu; + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); + TensorWrapper bias_nvte; if (bias.has_value()) { - bias_cu = makeTransformerEngineTensor(*bias); + bias_nvte = makeTransformerEngineTensor(*bias); } // Tensor dimensions - const size_t N = static_cast(input_cu.size(0)); - const size_t H = static_cast(input_cu.size(1)); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - TensorWrapper mu_cu = makeTransformerEngineTensor(mu); - TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; + } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; } } - TensorWrapper unquantized_out_cu; + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -138,24 +173,31 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + default: { } } - return {out, py::cast(mu), py::cast(rsigma)}; + return {out, py::cast(mu_py), py::cast(rsigma_py)}; } std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -254,61 +296,95 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); // Tensor dimensions - const size_t N = static_cast(input_cu.shape().data[0]); - const size_t H = static_cast(input_cu.shape().data[1]); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; + } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; } } - TensorWrapper unquantized_out_cu; + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -320,24 +396,30 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + default: { } } - return {out, py::none(), py::cast(rsigma)}; + return {out, py::none(), py::cast(rsigma_py)}; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7649ccb6d6..98f71f9a7b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -32,6 +32,9 @@ PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; +PyTypeObject *NVFP4TensorPythonClass = nullptr; +PyTypeObject *NVFP4TensorBasePythonClass = nullptr; +PyTypeObject *NVFP4QuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -86,10 +89,26 @@ void init_float8blockwise_extension() { "Internal error: could not initialize pyTorch float8blockwise extension."); } +void init_nvfp4_extensions() { + if (NVFP4TensorPythonClass) return; + auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); + NVFP4QuantizerClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); + NVFP4TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor")); + auto nvfp4_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"); + NVFP4TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorBase")); + NVTE_CHECK(NVFP4TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch NVFP4 extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); + init_nvfp4_extensions(); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 9fd1ae4de9..f46edaa70e 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -40,13 +40,12 @@ extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass; extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; +extern PyTypeObject *NVFP4TensorPythonClass; +extern PyTypeObject *NVFP4TensorBasePythonClass; +extern PyTypeObject *NVFP4QuantizerClass; void init_extension(); -void init_float8_extension(); - -void init_mxfp8_extension(); - namespace detail { inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } @@ -69,11 +68,17 @@ inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; } +inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; } + inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; } +inline bool IsNVFP4Tensor(PyObject *obj) { + return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass; +} + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -88,6 +93,8 @@ std::unique_ptr CreateMXFP8Params(const py::handle params); TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantization_params); +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } @@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, - NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; - + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), + std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, + CreateQuantizer)}; } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index cd7e70fecb..2abe9614e1 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -31,8 +31,20 @@ std::vector make_transpose_shape(const std::vector& shape) { return ret; } +/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ +template +std::vector convert_shape_for_fp4(const std::vector& shape) { + std::vector ret; + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() / 2); + return ret; +} + } // namespace +constexpr size_t NVFP4_BLOCK_SIZE = 16; constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -376,8 +388,9 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( - const std::vector& shape, DType dtype) { +std::pair +Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, + DType dtype) { amax.zero_(); auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), @@ -899,7 +912,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1095,7 +1108,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); std::vector scale_shape; @@ -1116,4 +1129,573 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } +NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->with_rht = quantizer.attr("with_rht").cast(); + this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); + this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); + this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + + // Get amax reduction group if needed for NVFP4 AG + const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); + c10::intrusive_ptr amax_reduction_group; + if (with_amax_reduction) { + auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); + NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group"); + amax_reduction_group = group.cast>(); + } + this->with_amax_reduction = with_amax_reduction; + this->amax_reduction_group = amax_reduction_group; + + this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast(); + this->rht_matrix = quantizer.attr("rht_matrix").cast(); +} + +void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { + // set dtype for rowwise and columnwise data in tensor wrapper + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(this->dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(this->dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { + using namespace pybind11::literals; + + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); + const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); + + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; + const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + if (rowwise_usage) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), + rowwise_scale_inv_shape.end()); + rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_rowwise = at::empty({1}, bit32_tensor_opts); + } + if (columnwise_usage) { + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), + columnwise_scale_inv_shape.end()); + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data_tensor = + at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_columnwise = at::empty({1}, bit32_tensor_opts); + } + + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + auto amax_rowwise_py = py_cast(amax_rowwise, rowwise_usage); + auto amax_columnwise_py = py_cast(amax_columnwise, columnwise_usage); + + // Construct Python NVFP4 tensor + py::object out_py; + if (internal) { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorBasePythonClass)); + out_py = NVFP4TensorClass( + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } else { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); + out_py = NVFP4TensorClass( + "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + rowwise_scale_inv_shape); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + columnwise_scale_inv_shape); + out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype) { + // Construct tensor + auto shape = convertShape(quantized_tensor.shape()); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + + // Register amax pointer from quantized tensor + void* amax_ptr = quantized_tensor.amax(); + if (amax_ptr == nullptr) { + amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + } + NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); + out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + + // Zero out amax + NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + auto amax_rowwise = get_tensor("_amax_rowwise"); + auto amax_columnwise = get_tensor("_amax_columnwise"); + NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data."); + + // Tensor dimensions, shape means original shape + std::vector shape; + if (columnwise_data) { + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + if (rowwise_data) { + auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + } + + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + if (!amax_rowwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_rowwise = at::empty({1}, opts); + tensor.attr("_amax_rowwise") = *amax_rowwise; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + if (amax_rowwise) { + amax_rowwise.reset(); + tensor.attr("_amax_rowwise") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + if (!amax_columnwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_columnwise = at::zeros({1}, opts); + tensor.attr("_amax_columnwise") = *amax_columnwise; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + if (amax_columnwise) { + amax_columnwise.reset(); + tensor.attr("_amax_columnwise") = py::none(); + } + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, + bool compute_amax) { + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); + quant_config.set_stochastic_rounding(this->stochastic_rounding); + + // We only need RHT for columnwise usage. + // flat first dim and last dim for multi dimensional input + size_t rows = 1; + for (size_t i = 0; i < input.ndim() - 1; ++i) { + rows *= input.size(i); + } + size_t cols = input.size(input.ndim() - 1); + + TensorWrapper te_rng_state; + if (this->stochastic_rounding) { + const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, opts); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); + te_rng_state = makeTransformerEngineTensor(rng_state); + quant_config.set_rng_state(te_rng_state.data()); + } + + // Restriction for the RHT cast fusion kernel. + bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + + // Compute amax. + if (this->with_rht) { + if (input.dtype() != DType::kBFloat16) { + NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); + } + if (this->with_post_rht_amax) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for RHT(input.t) + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_amax(input.data(), out.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } else { + // raise error since it's not supported yet + NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); + } + } else { // Without RHT + if (compute_amax) { + // Amax pointers + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + + // Compute amax of input tensor + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Make sure row-wise and column-wise amaxes match + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } + + // amax reduction + if (this->with_amax_reduction) { + std::vector amax_tensors; + // push amax tensors inside if they need to be reduced + auto make_amax_tensor = [](void* data_ptr) { + return at::from_blob( + data_ptr, std::vector{1}, + [](void*) {}, // deleter doing nothing since it doesn't own the data + at::device(at::kCUDA).dtype(torch::kFloat32)); + }; + if (rowwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr)); + } + if (columnwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr)); + } + c10d::AllreduceCoalescedOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + NVTE_SCOPED_GIL_RELEASE( + { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); + } + + if (this->with_rht) { + if (rowwise_usage) { + // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise + TensorWrapper out_identity(out.scaling_mode()); + auto out_identity_data = out.get_rowwise_data(); + auto out_identity_scale_inv = out.get_rowwise_scale_inv(); + auto out_identity_amax = out.get_amax(); + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); + } + + if (columnwise_usage) { + // Get the output columnwise data, scale_inv, and amax + auto out_columnwise_data = out.get_columnwise_data(); + auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); + // NOTE: should already be populated. + auto out_columnwise_amax = out.get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. + // The reason is due to the input `rht_output_t` is already in the transposed layout. + // Thus, we only need a rowwise quantization to generate the columnwise output. + TensorWrapper out_transpose(out.scaling_mode()); + // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail + // need to convert the shape to 2D here + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte + // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again + // so the multiple 2 get cancelled out + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + + if (!eligible_for_rht_cast_fusion) { + // Invoking fallback RHT kernel. + + // If using RHT, then amax will be computed in the RHT step + // If not using RHT, then amax will be computed based on input x + at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout + // This wrapper is going to be passed as input to the quantization kernel. + TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + // NOTE (frsun): This is non-intuitive, we are writing the + // result of transposed RHT to the output of rowwise. + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + + NVTE_SCOPED_GIL_RELEASE({ + // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + + // Quantize kernel will treat everything as rowwise input/output, which is + // intended. + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); + }); + } else { + // RHT cast fusion kernel. + NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, + "RHT matrix is not set"); + auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_cast_fusion_columnwise( + input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); + }); + } + } + } else { + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); + } +} + +void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + this->quantize_impl(input, out, noop_flag, true); +} + +void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + // Update output tensor amaxes with input tensor amax + auto input_amax_ptr = input.amax(); + auto output_rowwise_amax_ptr = out.get_amax().data_ptr; + auto output_columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + NVTE_CHECK(input_amax_ptr != nullptr || + (output_rowwise_amax_ptr == nullptr && output_columnwise_amax_ptr == nullptr), + "Input tensor does not have pre-computed amax"); + if (input_amax_ptr != output_rowwise_amax_ptr && input_amax_ptr != nullptr && + output_rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_rowwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + if (input_amax_ptr != output_columnwise_amax_ptr && input_amax_ptr != nullptr && + output_columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + + // Perform quantization + this->quantize_impl(input, out, std::nullopt, false); +} + +std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, + bool columnwise) const { + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + auto last_dim = shape.back(); + auto flat_first_dim = numel / last_dim; + + NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + + std::vector scale_shape; + + bool rowwise_usage = !columnwise; + + if (rowwise_usage) { + // rowwise scaling factor shape + size_t sinv0 = roundup(flat_first_dim, 128); + size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } else { + // columnwise scaling factor shape + size_t sinv0 = roundup(last_dim, 128); + size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } + return scale_shape; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index cb2121a457..368e9dcdfa 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer return ret; } +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp4_dtype").cast(); + + auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); + + // Row-scaled data + if (rowwise_usage) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); + ret.set_rowwise_data(data.data_ptr(), dtype, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); + ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); + } + + // Column-scaled data + if (columnwise_usage) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); + ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, + getTensorShape(scale_inv)); + ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + getTensorShape(amax_columnwise)); + } + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 92f2d3a500..3bb6be715d 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -14,22 +14,31 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } - NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8, + "4-bit or 8-bit input required for swizzling scaling factors."); + + const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING; NVTEBasicTensor scale_inv; + NVTEShape nvte_input_shape; if (rowwise) { + nvte_input_shape = input.shape(); scale_inv = input.get_rowwise_scale_inv(); } else { + nvte_input_shape = input.get_columnwise_data().shape; scale_inv = input.get_columnwise_scale_inv(); } - auto input_shape = nvte_shape_to_vector(input.shape()); + auto input_shape = nvte_shape_to_vector(nvte_input_shape); auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape."); + // Allocate memory for swizzled output. auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); std::vector scale_inv_shape_int; @@ -41,36 +50,34 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); // Reconstruct input only to avoid swizzling both directions if not needed. - // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + // The specific dtype used is irrelevant, just needs to be correct bits. + transformer_engine::TensorWrapper input_cu(input.scaling_mode()); + transformer_engine::TensorWrapper output_cu(input.scaling_mode()); + + const auto input_dtype = + (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const auto scale_inv_dtype = + (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } return swizzled_scale_inv; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 217cb98c74..3ab0717d0d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -39,11 +39,14 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer -from .tensor.quantized_tensor import QuantizedTensor, Quantizer +from .tensor.quantized_tensor import QuantizedTensorBase, QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from .tensor._internal.nvfp4_tensor_base import NVFP4TensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .triton.pad import pad_columnwise_scale_inv from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer @@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise( return out, handle +def _swap_first_dims(tensor: torch.Tensor, world_size: int): + """ + Swap first 2 dimensions of a tensor to fix interleaved + data format after gathering transposed data. + + For more than 2 dimensions, we squash the trailing dimensions, + instead of the first few dimensions, that's because the shape + passed in this function is already transposed. + """ + + shape = tensor.shape + assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave." + first_dim = shape[0] + flattened_trailing = math.prod(shape[1:]) + assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." + tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing) + tensor = tex.swap_first_dims(tensor, out=None) + return tensor.reshape(first_dim // world_size, flattened_trailing * world_size) + + +def _post_process_nvfp4_gather( + out: NVFP4TensorBase, + columnwise_data_interleaved: torch.Tensor, + columnwise_scale_inv_interleaved: torch.Tensor, + world_size: int, + handle: Optional[torch.distributed.Work] = None, +) -> NVFP4TensorBase: + """Post-process FP8 blockwise gather.""" + if handle is not None: + handle.wait() + handle = None + + # Fix the interleaved transposed data from gathering along first dim. + out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) + out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + + # Optionally pad the scaling inverse if needed. + out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + + +@dataclass +class _NVFP4AllGatherAsyncHandle: + """Handle for asynchronous NVFP4 all-gather.""" + + output: NVFP4TensorBase + columnwise_data_interleaved: torch.Tensor + columnwise_scale_inv_interleaved: torch.Tensor + world_size: int + async_handle: torch.distributed.Work + _synchronized: bool = False + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.async_handle.wait() + _post_process_nvfp4_gather( + self.output, + self.columnwise_data_interleaved, + self.columnwise_scale_inv_interleaved, + self.world_size, + ) + self._synchronized = True + + +def _all_gather_nvfp4( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: NVFP4Quantizer, + out_shape: Optional[list[int]] = None, +) -> tuple[NVFP4TensorBase, Optional[torch.distributed.Work]]: + """All-gather NVFP4 tensor along first dimension.""" + + # Input tensor attributes + in_shape: Iterable[int] = None + in_shape_t: Iterable[int] = None + device: torch.device + dtype: torch.dtype + + # Construct packed shapes for input and input_t. + if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorBase): + # High-precision tensor. + in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size()) + in_shape_t = NVFP4Quantizer.convert_shape_for_fp4( + NVFP4Quantizer.get_columnwise_shape(inp.size()) + ) + device = inp.device + dtype = inp.dtype + elif isinstance(inp, NVFP4TensorBase): + if inp._rowwise_data is not None: + in_shape = inp._rowwise_data.size() + device = inp._rowwise_data.device + if inp._columnwise_data is not None: + in_shape_t = inp._columnwise_data.size() + device = inp._columnwise_data.device + dtype = torch.bfloat16 + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, " + f"found {inp.__class__.__name__})" + ) + + assert in_shape is not None or in_shape_t is not None, "No data found." + + world_size = get_distributed_world_size(process_group) + + if out_shape is None: + out_shape = [in_shape[0] * world_size] + in_shape[1:] + + # For cases where inp has dimensions that cannot be quantized, + # we gather in high precision followed by a cast to NVFP4. + if ( + not isinstance(inp, NVFP4TensorBase) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None + + # Cast input tensor to NVFP4 with required data + if not isinstance(inp, NVFP4TensorBase): + inp = quantizer(inp) + elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( + quantizer.columnwise_usage and inp._columnwise_data is None + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to NVFP4." + ) + inp = quantizer(inp.dequantize()) + + # Construct NVFP4 output tensor + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # Coalesce NCCL collectives for gathering data and scale inverses. + with torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) as gather_coalescing_manager: + + # Gather NVFP4 data for row-wise usage + if quantizer.rowwise_usage: + + # Remove padding from NVFP4 scale-inverses + assert in_shape is not None, "Shape not found." + in_scale_inv = inp._rowwise_scale_inv + out_scale_inv = out._rowwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out._rowwise_data, + inp._rowwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_rowwise = inp._amax_rowwise + + # Gather the transposed NVFP4 data along first dimension. Fix format later. + if quantizer.columnwise_usage: + + # Remove padding from NVFP4 scale-inverses + # For doing an all-gather on transposed scale inverses, + # we need to remove padding from both dimension. + in_scale_inv = inp._columnwise_scale_inv + # take caution that for in_shape_t, flatten in the trailing dimensions! + flattened_in_shape0 = in_shape_t[0] + flattened_in_shape1 = math.prod(in_shape_t[1:]) + + # Remove dim0 padding + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + + # Remove dim1 padding (pack first). + unpadded_dim1 = flattened_in_shape1 * 2 // 16 + if in_scale_inv.size(1) != unpadded_dim1: + in_scale_inv = in_scale_inv[:, :unpadded_dim1].contiguous() + + # Construct tensor to gather transposed scale_inv (interleaved) and launch AG. + out_scale_inv = torch.empty( + [flattened_in_shape0 * world_size] + [in_scale_inv.shape[1]], + dtype=in_scale_inv.dtype, + layout=in_scale_inv.layout, + device=in_scale_inv.device, + ) + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + + # Construct tensor to gather transposed data (interleaved) and launch AG. + out_columnwise_data = torch.empty( + [inp._columnwise_data.shape[0] * world_size] + list(inp._columnwise_data.shape[1:]), + dtype=inp._columnwise_data.dtype, + layout=inp._columnwise_data.layout, + device=inp._columnwise_data.device, + ) + torch.distributed.all_gather_into_tensor( + out_columnwise_data, + inp._columnwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_columnwise = inp._amax_columnwise + + handle = gather_coalescing_manager if async_op else None + + # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. + if async_op and quantizer.columnwise_usage: + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, handle + ) + elif quantizer.columnwise_usage: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + + return out, handle + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1291,7 +1533,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1315,7 +1556,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1347,7 +1587,7 @@ def gather_along_first_dim( # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - if quantizer is not None and not isinstance(inp, QuantizedTensor): + if quantizer is not None and not isinstance(inp, QuantizedTensorBase): inp = quantizer(inp) return inp, None @@ -1426,13 +1666,24 @@ def gather_along_first_dim( out_shape=out_shape, ) + # NVFP4 case + if isinstance(inp, NVFP4TensorBase) or isinstance(quantizer, NVFP4Quantizer): + assert isinstance(quantizer, NVFP4Quantizer) + return _all_gather_nvfp4( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorBase): inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format @@ -1450,7 +1701,7 @@ def gather_along_first_dim( return out, None # Dequantize quantized tensor if not supported - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorBase): warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." diff --git a/transformer_engine/pytorch/experimental/__init__.py b/transformer_engine/pytorch/experimental/__init__.py new file mode 100644 index 0000000000..11658f636b --- /dev/null +++ b/transformer_engine/pytorch/experimental/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Experimental features and APIs.""" + +from .config import set_qlinear_params, get_experimental_quantizers + + +__all__ = ["set_qlinear_params", "get_experimental_quantizers"] diff --git a/transformer_engine/pytorch/experimental/config.py b/transformer_engine/pytorch/experimental/config.py new file mode 100644 index 0000000000..fec6bc9383 --- /dev/null +++ b/transformer_engine/pytorch/experimental/config.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Config API for experimental middleware between Transformer Engine and Kitchen.""" + +import dataclasses +import enum +import os +from typing import Optional + +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.experimental import quantization +from transformer_engine.pytorch.experimental import quantization_microblock_ref +from transformer_engine.pytorch.experimental.quantization import MMParams + + +@dataclasses.dataclass() +class QLinearParams: + """Quantization parameters of linear layer. + + Contains ready-to-use quantizers for input (x), weight (w), and gradient (g) tensors. + """ + + x_quantizer: Optional[quantization.ExperimentalQuantizer] = None + w_quantizer: Optional[quantization.ExperimentalQuantizer] = None + g_quantizer: Optional[quantization.ExperimentalQuantizer] = None + + mm_fprop: Optional[MMParams] = None + mm_dgrad: Optional[MMParams] = None + mm_wgrad: Optional[MMParams] = None + + +@enum.unique +class QuantizeRecipe(enum.Enum): + """Pre-defined quantization recipes for linear layers.""" + + NON_QUANTIZE = "non_quantize" + NVFP4_REF = "nvfp4_ref" + NVFP4_REF_RHT_ONLY = "nvfp4_ref_rht_only" + NVFP4_REF_2D_QUANTIZATION_ONLY = "nvfp4_ref_2d_quantization_only" + NVFP4_REF_RHT_AND_2D_QUANTIZATION = "nvfp4_ref_rht_and_2d_quantization" + + +def get_qlinear_params_from_predefined( + recipe: QuantizeRecipe, +) -> Optional[QLinearParams]: + """Get quantization parameters for linear layer based on recipe.""" + if recipe == QuantizeRecipe.NON_QUANTIZE: + return None + if recipe == QuantizeRecipe.NVFP4_REF: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_RHT_ONLY: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + ) + raise ValueError(f"Unsupported quantize recipe: {recipe}") + + +def get_qlinear_params_from_qat_params(qat_params_idx: int) -> Optional[QLinearParams]: + """Load quantization options from Kitchen to Transformer Engine. + + TODO(etsykunov): Confirm docstring is correct. + """ + assert qat_params_idx > 0, "QAT_PARAMS is not set." + + if qat_params_idx == 6010: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF) + if qat_params_idx == 960109: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_ONLY) + if qat_params_idx == 9002: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY) + if qat_params_idx == 9003: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION) + raise ValueError(f"Unsupported QAT params index: {qat_params_idx}") + + +def set_qlinear_params( + qlinear_params: Optional[QLinearParams] = None, + layer_number: Optional[int] = None, + layer_name: Optional[str] = None, +) -> Optional[QLinearParams]: + """Set quantization parameters based on configuration. + + Args: + qlinear_params: Quantization parameters. If None, loaded from environment. + layer_number: The numerical index of this layer in the model structure. + layer_name: The name for this layer. + + Returns: + QLinearParams: The finalized quantization parameters for this layer. + """ + if qlinear_params is None: + qat_params_idx = int(os.getenv("QAT_PARAMS", "0")) + if qat_params_idx == 0: + return None + return get_qlinear_params_from_qat_params(qat_params_idx) + + # Apply layer-specific overrides + if layer_number is not None: + raise NotImplementedError("Layer-specific overrides are not supported yet.") + if layer_name is not None: + raise NotImplementedError("Layer-specific overrides are not supported yet.") + + return qlinear_params + + +def get_experimental_quantizers(fp8: bool, qlinear_params: QLinearParams): + """Replacement of _get_quantizers() in TE modules.""" + if not fp8: + raise ValueError("FP8 is required to be enabled for experimental quantization.") + input_quantizer = qlinear_params.x_quantizer + weight_quantizer = qlinear_params.w_quantizer + output_quantizer = None + grad_input_quantizer = None + grad_weight_quantizer = None + grad_output_quantizer = qlinear_params.g_quantizer + + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) diff --git a/transformer_engine/pytorch/experimental/gemm.py b/transformer_engine/pytorch/experimental/gemm.py new file mode 100644 index 0000000000..d743b577b3 --- /dev/null +++ b/transformer_engine/pytorch/experimental/gemm.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""GEMM API for experimental middleware between Transformer Engine and Kitchen.""" + +from typing import Iterable, Optional + +import torch + +from transformer_engine.pytorch.experimental.quantization import ( + MMParams, + GEMMType, + ExperimentalQuantizedTensor, +) +from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer + + +def experimental_gemm( + A: ExperimentalQuantizedTensor, + B: ExperimentalQuantizedTensor, + workspace: torch.Tensor, # pylint: disable=unused-argument + out_dtype: Optional[torch.dtype] = None, + quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument + gelu: bool = False, # pylint: disable=unused-argument + gelu_in: torch.Tensor = None, # pylint: disable=unused-argument + accumulate: bool = False, # pylint: disable=unused-argument + layout: str = "TN", + out: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + bias: Optional[torch.Tensor] = None, + use_split_accumulator: bool = False, + grad: bool = False, +) -> Iterable[Optional[torch.Tensor]]: + """Dispatch GEMM to quantizer's qgemm method.""" + assert isinstance(A, ExperimentalQuantizedTensor) and isinstance( + B, ExperimentalQuantizedTensor + ), "A and B must be ExperimentalQuantizedTensor instances" + + A, B = B, A + + # Determine GEMM type based on grad flag and layout + if not grad: + gemm_type = GEMMType.FPROP + else: + if layout == "NN": + gemm_type = GEMMType.DGRAD + elif layout == "NT": + gemm_type = GEMMType.WGRAD + else: + # Default to FPROP for other layouts + gemm_type = GEMMType.FPROP + + # Extract quantizer from QuantizedTensor to get qgemm logic + # TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B.quantizer? + quantizer = None + if hasattr(A, "quantizer") and A.quantizer is not None: + quantizer = A.quantizer + elif hasattr(B, "quantizer") and B.quantizer is not None: + quantizer = B.quantizer + else: + raise ValueError("No quantizer found in QuantizedETensor objects") + + # Create MMParams + m_params = MMParams( + out_dtype=out_dtype, + use_split_accumulator=use_split_accumulator, + ) + out_dtype = A.dtype if m_params.out_dtype is None else m_params.out_dtype + + if gemm_type == GEMMType.FPROP: + qx, sx = A.data, A.scale + qw, sw = B.data, B.scale + assert qx is not None + assert sx is not None + assert qw is not None + assert sw is not None + assert A.original_shape is not None + + # Call quantizer's qgemm method + result = quantizer.qgemm( + qx, + qw, + m_params, + out_dtype, + sx, + sw, + bias, + gemm_type=GEMMType.FPROP, + qresult_x=A, + qresult_w=B, + ) + if len(A.original_shape) > 2: + # Original input was 3D, so we need to reshape result back to 3D + batch_size = A.original_shape[0] + seq_len = A.original_shape[1] + result = result.view(batch_size, seq_len, result.shape[-1]) + elif gemm_type == GEMMType.DGRAD: + qdy, sdy = A.data, A.scale + qw_t, sw_t = B.data_t, B.scale_t + assert qdy is not None + assert sdy is not None + assert qw_t is not None + assert sw_t is not None + + result = quantizer.qgemm( + qdy, + qw_t, + m_params, + out_dtype, + sdy, + sw_t, + None, + gemm_type=GEMMType.DGRAD, + qresult_x=A, + qresult_w=B, + ) + elif gemm_type == GEMMType.WGRAD: + qdy_t, sdy_t = A.data_t, A.scale_t + qx_t, sx_t = B.data_t, B.scale_t + assert qdy_t is not None + assert sdy_t is not None + assert qx_t is not None + assert sx_t is not None + + result = quantizer.qgemm( + qdy_t, + qx_t, + m_params, + out_dtype, + sdy_t, + sx_t, + None, + gemm_type=GEMMType.WGRAD, + qresult_x=A, + qresult_w=B, + ) + + # Return in the same format as general_gemm + return result, None, None, None diff --git a/transformer_engine/pytorch/experimental/quantization.py b/transformer_engine/pytorch/experimental/quantization.py new file mode 100644 index 0000000000..9adf4dabf8 --- /dev/null +++ b/transformer_engine/pytorch/experimental/quantization.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization API for experimental middleware between Transformer Engine and Kitchen.""" + +from __future__ import annotations +import abc +import dataclasses +import enum +from typing import Iterable, Optional, Tuple, Union + +import torch + +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase, Quantizer +from transformer_engine.pytorch.experimental import utils + + +@enum.unique +class GEMMType(enum.Enum): + """Type of GEMM operation being performed.""" + + FPROP = "fprop" + DGRAD = "dgrad" + WGRAD = "wgrad" + + +@dataclasses.dataclass(frozen=True) +class MMParams: + """Matrix multiplication parameters.""" + + out_dtype: torch.dtype | None = None + # Use split accumulator for more accurate FP8 GEMM + use_split_accumulator: bool = True + + +@dataclasses.dataclass +class ExperimentalQuantizedTensor(QuantizedTensorBase): + """Base class for experimental quantized tensor containers. + + An experimental container to hold quantization result, including quantized tensor, optional + transposed quantized tensor, and corresponding decoding scales. + + data: torch.Tensor + the quantized tensor. + scale: torch.Tensor + the decoding scale for the quantized tensor. Shape depends on the scaling granularity. + - if scaling type is PER_TENSOR, it should be a 1D scalar tensor. + data_t: torch.Tensor + the transposed quantized tensor (computed lazily if needed). + scale_t: torch.Tensor + the decoding scale for the transposed quantized tensor. + dtype: torch.dtype + nominal tensor datatype. + device: torch.device + device of the tensor. + quant_dtype: Union[utils.Fp4Formats, torch.dtype] + low precision tensor datatype. + original_shape: Tuple[int, ...] + original shape of the tensor. + quantizer: ExperimentalQuantizer + Builder class for quantized tensor. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + data_t: Optional[torch.Tensor] = None + scale_t: Optional[torch.Tensor] = None + global_amax_row: Optional[torch.Tensor] = None + global_amax_col: Optional[torch.Tensor] = None + + dtype: Optional[torch.dtype] = None + device: Optional[torch.device] = None + quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None + original_shape: Optional[Tuple[int, ...]] = None + quantizer: Optional[ExperimentalQuantizer] = None + + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware.""" + return True + + def get_quantizer(self) -> ExperimentalQuantizer: + """Get builder for QuantizedExperimentalTensor + + Quantizer can be used for in-place operations. + + """ + if self.quantizer is not None: + return self.quantizer + raise ValueError("Quantizer is not set") + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], ExperimentalQuantizedTensor]: + """Prepare the quantization result for saving for backward""" + tensors = [self.data, self.data_t, self.scale, self.scale_t] + self.data = None + self.data_t = None + self.scale = None + self.scale_t = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the quantization result from the saved tensors""" + self.data = tensors[0] + self.data_t = tensors[1] + self.scale = tensors[2] + self.scale_t = tensors[3] + return tensors[4:] + + def dequantize(self, *args, **kwargs) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + # Compatibility + @property + def _data(self): + return self.data + + @_data.setter + def _data(self, value): + self.data = value + + @property + def _scale_inv(self): + return self.scale + + @_scale_inv.setter + def _scale_inv(self, value): + self.scale = value + + +class ExperimentalQuantizer(Quantizer): + """Experimental Quantizer class + + Defines the interface for experimental quantizers. + """ + + def __init__(self, *, rowwise: bool, columnwise: bool) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.internal = True + + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware""" + return True + + @abc.abstractmethod + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + m_params: MMParams, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + gemm_type: GEMMType = GEMMType.FPROP, + qresult_x: ExperimentalQuantizedTensor | None = None, + qresult_w: ExperimentalQuantizedTensor | None = None, + ) -> torch.Tensor: + """Quantized GEMM interface.""" + + def dequantize(self, *args, **kwargs) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + def update_quantized(self, *args, **kwargs) -> torch.Tensor: + """Update the quantized tensor with the given tensor in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_quantized function" + ) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensorBase: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement make_empty function" + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement calibrate function" + ) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_compatible_recipe function" + ) diff --git a/transformer_engine/pytorch/experimental/quantization_microblock_ref.py b/transformer_engine/pytorch/experimental/quantization_microblock_ref.py new file mode 100644 index 0000000000..da749d237f --- /dev/null +++ b/transformer_engine/pytorch/experimental/quantization_microblock_ref.py @@ -0,0 +1,811 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVFP4 implementations for experimental middleware between Transformer Engine and Kitchen.""" + +from typing import Optional, Tuple + +import torch + +from transformer_engine.pytorch.experimental import quantization +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.experimental.quantization import ( + ExperimentalQuantizedTensor, + ExperimentalQuantizer, +) + + +def cast_to_fp4x2(x): + """Quantize a tensor to FP4 E2M1 and store in a byte tensor""" + + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def cast_from_fp4x2(x, dq_dtype): + """Dequantize FP4 E2M1 tensor that has been represented in a byte tensor""" + fp4_values = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + device=x.device, + dtype=dq_dtype, + ) + + # Convert to long integers for indexing + second_bit = torch.div(x, 16, rounding_mode="floor").to(torch.long) + first_bit = (x - second_bit * 16).to(torch.long) + + # Use the long integers to index fp4_values + first_bit_values = fp4_values[first_bit] + second_bit_values = fp4_values[second_bit] + + result = torch.zeros( + (first_bit_values.shape[0], first_bit_values.shape[1] * 2), + device=x.device, + dtype=dq_dtype, + ) + result[:, ::2] = first_bit_values + result[:, 1::2] = second_bit_values + + return result + + +def cast_to_e8(decode_scale): + """Cast to a value that is representable in FP8 E8M0. + + The result is in FP32, not FP8 E8M0. + """ + max_exponent = torch.tensor(127, device=decode_scale.device, dtype=torch.float32) + exponent = torch.ceil(torch.log2(decode_scale)) + exponent = torch.clamp(exponent, min=-max_exponent, max=max_exponent) + + return torch.tensor(2.0, device=decode_scale.device, dtype=torch.float32) ** exponent + + +def cast_to_e4m3(decode_scale, global_amax): + """Scale and cast to FP8 E4M3. + + decode_scale is actually the encoding scaling factor. global_amax + can be any data tensor and not just the amax. + + TODO(etsykunov): Make less unintuitive. + """ + decode_scale = decode_scale * global_amax + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + return decode_scale.to(torch.float8_e4m3fn) + + +def high_precision_gemm_ref( + a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + accumulate: bool = False, + is_a_transposed: bool = False, + is_b_transposed: bool = False, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scale_alpha: float = 1.0, +) -> torch.Tensor: + """GEMM implementation with unquantized data""" + # Handle transpositions + mat1, mat2 = a, b + if is_a_transposed: + mat1 = a.T + if is_b_transposed: + mat2 = b.T + + # Ensure dtype compatibility for torch.addmm + mat1 = mat1.to(out_dtype) + mat2 = mat2.to(out_dtype) + + # Determine output shape + y_shape = (mat1.size(0), mat2.size(1)) + + if bias is not None: + assert not accumulate, "Bias is not supported with accumulation" + bias = bias.to(out_dtype) + # With bias case + if out_dtype == torch.float32: + y_ref = torch.addmm(bias.repeat(mat1.size(0), 1), mat1, mat2, beta=1, alpha=1) + else: + y_ref = torch.addmm(bias, mat1, mat2, beta=1, alpha=scale_alpha) + else: + # Without bias case + if accumulate and out is not None: + y_ref = out.clone().to(out_dtype) + else: + y_ref = torch.zeros(y_shape, dtype=out_dtype, device=a.device) + torch.addmm(y_ref, mat1, mat2, beta=1, alpha=scale_alpha, out=y_ref) + + return y_ref + + +class NVFP4TensorRef(ExperimentalQuantizedTensor): + """NVFP4 tensor for middleware between Transformer Engine and Kitchen""" + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"dtype={self.dtype}, " + f"device={self.device}, " + f"quant_dtype={self.quant_dtype}, " + f"data={self.dequantize(dtype=self.dtype)}, " + f"original_shape={self.original_shape}" + ")" + ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> ExperimentalQuantizedTensor: + """In-place update of quantized data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, ExperimentalQuantizedTensor): + return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) + self.get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from quantized tensor + """ + if dtype is None: + dtype = self.dtype + + # Ignore data_t for now + assert self.data is not None, "QuantizedTensor has no valid tensor data" + assert self.scale is not None, "QuantizedTensor has no valid scale" + tensor_data = self.data + tensor_scale = self.scale + # Dispatch to the quantizer + return self.get_quantizer().dequantize(tensor_data, tensor_scale, dtype=dtype) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """Generate or remove quantized data based on provided usage.""" + has_data = self.data is not None + has_data_transpose = self.data_t is not None + needs_data = has_data + needs_data_transpose = has_data_transpose + + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP8 data is required to generate FP8 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self.data = None + if not needs_data_transpose: + self.data_t = None + + def _create_transpose(self): + """Create transposed quantized tensor""" + if not self.data.is_contiguous(): + self.data = self.data.contiguous() + self.data_t = self.data.t().contiguous() + self.scale_t = self.scale + + def size(self, *args, **kwargs): # pylint: disable=unused-argument + """Return the original tensor shape, not the internal packed data shape. + + FP4 quantization packs two 4-bit values into each 8-bit value, which reduces + the second dimension by half. This method returns the logical shape that + users expect, not the internal packed storage shape. + """ + assert self.original_shape is not None + return torch.Size(self.original_shape) + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0], + dtype=torch.float32, + ) + + +class NVFP4QuantizerRef(ExperimentalQuantizer): + """NVFP4 quantizer for middleware between Transformer Engine and Kitchen""" + + def __init__( + self, + dtype: utils.Fp4Formats, + rowwise: bool = True, + columnwise: bool = True, + pow_2_scales: bool = False, + eps: float = 0.0, + quant_tile_shape: Tuple[int, int] = (1, 16), + with_rht: bool = False, + with_random_sign_mask: bool = True, + ): + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = dtype + self.pow_2_scales = pow_2_scales + self.eps = eps + self.quant_tile_shape = quant_tile_shape + self.with_rht = with_rht + self.with_random_sign_mask = with_random_sign_mask + + @staticmethod + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: + """Apply randomized Hadamard transform without random signs (reference path). + + This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))). + """ + # Only apply when enabled + if not self.with_rht: + return x + + # RHT dimension equals the quantization tile length (NVFP4 uses 16) + rht_dim = self.quant_tile_shape[1] + assert ( + x.shape[-1] % rht_dim == 0 + ), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + + # Build H and scale + H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + @staticmethod + def _recover_swizzled_scales( + swizzled_scale: bool, scale: torch.Tensor, m: int, n: int, block_length: int + ) -> torch.Tensor: + if not swizzled_scale: + return scale + rounded_m = utils.roundup_div(m, 128) * 128 + scale_n = utils.roundup_div(n, block_length) + rounded_n = utils.roundup_div(scale_n, 4) * 4 + # Recover swizzled scaling factor layout -> linear layout + tmp = torch.reshape(scale, (rounded_m // 128, rounded_n // 4, 32, 4, 4)) + # after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4] + tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + result = torch.reshape(tmp, (rounded_m, rounded_n)) + return result[:m, :scale_n] + + @classmethod + def _quantize_blockwise_reference( + cls, + x: torch.Tensor, + global_amax: torch.Tensor, + tile_len_x: int, + tile_len_y: int, + *, + pow_2_scales: bool, + eps: float, # pylint: disable=unused-argument + ) -> Tuple[torch.Tensor, torch.Tensor]: + + assert x.ndim == 2 + using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 + m, n = x.shape + # Compute vec_max based on the original x (before reshape) + # For 1D quantization: amax over each row chunk of 16 + # For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax + if using_2d_quantization: + # x shape: (128, 128) + x_blocks = ( + x.unfold(0, tile_len_y, tile_len_y) + .unfold(1, tile_len_x, tile_len_x) + .to(torch.float32) + ) # (8, 8, 16, 16) + block_amax = torch.amax(torch.abs(x_blocks), dim=(-1, -2)) # (8, 8) + # Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows + vec_max = block_amax.repeat_interleave(tile_len_y, dim=0).unsqueeze(-1) # (128, 8, 1) + else: + # x shape: (128, 128) + x_reshaped = x.view(m, n // tile_len_x, tile_len_x) # (128, 8, 16) + vec_max = torch.amax(torch.abs(x_reshaped), dim=-1, keepdim=True).to( + torch.float32 + ) # (128, 8, 1) + x = x.view(m, n // tile_len_x, tile_len_x) + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) + + if pow_2_scales: + decode_scale = cast_to_e8(decode_scale) + encode_scale = torch.div( + torch.tensor(1.0, device=x.device, dtype=torch.float32), + decode_scale.to(torch.float32), + ) + else: + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=global_encode_scale.device, + dtype=torch.float32, + ), + ) + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + decode_scale = decode_scale * global_encode_scale + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + + scaled_x = x.to(torch.float32) * encode_scale + + clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) + + return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + + @staticmethod + def _pad_tensor( + tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] + ) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = tensor.shape + padding_needed_rows = 0 + padding_needed_cols = 0 + + if row_divisor is not None and M % row_divisor != 0: + padding_needed_rows = row_divisor - (M % row_divisor) + # Check and calculate column padding if col_divisor is provided + if col_divisor is not None and N % col_divisor != 0: + padding_needed_cols = col_divisor - (N % col_divisor) + + # Return original tensor if no padding is needed + if padding_needed_rows == 0 and padding_needed_cols == 0: + return tensor + + # pad the tensor + out = torch.nn.functional.pad( + tensor, + (0, padding_needed_cols, 0, padding_needed_rows), + mode="constant", + value=0.0, + ).contiguous() + + return out + + @staticmethod + def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = original_size + out = tensor[:M, :N].contiguous() + return out + + def _quantize(self, tensor: torch.Tensor) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + ]: + """ + Python implementation of microblock FP4 quantization. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to quantize (should be 2D) + + Returns + ------- + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor] + (qx, sx, qx_t, sx_t, global_amax) where: + - qx: quantized data in row-major order (if rowwise_usage), None otherwise + - sx: scale tensor for qx (if rowwise_usage), None otherwise + - qx_t: quantized data in column-major order (if columnwise_usage), None otherwise + - sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise + - global_amax: global amax tensor + """ + if self.pow_2_scales: + assert self.quant_tile_shape == ( + 1, + 32, + ), "MXFP4 only supports 1x32 tile shape." + # TODO(etsykunov): Fix bug where global_amax_row and + # global_amax_col are not defined + # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) + else: + assert self.quant_tile_shape in ( + (1, 16), + (16, 16), + ), "NVFP4 only supports 1x16 or 16x16 tile shape." + # Prepare inputs once so we can reuse for both amax and quantization + # Row-input will always be the original input. + row_input = tensor + col_input = ( + self._apply_rht(tensor.t().contiguous()) + if self.with_rht + else tensor.t().contiguous() + ) + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) + + transpose_scales = False + + M, N = tensor.shape + if self.rowwise_usage: + x_input = row_input + x_padded = self._pad_tensor( + x_input, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx, sx = self._quantize_blockwise_reference( + x_padded, + global_amax_row, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + if transpose_scales: + sx = sx.T + + qx = self._rm_pad_tensor(qx, (M, N // 2)) + + else: + qx = None + sx = None + + if self.columnwise_usage: + x_t = col_input + x_t_padded = self._pad_tensor( + x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + + qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) + + if transpose_scales: + sx_t = sx_t.T + else: + qx_t = None + sx_t = None + + return qx, sx, qx_t, sx_t, global_amax_row, global_amax_col + + def quantize( + self, + tensor: torch.Tensor, + **kwargs, # pylint: disable=unused-argument + ) -> NVFP4TensorRef: + # sanity checks + assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype." + + # Make it work with 3D tensors + original_shape = tensor.shape + if tensor.ndim > 2: + tensor = tensor.view(-1, tensor.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(tensor) + + return NVFP4TensorRef( + data=qx, + scale=sx, + data_t=qx_t, + scale_t=sx_t, + global_amax_row=global_amax_row, + global_amax_col=global_amax_col, + dtype=tensor.dtype, + device=tensor.device, + quant_dtype=self.dtype, + quantizer=self, + original_shape=original_shape, + ) + + def update_quantized( + self, + src: torch.Tensor, + dst: ExperimentalQuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> ExperimentalQuantizedTensor: + """Update the quantized tensor with the given tensor in-place + + Parameters + ---------- + src: torch.Tensor + Source tensor to copy from + dst: ExperimentalQuantizedTensor + Destination ExperimentalQuantizedTensor to update + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + """ + # Handle noop flag + if noop_flag is not None and noop_flag.item() != 0: + return dst + + # Make sure input is in expected format + if not src.is_contiguous(): + src = src.contiguous() + + # Store the original shape and reshape for processing + original_shape = src.shape + if src.ndim > 2: + src = src.view(-1, src.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax = self._quantize(src) + + # Update the destination with new data + dst.data = qx + dst.scale = sx + dst.data_t = qx_t + dst.scale_t = sx_t + dst.global_amax = global_amax + dst.dtype = src.dtype + dst.quant_dtype = self.dtype + dst.original_shape = original_shape + + return dst + + @property + def supports_allgather_fp8(self) -> bool: + """Whether the tensor data can be all-gathered with an FP8 all-gather. + + TODO(etsykunov): Confirm docstring is correct. Also, this API + seems too FP8-specific and should be reconsidered. + """ + return False + + def transpose_qresult( + self, qresult: quantization.ExperimentalQuantizedTensor + ) -> quantization.ExperimentalQuantizedTensor: + """Convert row-wise data to column-wise data (?) + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Transpose qresult is not implemented for FP4.") + + @property + def supports_dequantize(self) -> bool: + """Whether quantized tensor can converted to high-precision tensor""" + return False + + @property + def is_data_t_transposed_in_memory(self) -> bool: + """Whether column-wise data is stored in transposed layout. + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Not implemented yet") + + def dequantize( + self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None + ) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError("Not implemented yet") + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + m_params: quantization.MMParams, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP, + qresult_x: quantization.ExperimentalQuantizedTensor | None = None, + qresult_w: quantization.ExperimentalQuantizedTensor | None = None, + ) -> torch.Tensor: + assert bias is None, "Bias is implemented for FP4 GEMM." + + high_precision_x = cast_from_fp4x2(qx, out_dtype) + high_precision_w = cast_from_fp4x2(qw, out_dtype) + + if self.pow_2_scales: + + if sx.dtype == torch.uint8: + # if scaling factor is stored in uint8 container + sx = torch.tensor(2.0, device=sx.device, dtype=torch.float32) ** ( + ( + sx.to(torch.float32) + - torch.tensor(127, device=sx.device, dtype=torch.float32) + ) + ) + sw = torch.tensor(2.0, device=sw.device, dtype=torch.float32) ** ( + ( + sw.to(torch.float32) + - torch.tensor(127, device=sw.device, dtype=torch.float32) + ) + ) + else: + # if scaling factor is torch.float8_e8m0fnu + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + alpha = torch.tensor(1.0, device=high_precision_x.device, dtype=torch.float32) + + else: + + assert qresult_x is not None + assert qresult_w is not None + + assert qresult_x.global_amax_row is not None + assert qresult_w.global_amax_col is not None + + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + factor = 6.0 * 6.0 * 448.0 * 448.0 + + if gemm_type == quantization.GEMMType.WGRAD: + partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col + else: + partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row + alpha = torch.div(partial_alpha, factor).squeeze(-1) + + M, K = high_precision_x.shape + N, K_w = high_precision_w.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + assert K % 32 == 0, "K dimension must be divisible by 32" + assert N % 8 == 0, "N dimension must be divisible by 8" + + block_length = 32 if self.pow_2_scales else 16 + + grid_k = K // block_length + + assert sx.shape == ( + M, + K // block_length, + ), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}" + assert sw.shape == ( + N, + K // block_length, + ), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}" + + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # below implementation is to match the FP4 tensor core implementation + # Each output element (i, j) is fp32 accumulation of (K // block_length) inner products + # Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32 + # Then batch the computation in M, N dimension + for k in range(grid_k): + k_start = k * block_length + k_end = k_start + block_length + + qx_block = high_precision_x[:, k_start:k_end].clone().contiguous() + qw_block = high_precision_w[:, k_start:k_end].clone().contiguous() + + # Extract scaling factors for the current blocks + sx_block = sx[:, k] + sw_block = sw[:, k] + + y += torch.outer(sx_block, sw_block) * high_precision_gemm_ref( + qx_block, qw_block, torch.float32, is_b_transposed=True + ) + + if not self.pow_2_scales and K > 0: + # only apply global scale for NVFP4 and non-empty cases + y = alpha * y + + # accumulation happens at epilogue in float32 + if accumulate: + assert out is not None, "Output tensor must be provided for accumulation." + y += out.to(torch.float32) + else: + assert out is None, "Output tensor should be None when accumulate is False." + + y = y.to(out_dtype) + return y diff --git a/transformer_engine/pytorch/experimental/utils.py b/transformer_engine/pytorch/experimental/utils.py new file mode 100644 index 0000000000..20dc6f11b0 --- /dev/null +++ b/transformer_engine/pytorch/experimental/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utility functions for experimental middleware between Transformer Engine and Kitchen.""" + +import enum + +import torch + + +HIGH_PRECISION_FLOAT_DTYPES = ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, +) + + +class Fp4Formats(enum.Enum): + """FP4 data format""" + + E2M1 = "e2m1" + + +def roundup_div(x: int, y: int) -> int: + """Round up division""" + assert x >= 0 + assert y > 0 + return (x + y - 1) // y diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8f9dbd88d0..a75a03bfa5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -21,6 +21,7 @@ MXFP8BlockScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, ) from .constants import dist_group_type @@ -53,6 +54,13 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_nvfp4_support() -> Tuple[bool, str]: + """Return if nvfp4 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + + def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if ( @@ -105,6 +113,13 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 +def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: + """Get fp4 data type according to recipe and tensor""" + if fp4_recipe.fp4_format == Format.E2M1: + return tex.DType.kFloat4E2M1 + raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") + + def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( @@ -142,6 +157,8 @@ class FP8GlobalStateManager: reason_for_no_mxfp8 = "" fp8_block_scaling_available = None reason_for_no_fp8_block_scaling = None + nvfp4_available = None + reason_for_no_nvfp4 = "" @classmethod def reset(cls) -> None: @@ -205,6 +222,13 @@ def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: ) return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @classmethod + def is_nvfp4_available(cls) -> Tuple[bool, str]: + """Return if NVFP4 support is available.""" + if cls.nvfp4_available is None: + cls.nvfp4_available, cls.reason_for_no_nvfp4 = check_nvfp4_support() + return cls.nvfp4_available, cls.reason_for_no_nvfp4 + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -481,6 +505,9 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, Float8BlockScaling): fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() assert fp8_block_available, reason_for_no_fp8_block + if isinstance(fp8_recipe, NVFP4BlockScaling): + nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available() + assert nvfp4_available, reason_for_no_nvfp4 @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -837,6 +864,8 @@ def create( cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): cls = Float8BlockScalingRecipeState + elif recipe.nvfp4(): + cls = NVFP4BlockScalingRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( @@ -1084,3 +1113,79 @@ def make_quantizers(self) -> list: ] ) ) + + +class NVFP4BlockScalingRecipeState(RecipeState): + """Configuration for NVFP4 quantization. + + NVFP4 quantization does not require state. + + """ + + recipe: NVFP4BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: NVFP4BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp4_te_dtype(recipe) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + from .tensor.nvfp4_tensor import NVFP4Quantizer + + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward. It assumes forward quantizers are + # ordered [input, weight, output, ...] and backward quantizers + # are ordered [grad_output, grad_input, ...]. This doesn't + # play nicely with fusible ops: Linear op doesn't own output + # or grad input quantizers, Quantize op only owns input and + # grad output quantizers. + + if self.mode == "forward": + + def _make_quantizer(idx: int) -> NVFP4Quantizer: + qparams = ( + self.recipe.fp4_quant_fwd_weight + if idx % 3 == 1 + else self.recipe.fp4_quant_fwd_inp + ) + return NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + ) + + return [_make_quantizer(idx) for idx in range(self.num_quantizers)] + + if self.mode == "backward": + return [ + NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, + stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + ) + for _ in range(self.num_quantizers) + ] + + raise RuntimeError(f"Unexpected recipe mode ({self.mode})") diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index e4fa0c7411..3505a68307 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,16 +4,18 @@ """Internal function used by multiple modules.""" -from typing import Any, List, Optional, Tuple, Union, Callable -from dataclasses import dataclass - +import dataclasses import queue +from typing import Any, Callable, List, Optional, Tuple, Union + import torch from .. import cpp_extensions as tex +from .. import experimental from ..constants import TE_DType -from ..utils import get_default_init_method from ..export import is_in_onnx_export_mode +from ..tensor.utils import is_experimental +from ..utils import get_default_init_method def _get_normalization_func(normalization: str, forward: bool): @@ -170,7 +172,33 @@ def noop_cat( return _NoopCatFunc.apply(dim, *tensors) -@dataclass +def get_module_quantizers( + module: torch.nn.Module, + fp8_output: bool, + fp8_grad: bool, + debug: bool, +): + """Return the 6-tuple of quantizers for a module in a centralized way. + + Routing policy: + - If experimental quantization is enabled via environment and module.fp8 is True, + return experimental quantizers. + - Otherwise, return the module's own quantizers (debug or regular). + """ + if getattr(module, "fp8", False) and is_experimental(): + # TODO(etsykunov): Quantizer instantiation should be better + # done in the module's constructor + qlinear_params = experimental.config.set_qlinear_params() + + if qlinear_params is not None: + return experimental.config.get_experimental_quantizers(module.fp8, qlinear_params) + + if not debug: + return module._get_quantizers(fp8_output, fp8_grad) + return module._get_debug_quantizers(fp8_output, fp8_grad) + + +@dataclasses.dataclass class _ParameterInitMeta: """ Stores essential metadata needed to support deferred parameter initialization. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 70366dabe5..bf4fb97d2d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -27,6 +27,7 @@ DelayedScalingRecipeState, Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, + NVFP4BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -39,6 +40,7 @@ from ..constants import dist_group_type from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase @@ -76,7 +78,8 @@ class UserBufferQuantizationMode(Enum): def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - return 33_554_432 + # 32 MiB for NVFP4 GEMM, plus 256 B for misc scales + return 32 * 1024 * 1024 + 256 return 4_194_304 @@ -757,6 +760,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8BlockScalingRecipeState ): return + if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -1218,15 +1223,13 @@ def grad_output_preprocess( ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - if isinstance(quantizer, Float8BlockQuantizer): + # TODO(ksivaman): Re-add fusion once kernel is available. + if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)): # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance( - grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), - ): + if not isinstance(grad_output, QuantizedTensorBase): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4d30be414e..6dbbd335eb 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -29,6 +30,7 @@ from ..fp8 import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, + assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -53,7 +55,7 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, WeightGradStore +from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers from ..tensor.quantized_tensor import ( QuantizedTensor, QuantizedTensorBase, @@ -135,6 +137,8 @@ def forward( if ub_name is not None: nvtx_label = f"{nvtx_label}.{ub_name}" + with_input_all_gather = parallel_mode == "column" and sequence_parallel + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -144,6 +148,7 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") @@ -157,7 +162,6 @@ def forward( weight_requires_grad = weight.requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad - with_input_all_gather = parallel_mode == "column" and sequence_parallel # Configure Userbuffers communication (comm+GEMM overlap) if debug: # turn off userbuffers in debug mode @@ -190,11 +194,13 @@ def forward( # Avoid quantized norm kernel if norm output will be returned # or if a gather of ln_out must be in high precision. + experimental = is_experimental(input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not experimental ) # Apply normalization @@ -240,7 +246,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = input_quantizer - if not with_quantized_norm: + # experimental recipe doesn't need to support quantized AG + if not with_quantized_norm and not experimental: ln_out = quantizer(ln_out) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather @@ -1422,6 +1429,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif other recipes (mxfp8, etc) def reset_layer_norm_parameters(self) -> None: @@ -1526,11 +1535,7 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad) - ) + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) if debug: if self.no_debug_features_active(quantizers): debug = False @@ -1763,6 +1768,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9f799c5538..a0e5f3aedd 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -17,6 +17,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -40,6 +41,7 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -64,6 +66,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload @@ -114,7 +117,8 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): } # no activation fusion written yet # Per-tensor current scaling or fp8 blockwise scaling: [] - if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + # TODO(ksivaman): Fuse nvfp4 act once kernel is available. + if recipe.float8_current_scaling() or recipe.float8_block_scaling() or recipe.nvfp4(): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), @@ -211,6 +215,7 @@ def forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -258,11 +263,13 @@ def forward( # high precision layernorm output and output of the linear are returned # for debug: : layernorm output = High precision to enable processing of this norm + experimental = is_experimental(fc1_input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not experimental ) # Apply normalization @@ -302,7 +309,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer - if not with_quantized_norm: + # experimental recipe doesn't need to support quantized AG + if not with_quantized_norm and not experimental: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -548,6 +556,7 @@ def forward( if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, @@ -673,6 +682,7 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None @@ -1014,7 +1024,10 @@ def fc2_wgrad_gemm( if ctx.fp8: # TODO float8 blockwise current scaling has no bgrad fusion for now - if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): + # TODO(ksivaman): Re-add fusion once kernel is available. + if isinstance( + ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) + ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) else: @@ -1690,6 +1703,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_layer_norm_parameters(self) -> None: @@ -1908,7 +1923,10 @@ def _get_quantizers(self, fp8_output): fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, - columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), + columnwise=isinstance( + fc2_input_quantizer, + (MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer), + ), ) fc1_input_quantizer.internal = True if fp8_output: @@ -2113,6 +2131,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_mlp.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.set_parallel_mode: + # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.set_parallel_mode: + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].amax_reduction_group = self.tp_group + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: """Get the weight tensors of the module.""" return [self.fc1_weight, self.fc2_weight] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7e526245c1..cf7f58947b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -25,7 +25,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import noop_cat, WeightGradStore +from ._common import noop_cat, WeightGradStore, get_module_quantizers from ..fp8 import FP8GlobalStateManager from ..utils import ( cast_if_needed, @@ -35,6 +35,7 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, ) @@ -65,6 +66,7 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.utils import is_experimental from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState @@ -151,6 +153,9 @@ def forward( ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG + # experimental recipe check + experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer) + # ------------------------------------------------------ # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -161,6 +166,7 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer @@ -172,7 +178,7 @@ def forward( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorBase): + if not isinstance(inputmat, QuantizedTensorBase) and not experimental: own_quantized_input = True input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( @@ -442,6 +448,7 @@ def forward( ctx.main_grad_func = lambda: weight.main_grad ctx.debug = debug + ctx.experimental = experimental ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -609,7 +616,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(inputmat, QuantizedTensorBase): # Input tensor is already quantized pass - elif ctx.debug: + elif ctx.debug or ctx.experimental: # Debug quantizer will be applied immediately before wgrad GEMM pass else: @@ -698,6 +705,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # dgrad GEMM # Note: dx = dy * w + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weight_fp8, @@ -1326,6 +1334,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_parameters(self, defer_init=False): @@ -1410,12 +1420,7 @@ def forward( weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad) - ) - + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) if debug: if self.no_debug_features_active(quantizers): debug = False @@ -1655,6 +1660,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 70c70c54d2..f8f95cf194 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -926,6 +926,7 @@ def op_forward( input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) weight_quantizer.set_usage(rowwise=True, columnwise=False) + # Recipe-specific configuration recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale @@ -940,6 +941,13 @@ def op_forward( if self.sequence_parallel and self.tensor_parallel_mode == "row": grad_output_quantizer.with_amax_reduction = True grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + if recipe.nvfp4(): + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Get autocast dtype if needed if torch.is_autocast_enabled(): diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 7fa12cc087..43846512d7 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -54,6 +54,7 @@ def get_all_tensor_types(): Float8BlockwiseQTensor, Float8BlockwiseQTensorBase, ) + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase all_tensor_types = [ torch.Tensor, @@ -64,5 +65,7 @@ def get_all_tensor_types(): MXFP8TensorBase, Float8BlockwiseQTensor, Float8BlockwiseQTensorBase, + NVFP4Tensor, + NVFP4TensorBase, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py new file mode 100644 index 0000000000..df187d6741 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py @@ -0,0 +1,348 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for NVFP4Tensor""" + +from __future__ import annotations +from collections.abc import Iterable +import functools +import math +from typing import Any, Dict, Optional, Tuple, Union +import warnings + +import torch + +# import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ..quantized_tensor import QuantizedTensorBase + +# from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ..quantized_tensor import Quantizer +from ...utils import _empty_tensor + + +@functools.lru_cache(maxsize=None) +def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Values representable in FP4 E2M1 format""" + return torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + device=device, + dtype=dtype, + ) + + +class _FromNVFP4Func(torch.autograd.Function): + """Cast from NVFP4 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: NVFP4TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + + # Dequantize row-wise data + if tensor._rowwise_data is not None: + ### TODO(tmoon): Debug dequantize kernel and remove unfused impl + # return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) + + # Tensor properties + shape = list(tensor._rowwise_data.size()) + shape[-1] *= 2 + device = tensor._rowwise_data.device + + # Convert FP4E2M1 values to FP32 + data = tensor._rowwise_data.view(torch.uint8).to(torch.int32) + data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape) + data = _fp4_e2m1_vals(device, dtype=torch.float32)[data] + data = data.to(torch.float32).contiguous() + + # Convert FP8E4M3 block scales to FP32 + block_scales = tensor._rowwise_scale_inv + block_scales = block_scales.reshape(-1, block_scales.size(-1)) + block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16] + block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32) + + # Convert amax to FP32 tensor scale + tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max + + # Apply scales + block_data = data.view(-1, 16) + block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1) + + return data.to(dtype) + + if tensor._columnwise_data is not None: + raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") + raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class NVFP4TensorBase(QuantizedTensorBase): + """Mixin class that holds data attributes of NVFP4Tensor. + + NVFP4Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + _fp4_dtype: TE_DType + _amax_rowwise: torch.Tensor + _amax_columnwise: torch.Tensor + + def __new__( + cls, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + quantizer: Optional[Quantizer], + *args, + **kwargs, + ): + + instance = super().__new__(cls, *args, **kwargs) + + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._fp4_dtype = fp4_dtype + instance._quantizer = quantizer.copy() if quantizer is not None else None + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + instance._amax_rowwise = amax_rowwise + instance._amax_columnwise = amax_columnwise + + return instance + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + for t in ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ): + if t is not None: + t.data = _empty_tensor() + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "amax_rowwise": self._amax_rowwise, + "amax_columnwise": self._amax_columnwise, + "fp4_dtype": self._fp4_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]: + """Prepare the tensor base for saving for backward""" + tensors = [ + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ] + self._rowwise_data = None + self._columnwise_data = None + self._rowwise_scale_inv = None + self._columnwise_scale_inv = None + self._amax_rowwise = None + self._amax_columnwise = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + self._rowwise_scale_inv = tensors[2] + self._columnwise_scale_inv = tensors[3] + self._amax_rowwise = tensors[4] + self._amax_columnwise = tensors[5] + return tensors[6:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromNVFP4Func.forward(None, self, dtype) + + def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: + # pylint: disable=missing-function-docstring + + # Infer tensor shape + shape = None + if self._rowwise_data is not None: + byte_shape = list(self._rowwise_data.size()) + shape = byte_shape[:-1] + [byte_shape[-1] * 2] + elif self._columnwise_data is not None: + warnings.warn("Attempting to get shape of NVFP4 tensor with only column-wise data.") + byte_shape = list(self._columnwise_data.size()) + shape = byte_shape[1:-1] + [byte_shape[-1] * 2, byte_shape[0]] + if shape is None: + raise RuntimeError("Attempted to get shape of NVFP4 tensor with no data") + + # Return shape or dim + if dim is None: + return torch.Size(shape) + return shape[dim] + + def view(self, shape: torch.Size): + # pylint: disable=missing-function-docstring + + # Return input tensor if view not needed + cur_shape = self.size() + if shape is None or shape == cur_shape: + return self + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if self._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = self._rowwise_data.view(byte_shape) + if self._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = self._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4TensorBase( + rowwise_data=new_rowwise_data, + rowwise_scale_inv=self._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=self._columnwise_scale_inv, + amax_rowwise=self._amax_rowwise, + amax_columnwise=self._amax_columnwise, + quantizer=self._quantizer, + fp4_dtype=self._fp4_dtype, + ) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "NVFP4TensorBase(" + f"rowwise_scaled_data={data_rowwise}," + f"rowwise_scale_inv={self._rowwise_scale_inv}," + f"amax_rowwise={self._amax_rowwise}," + f"amax_columnwise={self._amax_columnwise}," + ")" + ) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """ + For the NVFP4 format, columnwise scaled output is only produced by x2 + scaling kernels, so this function only disables usages. + """ + + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data + if rowwise_usage: + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses" + ) + if self._amax_rowwise is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing per tensor" + " row-scaled scale-inverse" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + self._amax_rowwise = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing column-scaled scale-inverses" + ) + if self._amax_columnwise is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing per tensor column-scaled scale-inverse" + ) + else: + self._columnwise_data = None + self._columnwise_scale_inv = None + self._amax_columnwise = None diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 321c351dd0..d7f5f8c7d2 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Tensor class with FP8 data""" +"""Tensor class with MXFP8 data""" from __future__ import annotations from collections.abc import Iterable import math @@ -186,8 +186,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): Reciprocal of the scaling factor applied when casting to FP8, i.e. the scaling factor that must be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. + precision. dtype: torch.dtype, default = torch.float32 Nominal tensor datatype. diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py new file mode 100644 index 0000000000..b12e89956a --- /dev/null +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -0,0 +1,898 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with NVFP4 data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple, Union +import functools + +import torch +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe +from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type +from ..utils import ( + canonicalize_process_group, + devices_match, + round_up_to_nearest_multiple, +) + +from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc + +aten = torch.ops.aten + + +def get_no_random_sign_vector() -> torch.Tensor: + """Non-random sign vector for Hadamard transform.""" + return torch.tensor([1], dtype=torch.float32) + + +def get_sign_from_vector(vector: torch.Tensor) -> int: + """Convert sign vector to bitmask. + + Used for random Hadamard transform. + + """ + mask = 0 + for i, v in enumerate(vector): + mask |= (v == -1) << i + return mask + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded random signs for Hadamard transform. + + https://xkcd.com/221/ + + """ + return torch.tensor( + [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1], + dtype=torch.float32, + ) + + +def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor: + """Construct a 16x16 Hadamard matrix.""" + assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported." + hadamard_scale = 1 / math.sqrt(hadamard_dimension) + return ( + torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1], + [1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1], + [1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1], + [1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1], + [1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1], + [1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1], + [1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1], + [1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1], + [1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1], + [1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1], + [1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1], + [1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1], + [1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1], + [1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1], + ], + dtype=torch.float32, + ) + * hadamard_scale + ) + + +@functools.lru_cache(maxsize=None) +def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor: + """Construct matrix used in random Hadamard transform.""" + hadamard_dimension = 16 + if with_random_sign_mask: + signs = get_wgrad_sign_vector() + else: + signs = get_no_random_sign_vector() + sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32) + rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension) + return rht_matrix.to(dtype=torch.bfloat16).cuda() + + +@functools.lru_cache(maxsize=None) +def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int: + """Sign mask for random Hadamard transform.""" + if with_random_sign_mask: + return get_sign_from_vector(get_wgrad_sign_vector()) + return 0 + + +class NVFP4Quantizer(Quantizer): + """Builder class for NVFP4 tensors with NV block scaling""" + + dtype: TE_DType + """Random Hadamard Transform""" + with_rht: bool + with_post_rht_amax: bool + """amax reduction options""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] + + """2D block scaling, only applicable for weights.""" + with_2d_quantization: bool + + """Stochastic rounding, only applicable for gradients.""" + stochastic_rounding: bool + + """RHT matrix random sign mask""" + rht_matrix_random_sign_mask_t: int + rht_matrix: torch.Tensor + + def __init__( + self, + fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, + rowwise: bool = True, + columnwise: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[dist_group_type] = None, + with_rht: bool = False, + with_post_rht_amax: bool = False, + with_2d_quantization: bool = False, + stochastic_rounding: bool = False, + with_random_sign_mask: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp4_dtype + self.with_rht = with_rht + self.with_post_rht_amax = with_post_rht_amax + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group + self.with_2d_quantization = with_2d_quantization + self.stochastic_rounding = stochastic_rounding + self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask) + self.rht_matrix = get_rht_matrix(with_random_sign_mask) + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + return dst + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + return True + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For NVFP4 1D blockwise quantization, blocksize is 16 + - If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4)) + - If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4)) + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = round_up_to_nearest_multiple(K, 128) + inner = round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + # rowwise + outer = round_up_to_nearest_multiple(M, 128) + inner = round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + + @staticmethod + def get_columnwise_shape(shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of a tensor after columnwise quantization. + + For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling. + + Parameters + ---------- + shape : Iterable[int] + Original shape of the tensor + + Returns + ------- + Tuple[int, ...] + New shape with dimensions rearranged for columnwise layout. + For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1). + Returns empty tuple for empty input shape. + """ + if len(shape) == 0: + return tuple() + # and then after AG, a reorganize kernel will be called to restore the shape + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + @staticmethod + def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]: + """Convert shape for FP4 data by dividing the last dimension by 2""" + shape = list(shape) + shape[-1] = shape[-1] // 2 + return tuple(shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> NVFP4Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + flat_first_dim = math.prod(shape[:-1]) + assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP4 data + data = None + scale_inv = None + amax_rowwise = None + if self.rowwise_usage: + data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device) + # Allocate per tensor scale inverse. FP32 format. + amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + amax_columnwise = None + if self.columnwise_usage: + # enforce 2D shape to avoid [S, B, H] shape and B and be 1 + # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + shape_2d = tuple([flat_first_dim, shape[-1]]) + columnwise_data = torch.empty( + self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), + dtype=torch.uint8, + device=device, + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, dtype=torch.uint8, device=device + ) + amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device) + + # Construct FP8 tensor + return NVFP4Tensor( + shape=shape, + dtype=dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.dtype, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + pass # Calibration is no-op + + def _canonicalized_amax_reduction_group(self) -> dist_group_type: + """Get process group for amax reduction""" + return canonicalize_process_group(self.amax_reduction_group) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return NVFP4BlockScaling + + +class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): + """Quantized tensor class with FP4 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP4. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + Raw FP4 data in a uint8 tensor (rowwise layout). + rowwise_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP4, i.e. the scaling factor that must + be applied when casting from FP4 to higher + precision (rowwise). + columnwise_data: torch.Tensor, optional + Raw FP4 data in a uint8 tensor (columnwise layout). + columnwise_scale_inv: torch.Tensor, optional + Reciprocal of the scaling factor for columnwise FP4 data. + amax_rowwise: torch.Tensor, optional + Rowwise amax tracking tensor. + amax_columnwise: torch.Tensor, optional + Columnwise amax tracking tensor. + fp4_dtype: TE_DType + The FP4 data type used for quantization. + quantizer: Quantizer + The quantizer instance used for this tensor. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype, used in dequantize. + """ + + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args, + # which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + amax_rowwise: Optional[torch.Tensor], + amax_columnwise: Optional[torch.Tensor], + fp4_dtype: TE_DType, + quantizer: Quantizer, + **kwargs, + ): + instance = super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + fp4_dtype, + quantizer, + *args, + **kwargs, + ) + return instance + + def __repr__(self, *, tensor_contents=None): + return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from NVFP4Tensor + + By default the resulting tensor's dtype is the + NVFP4Tensor's nominal dtype. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromNVFP4Func.apply(self, dtype) + return _FromNVFP4Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return NVFP4Quantizer() + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> NVFP4Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return NVFP4Tensor.make_like(self) + + def clone(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + assert self._rowwise_data is not None + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> NVFP4Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("NVFP4Tensor does not support different memory formats!") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + if len(args) != 2: + raise RuntimeError("Unexpected args for view op (expected 2 args, got {len(args)})") + tensor = args[0] + shape = args[1] + if shape == list(tensor.size()): + return tensor.detach() + return tensor.view(shape) + + # NVFP4 dequantize not supported. Add manual support for needed funcs. + if func in (aten.empty_like.default, aten.zero_.default): + tensor = args[0] + data_init_func = torch.zeros_like if func == aten.zero_.default else torch.empty_like + scale_inv_init_func = ( + torch.ones_like if func == aten.zero_.default else torch.empty_like + ) + + if tensor._rowwise_data is not None: + rowwise_data = data_init_func(tensor._rowwise_data) + rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv) + amax_rowwise = torch.zeros_like(tensor._amax_rowwise) + else: + rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None + + if tensor._columnwise_data is not None: + columnwise_data = data_init_func(tensor._columnwise_data) + columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv) + amax_columnwise = torch.zeros_like(tensor._amax_columnwise) + else: + columnwise_data, columnwise_scale_inv, amax_columnwise = ( + None, + None, + None, + ) + + return NVFP4Tensor( + shape=tensor.shape, + dtype=tensor.dtype, + fp4_dtype=tensor._fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=tensor._quantizer, + requires_grad=tensor.requires_grad, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + ) -> NVFP4Tensor: + """Build NVFP4Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=quantizer, + requires_grad=False, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + NVFP4Tensor._make_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + self._fp4_dtype, + self.dtype, + self._quantizer, + ), + ) + + def _get_data(self) -> NVFP4Tensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a NVFP4Tensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) + + # Just copy FP8 data if other tensor is NVFP4Tensor + if isinstance(tensor, NVFP4Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + NVFP4Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + self._amax_rowwise = tensor._amax_rowwise + self._amax_columnwise = tensor._amax_columnwise + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.update_quantized(tensor, self) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting NVFP4Tensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.view(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.view(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.view(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.reshape(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.reshape(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.reshape(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.reshape(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 656eda46ca..7b88d25196 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -264,6 +264,10 @@ def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False + def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument + """Returns whether or not given tensor can be quantized""" + return True + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype""" diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 23f56da5d0..a4bdf5e07d 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -4,11 +4,13 @@ """Helper functions for using fp8 tensors as weights""" +import os +from typing import Optional, Union import torch import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv -from .quantized_tensor import QuantizedTensor +from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer @@ -450,3 +452,20 @@ def _cast_master_weights_to_fp8_blockwise_scaling( tex.fp8_block_scaling_partial_cast( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) + + +def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool: + """Check if an environment or object is using experimental Kitchen middleware. + + Returns False if x is a torch.Tensor. + """ + # Detect if the environment is experimental + if x is None: + return int(os.getenv("QAT_PARAMS", "0")) > 0 + + # Detect if the object is experimental + if isinstance(x, torch.Tensor): + return False + if not isinstance(x, (Quantizer, QuantizedTensorBase)): + raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance") + return hasattr(x, "experimental") and x.experimental diff --git a/transformer_engine/pytorch/triton/pad.py b/transformer_engine/pytorch/triton/pad.py new file mode 100644 index 0000000000..29b0daf310 --- /dev/null +++ b/transformer_engine/pytorch/triton/pad.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVFP4 padding kernels + +TODO(ksivamani): Documentation + +""" + +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1), + ], + key=["out_dim0", "out_dim1"], +) +@triton.jit +def zero_pad_kernel( + inp_ptr, + out_ptr, + in_dim0: tl.constexpr, + in_dim1: tl.constexpr, + out_dim0: tl.constexpr, + out_dim1: tl.constexpr, + in_s0, + in_s1, + out_s0, + out_s1, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + # tile over OUTPUT coordinates + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols + om = offs_m[:, None] + on = offs_n[None, :] + + # edge masking for output + out_mask = (om < out_dim0) & (on < out_dim1) + + # valid input region is simply top-left (no offsets) + in_mask = (om < in_dim0) & (on < in_dim1) + + # load valid input, else zero (masked load touches memory only where True) + x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0) + + # store to output (only within bounds of the output tile) + tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask) + + +def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor: + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + assert inp.ndim == 2 + dim0, dim1 = inp.shape + + pad_x = (128 - dim0 % 128) % 128 + pad_y = (4 - dim1 % 4) % 4 + out_x = dim0 + pad_x + out_y = dim1 + pad_y + out = torch.empty((out_x, out_y), device=inp.device, dtype=inp.dtype) + + in_s0, in_s1 = inp.stride() + out_s0, out_s1 = out.stride() + + BLOCK_M, BLOCK_N = 128, 128 + grid = (triton.cdiv(out_x, BLOCK_M), triton.cdiv(out_y, BLOCK_N)) + + zero_pad_kernel[grid]( + inp, + out, + dim0, + dim1, + out_x, + out_y, + in_s0, + in_s1, + out_s0, + out_s1, + ) + return out diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 6420f3e120..1a0722f894 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,8 +11,8 @@ import numpy as np import torch -import transformer_engine.pytorch.cpp_extensions as ext from . import torch_version +from .tensor.quantized_tensor import Quantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor @@ -441,6 +441,16 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ) +def assert_dim_for_all_gather( + tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer +) -> None: + """Assert that tensor dimensions are supported for all-gather""" + if with_all_gather: + assert quantizer.is_quantizable(tensor), ( + "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ + ) + + def is_bf16_compatible() -> None: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher. @@ -460,6 +470,8 @@ def is_non_tn_fp8_gemm_supported() -> bool: @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" + import transformer_engine.pytorch.cpp_extensions as ext + encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude) From 2354fb8b02ec64e73f82f2fd564f541c29a5e737 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 30 Sep 2025 09:01:24 -0700 Subject: [PATCH 43/78] Fix the segfault in the nvfp4 quantization (#2214) * Fix the segfault in the nvfp4 quantization Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/util/nvfp4_transpose.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh index fe9736298d..712b557c5d 100644 --- a/transformer_engine/common/util/nvfp4_transpose.cuh +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -1433,7 +1433,8 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o const size_t block_size = THREADS_NUM; const size_t scale_stride = output->scale_inv.shape[1]; - const size_t scale_stride_transpose = output->columnwise_scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); nvfp4_scale_t *const scales_transpose_ptr = From 25252e9f2bc1460a841f32ef172126fb9192515a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:17:37 -0700 Subject: [PATCH 44/78] [PyTorch] Add FP8 attention with current scaling (#2012) * debug existing usage Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8_dpa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reimplement fp8_dpa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * redesign CS; need cleanup Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up s/dP quantizers Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return dP to DS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * improve quantizer_helper; tweak dP DS/CS logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * debug CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up non-CP; debug dq/dk mismatches Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor success with CP; need to remove debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable fp8 output for fp8_mha + CS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add output_tensor_type to FADescriptor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove print Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more fixes for non-CP and CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable non-determinism for blackwell Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix indent; remove print Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch from create_tensor_from_data to make_like Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable a2a+p2p for CS CP and require additional cp_group_global Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * condense tests; only create dist groups once Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * consolidate CP P2P per-tile calls for fwd/bwd and fused/flash Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix flash-attn from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes for previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix attn_mask_type in f16 causal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert bb6a0a59 temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reenable comparison for some tensors in CP tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dbias for fused attn CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints/comments and add back NVTE_CS_dP_SCALE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * first attempt at mixed DS/CS reduction Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix for last commit for mixed DS/CS reduction Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints from 69639024 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix DS recipe for dP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_DPA_FORCE_DS to force DS for all DPA tensors, not just dP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix NVTE_DPA_FORCE_DS and add NVTE_PRINT Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * modify DS recipe for MLPerf Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce only over TP group; need to think about CP group later Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * streamline fake_recipe/quantizer generation; allow NVTE_DPA_Fixed_Scales or DS-update S/dP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more print: NVTE_LAYER_NUMBER Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * split S/dP in env vars: NVTE_DPA_Fix_S_Scale and NVTE_DPA_Fix_dP_Scale Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix autocast_key for DS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_REPEAT_in_F16 to repeat FP8 fwd/bwd passes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add FP8 CS to UnfusedDPA; unsuccessful; does not affect other backends Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temporary: print min/max and save tensors for debugging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * emulate q/dq+bf16 with NVTE_Emulate_in_F16; add NVTE_DPA_FORCE_MXFP8 for MXFP8 q/dq Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add RHT to BMM1 with NVTE_RHT_BMM1 for the size Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * re-enable fused attn in dpa_fp8_vs_f16 test; changed during unfused attn implementation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_FP8_CS_POWER_OF_2, NVTE_DPA_FORCE_BLOCKFP8, NVTE_Emulate_QDQ_QKV, NVTE_Emulate_QDQ_O, NVTE_Emulate_QDQ_dO Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add F16 O support for FP8 kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to TE FE commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * return to FE develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tidy up; untested Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes and improvements for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more minor fixes and improvements Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more small fixes/improvements; mostly for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CS/DS recipe switch in DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * avoid quantizing/saving of O when CS bwd uses F16 O Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move fp8_autocast(fp8_recipe) print to utils.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add debug logging to unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back prints of quantizers/layer_number for debugging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable amax reduction for both CS and DS tensors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix NVTE_FP8_DPA_BWD=0 for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit for F16 fwd/bwd a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * small fixes for float8_current_scaling(), nominal types, and unruly d_out types Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8_output in MHA and some CP tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes to CP tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for CP A2A Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clamp input data in tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove rmse and tighten atol/rtol for tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restructure fp8_recipes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linter Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "remove rmse and tighten atol/rtol for tests" This reverts commit 15dba6a59a5323d414f02cf22f099cb00d880532. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more fixes for linter Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 recipe changes for F16 code path Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to FE on main to help with merges Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * switch back to FE develop after merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE develop commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to GitHub FE 1.14.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to its latest main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for A2A Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit for A2A DS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove memset for BSHD/SBHD FP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove concat for qkv quantization in CS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * improve/simplify the logic for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add nominal_type for UnfusedDPA FP8 EmuFunc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: update env vars for DPA recipes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo in last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix DS recipe creation for NVFP4 global recipe Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace python max with torch.maximum Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linter Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CP A2A for FA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce prints in print_quantizers Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add FP8 env vars to NVTE_DEBUG prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add reduce_amax to DS repr Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * separate fp8_dpa/fp8_mha in CP tests; fix A2A for them; add f16_O tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address some reciews Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make data optional in create_hp_tensor_with_amax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for comments in bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * print cudnn version in attn tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable CS for Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * alternative tests to reduce CI time Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make NVTE_DPA_FP8CS_O_in_F16 default to 1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove _fp8 variables to avoid confusion Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return to requiring two cp_groups for a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace NVTE_PRINT with NVTE_DEBUG/_LEVEL for quantizer prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * provide a basic set of tests for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the last merge with nvfp4 PR Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable for Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 backend selection for Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce CP CI to essential tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix to CP test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix recipe logic in tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to concat for qkv quantization Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove cudnn version in qa scripts Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../attention/run_attention_with_cp.py | 151 +- tests/pytorch/attention/test_attention.py | 141 +- .../attention/test_attention_with_cp.py | 91 +- .../fused_attn_f16_arbitrary_seqlen.cu | 8 +- .../common/fused_attn/fused_attn_fp8.cu | 147 +- transformer_engine/common/fused_attn/utils.h | 13 +- transformer_engine/common/recipe/__init__.py | 11 +- .../dot_product_attention/backends.py | 498 ++- .../dot_product_attention/context_parallel.py | 3283 ++++++++--------- .../dot_product_attention.py | 319 +- .../attention/dot_product_attention/utils.py | 221 +- .../pytorch/attention/multi_head_attention.py | 46 +- .../pytorch/cpp_extensions/fused_attn.py | 3 - transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions.h | 5 + .../pytorch/csrc/extensions/attention.cpp | 195 +- .../pytorch/csrc/extensions/cast.cpp | 20 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- transformer_engine/pytorch/fp8.py | 4 +- .../pytorch/tensor/float8_tensor.py | 4 + 21 files changed, 2970 insertions(+), 2202 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 1a7b4b78db..80a8e4af4d 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8 +Subproject commit 80a8e4af4d89d33a2c59d51fcf9fda1c9d368cd4 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 7e47e7df8d..d490c235bb 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -12,14 +12,18 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_cu_seqlens_on_cp_rank, ) +from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert - dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -151,7 +155,7 @@ def get_tols(config, dtype): elif dtype == "fp8": atol = 5e-1 rtol = 5e-1 - rmse_tol = 0.1 + rmse_tol = 0.15 else: assert False, f"{dtype=} is not supported!" @@ -164,14 +168,23 @@ def run_dpa_with_cp( qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p", - fp8_mha=False, + fp8_bwd="True", + fp8_dpa="False", + fp8_mha="False", + scaling_mode="delayed", + f16_O="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) # set up environment variables and config - fp8_mha = fp8_mha == "True" + fp8_bwd = fp8_bwd == "True" and dtype == "fp8" + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" + fp8_dpa = fp8_dpa == "True" and dtype == "fp8" + fp8_mha = fp8_mha == "True" and dtype == "fp8" + f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": @@ -219,8 +232,12 @@ def run_dpa_with_cp( sub_group = dist.new_group(sub_ranks, backend="nccl") if rank in sub_ranks: cp_comm_sub_groups.append(sub_group) + if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "current": + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( @@ -247,19 +264,38 @@ def run_dpa_with_cp( cu_seqlens_q_padded, cu_seqlens_kv_padded, ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) - q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() - k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() - v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() - for x in [q, k, v]: - x.requires_grad = True - - dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() - if fp8_mha: + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() + if scaling_mode == "delayed": + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) dout_quantizer = Float8Quantizer( fp8_dtype=tex.DType.kFloat8E5M2, scale=torch.tensor([1], dtype=torch.float32).cuda(), amax=torch.tensor([0], dtype=torch.float32).cuda(), ) + if scaling_mode == "current": + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) @@ -274,6 +310,7 @@ def run_dpa_with_cp( else: fp8_context = nullcontext() with fp8_context: + # q, k, v, out in FP8; dout in F16 out = core_attn( q, k, @@ -284,8 +321,9 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - if fp8_mha: + if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: @@ -298,24 +336,10 @@ def run_dpa_with_cp( ############ run with CP ############ logging.info(f"[Rank {rank}] Run with context parallelism") - # set up environment - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - if config.softmax_type != "vanilla": - core_attn.softmax_offset.grad.zero_() - if dtype == "fp8": - core_attn.reset_fp8_meta_tensors() - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) - else: - fp8_context = nullcontext() - # set up inputs q_, k_, v_, dout_, *rest = [ - x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) ] bias_ = rest[0] if len(rest) else None if qkv_format == "bshd" or qkv_format == "sbhd": @@ -343,6 +367,16 @@ def run_dpa_with_cp( ) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( @@ -350,9 +384,25 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() # run attention with fp8_context: + # q, k, v, out in FP8; dout in F16 out_ = core_attn( q_, k_, @@ -363,27 +413,30 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - if fp8_mha: + if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: out_.backward(dout_) - if fp8_mha: - assert isinstance(out, Float8Tensor) - assert isinstance(out_, Float8Tensor) - out = out.dequantize() - out_ = out_.dequantize() - - # get outputs dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad d_softmax_offset_ = None if config.softmax_type != "vanilla": d_softmax_offset_ = core_attn.softmax_offset.grad.clone() - for x in [out_, dq_, dk_, dv_, d_softmax_offset_]: - if x is not None: - assert torch.all(~torch.isnan(x)) - assert torch.all(~torch.isinf(x)) + + # get outputs + tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] + if fp8_mha: + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[4] = tensors_to_deq + for tensor in tensors: + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) + out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": @@ -394,17 +447,17 @@ def run_dpa_with_cp( x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [q.grad, k.grad, v.grad, out] + for x in [dq, dk, dv, out] ] dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] dq_, dk_, dv_, out_ = [ x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [q_.grad, k_.grad, v_.grad, out_] + for x in [dq_, dk_, dv_, out_] ] elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a5c3457791..e3a4de73b0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1693,23 +1693,44 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_mha_fp8_vs_f16( + dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode +): """Test MultiHeadAttention module in FP8""" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + fp8_mha=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_format.replace("hd", "h3d"), + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: - pytest.skip("Less than two backends to compare.") + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: available_backends, _, fused_attn_backends = get_available_attention_backends( config, @@ -1727,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1735,19 +1756,20 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm, RoPE, is_training + dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-1 rmse_tol = 0.15 - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1758,6 +1780,8 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rmse_tol, True, ) + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1784,7 +1808,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): +def _run_mha_fp8_vs_f16( + dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe +): """Run MultiHeadAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() @@ -1794,15 +1820,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_mha, - fp8_mha=fp8_mha, - ) - with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: @@ -1911,7 +1928,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -1927,16 +1945,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_layout, + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -1956,32 +1991,44 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe + ) + + if unfused_attn_supported: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training + dtype, config, False, qkv_layout, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1992,12 +2039,40 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol, True, ) + if unfused_attn_supported: + logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( + unfused_attn_fwd_fp8, + fused_attn_fwd_f16, + "unfused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + True, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + compare_and_assert( + unfused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"unfused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + True, + ) if config.dropout_p != 0.0: # test cuDNN FP8 dropout assert torch.all( fused_attn_fwd_fp8 == 1 ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." else: + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, @@ -2021,9 +2096,10 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol, True, ) + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" -def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): +def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe): """Run DotProductAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() @@ -2033,14 +2109,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_dpa, - ) - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) with fp8_model_init(enabled=fp8_dpa): dpa = DotProductAttention( @@ -2147,6 +2215,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, + fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index c752d07d82..0f00b8b0ef 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -14,6 +14,10 @@ get_device_compute_capability, get_cudnn_version, ) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils _current_file = pathlib.Path(__file__).resolve() @@ -27,6 +31,8 @@ torch.manual_seed(seed) torch.cuda.manual_seed(seed) +test_essential = True + model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA @@ -63,12 +69,22 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +dtypes = ["bf16", "fp16"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"] + model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} + dtypes = ["bf16"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): @@ -77,6 +93,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": @@ -162,14 +179,30 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): } +dtypes = ["bf16", "fp16", "fp8"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} + dtypes = ["bf16", "fp8"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) -@pytest.mark.parametrize("fp8_mha", [False, True]) -def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) +@pytest.mark.parametrize("fp8_bwd", [True, False]) +@pytest.mark.parametrize("fp8_mha", [True, False]) +@pytest.mark.parametrize("fp8_dpa", [True, False]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("f16_O", [True, False]) +def test_cp_with_fused_attention( + dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -180,10 +213,15 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") + if dtype == "fp8" and not fp8_dpa and fp8_mha: + pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") + if dtype != "fp8" and fp8_bwd: + pytest.skip("Only fp8 works with fp8_bwd=True!") config = model_configs_fused_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type + if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": @@ -211,8 +249,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) - if dtype != "fp8" and fp8_mha: - pytest.skip("Only fp8 works with fp8_mha=True!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("Only fp8 works with scaling_mode != None!") + if dtype == "fp8" and scaling_mode is None: + pytest.skip("fp8 only works with scaling_mode != None!") + if ( + dtype == "fp8" + and scaling_mode == "current" + and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + ): + pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") + if f16_O and (dtype != "fp8" or scaling_mode != "current"): + pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: @@ -229,10 +281,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + fp8_meta = {} + fp8_meta["recipe"] = None + fp8_meta["local_recipes"] = [] + fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha) + if fp8 and scaling_mode == "delayed": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)] + if fp8 and scaling_mode == "current": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [ + Float8CurrentScaling(fp8_dpa=True), + DelayedScaling(fp8_dpa=True), + ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), + fp8=fp8, + fp8_meta=fp8_meta, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -246,7 +313,11 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha qkv_format=qkv_format, kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, log_level=pytest_logging_level, ), check=True, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1d6435ad8a..ba0f845789 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -129,7 +129,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( window_size_right, true, tensorType, - tensorType}; + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -585,7 +587,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( window_size_right, deterministic, tensorType, - tensorType}; + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 995dbda7fb..21c544491a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); @@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); try { FADescriptor_v1 descriptor{b, @@ -1699,8 +1707,10 @@ void fused_attn_fp8_fwd_impl_v1( 0, 0, true, - fwd_tensor_type, - fwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1739,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -1787,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1( descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() @@ -1839,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -1916,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1( {descale_v, devPtrDescaleV}, {descale_s, devPtrDescaleS}, {scale_s, devPtrScaleS}, - {scale_o, devPtrScaleO}, {attn_scale, &scaling_factor}, {O, devPtrO}, {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; + if (is_delayed_scaling) { + variant_pack[scale_o] = devPtrScaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; } */ @@ -1963,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, + cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1978,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -2005,8 +2035,10 @@ void fused_attn_fp8_bwd_impl_v1( 0, 0, false, - fwd_tensor_type, - bwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + do_tensor_type, + dqkv_tensor_type}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -2059,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -2099,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1( o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); + .set_stride(o_stride) + .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d}) @@ -2125,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1( descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + if (is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() @@ -2214,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - dO->set_data_type(bwd_tensor_type); - dQ->set_data_type(bwd_tensor_type); - dK->set_data_type(bwd_tensor_type); - dV->set_data_type(bwd_tensor_type); + dO->set_data_type(do_tensor_type); + dQ->set_data_type(dqkv_tensor_type); + dK->set_data_type(dqkv_tensor_type); + dV->set_data_type(dqkv_tensor_type); std::tuple, // q std::shared_ptr, // k @@ -2298,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_o, devPtrDescaleO}, {descale_dO, devPtrDescaledO}, {descale_s, devPtrDescaleS}, {descale_dP, devPtrDescaledP}, {scale_s, devPtrScaleS}, - {scale_dQ, devPtrScaledQ}, - {scale_dK, devPtrScaledK}, - {scale_dV, devPtrScaledV}, {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, @@ -2316,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dP, devPtrAmaxdP}, }; + if (is_delayed_scaling) { + variant_pack[scale_dQ] = devPtrScaledQ; + variant_pack[scale_dK] = devPtrScaledK; + variant_pack[scale_dV] = devPtrScaledV; + } + if (!is_O_in_F16) { + variant_pack[descale_o] = devPtrDescaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { @@ -2366,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; + const DType O_type = output_O->data.dtype; void* devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; @@ -2432,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, @@ -2467,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked( cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQKV->data.dtype; void* devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); @@ -2484,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked( void* devPtrDescaleV = input_QKV->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2527,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked( devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, @@ -2565,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; + const DType O_type = output_O->data.dtype; void* devPtrQ = input_Q->data.dptr; void* devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); @@ -2633,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, @@ -2671,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked( cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQ->data.dtype; void* devPtrQ = input_Q->data.dptr; void* devPtrKV = input_KV->data.dptr; @@ -2688,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked( void* devPtrDescaleV = input_KV->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2733,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked( devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, @@ -2822,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType O_type = output_O->data.dtype; size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -2831,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, @@ -2878,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleV = input_Q->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2911,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; @@ -2924,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 0a0197423c..f03774f8ed 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -111,21 +111,24 @@ struct FADescriptor_v1 { std::int64_t window_size_left; std::int64_t window_size_right; bool deterministic; - cudnn_frontend::DataType_t fwd_tensor_type; - cudnn_frontend::DataType_t bwd_tensor_type; + cudnn_frontend::DataType_t qkv_tensor_type; + cudnn_frontend::DataType_t o_tensor_type; + cudnn_frontend::DataType_t do_tensor_type; + cudnn_frontend::DataType_t dqkv_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, - bwd_tensor_type) < + window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, + o_tensor_type, do_tensor_type, dqkv_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.fwd_tensor_type, rhs.bwd_tensor_type); + rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.dqkv_tensor_type); } }; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ea0287ef15..179d618b35 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -209,6 +209,7 @@ def __repr__(self) -> str: f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " + f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) @@ -226,10 +227,11 @@ class Float8CurrentScaling(Recipe): pass. """ + use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" fp8_format: Format = Format.HYBRID - fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) @@ -238,9 +240,6 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert ( - not self.fp8_dpa and not self.fp8_mha - ), "FP8 attention is not supported for Float8CurrentScaling." def __repr__(self) -> str: return ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4a60bd9fe1..f72c1eb9e0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -16,14 +16,16 @@ import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - SplitAlongDim, get_device_compute_capability, - combine_tensors, split_tensor_along_dim, ) -from transformer_engine.pytorch.utils import attention_mask_func +from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) from transformer_engine.pytorch.tensor.quantized_tensor import ( - QuantizedTensor, + QuantizedTensorBase, prepare_for_saving, restore_from_saved, ) @@ -40,7 +42,7 @@ META_O, META_QKV, ) -from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype +from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( @@ -53,6 +55,9 @@ import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, @@ -130,6 +135,58 @@ fa_utils.set_flash_attention_3_params() +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + + +class FP8EmulationFunc(torch.autograd.Function): + """ + Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: + - forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through) + - backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through) + """ + + @staticmethod + def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): + # pylint: disable=missing-function-docstring + if quantizer_name == "QKV_quantizer": + query_layer, key_layer, value_layer = [ + x.contiguous() for x in [tensor1, tensor2, tensor3] + ] + q_fp8, k_fp8, v_fp8 = combine_and_quantize( + qkv_layout, query_layer, key_layer, value_layer, quantizer + ) + tensors = combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype + ) + elif quantizer_name in ["S_quantizer", "O_quantizer"]: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) + ctx.quantizer = quantizer + ctx.quantizer_name = quantizer_name + ctx.qkv_layout = qkv_layout + return tensors[0], tensors[1], tensors[2] + + @staticmethod + def backward(ctx, grad1, grad2, grad3): + # pylint: disable=missing-function-docstring + if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + elif ctx.quantizer_name == "dQKV_quantizer": + query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] + dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer + ) + tensors = combine_and_dequantize( + ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype + ) + else: + tensors = grad1, grad2, grad3 + return tensors[0], tensors[1], tensors[2], None, None, None + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms @@ -189,6 +246,10 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, + fp8_output: bool = False, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -286,6 +347,35 @@ def forward( if apply_qk_layer_scaling: scale /= self.layer_number + if fp8: + # get quantizers from DPA; all Nones if not fp8 + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + dpa_utils.get_attention_quantizers(fp8, quantizers) + ) + # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if fp8_recipe.float8_current_scaling(): + S_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=S_quantizer.dtype, device="cuda" + ) + dP_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=dP_quantizer.dtype, device="cuda" + ) + + if "2" in qkv_layout or "3" in qkv_layout: + qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) + qkv_layout = "_".join([qkv_format] * 3) + # quantize and dequantize QKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + ) + # quantize and dequantize dQKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + ) + # Raw attention scores. [b * np, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( @@ -330,6 +420,12 @@ def forward( dtype=query_layer.dtype ) + if fp8: + # quantize and dequantize dP to emulate FP8 + matmul_result, *_ = FP8EmulationFunc.apply( + matmul_result, None, None, dP_quantizer, "dP_quantizer", None + ) + # add attention sink to the last column: [b, np, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( @@ -379,6 +475,12 @@ def forward( # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + if fp8: + # quantize and dequantize S to emulate FP8 + attention_probs, *_ = FP8EmulationFunc.apply( + attention_probs, None, None, S_quantizer, "S_quantizer", None + ) + # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) @@ -413,6 +515,20 @@ def forward( # [tq, np, hn] --> [tq, hp] context_layer = context_layer.view(total_tokens, -1) + if fp8: + # quantize and dequantize O to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, O_quantizer, "O_quantizer", None + ) + # quantize and dequantize dO to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, dO_quantizer, "dO_quantizer", None + ) + + # quantize O + if fp8_output: + context_layer = O_quantizer(context_layer) + return context_layer @@ -511,6 +627,7 @@ def forward( quantizers=None, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), + fp8_output: bool = False, ) -> torch.Tensor: """flash-attn fprop""" @@ -716,6 +833,7 @@ def forward( quantizers=quantizers, pad_between_seqs=False, use_flash_attn_3=use_flash_attn_3, + fp8_output=fp8_output, ) else: from transformer_engine.pytorch.cpu_offload import ( @@ -815,8 +933,6 @@ def convert_to_torch_float8(tensor, dtype): ) return out - # "fp8_mha" decides outputs in fp8, while inputs are inferred from - # the real dtype assert isinstance(key_layer, query_layer.__class__) and isinstance( value_layer, query_layer.__class__ ), "q, k, and v must have the same type." @@ -863,7 +979,7 @@ def convert_to_torch_float8(tensor, dtype): if fp8: output = output.to(dtype=torch_orig_dtype) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: O_quantizer = quantizers["scaling_fwd"][META_O] output = O_quantizer(output) @@ -891,7 +1007,7 @@ def convert_to_torch_float8(tensor, dtype): if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: output_data = ( output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) .transpose(0, 1) @@ -915,7 +1031,7 @@ def convert_to_torch_float8(tensor, dtype): class FusedAttnFunc(torch.autograd.Function): - """Function for FusedAttention with separate Q, K, V tensors""" + """FusedAttention forward and backward implementation""" @staticmethod def forward( @@ -949,55 +1065,71 @@ def forward( quantizers, deterministic, softmax_offset, + fp8_output, + layer_number, ): # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False - - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn - fake_dtype = q.dtype + # add NVTX range + nvtx_label = "transformer_engine.FusedAttnFunc.forward" + nvtx_range_push(f"{nvtx_label}") + + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + + # input types are inferred from the real data while output types are controlled by fp8_output + # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + + # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel in FP8: + is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + + # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + # get nominal data type for out + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + out_nominal_dtype = q.dtype + if fp8: fused_attention_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - q_fp8, k_fp8, v_fp8 = None, None, None + # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - match qkv_group: - case 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = QKV_quantizer(qkv) - q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) - case 2: - q_fp8 = QKV_quantizer(q) - dim = qkv_layout.split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = QKV_quantizer(kv_c) - k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) - case 3: - q_fp8 = QKV_quantizer(q) - k_fp8 = QKV_quantizer(k) - v_fp8 = QKV_quantizer(v) - case _: - raise "Invalid qkv_layout " + qkv_layout - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - out_fp8, aux_ctx_tensors = fused_attn_fwd( + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1006,7 +1138,7 @@ def forward( q_fp8, k_fp8, v_fp8, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1026,42 +1158,54 @@ def forward( rng_gen, softmax_offset, ) - if is_output_fp8: - out_ret = out_fp8 + + # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_fp8 = out_ + out = out_ + + if isinstance(out_, Float8Tensor): + if not is_output_fp8 or not is_bwd_fp8: + out = out_.dequantize().view(out_.shape) else: - out_ret = out_fp8.dequantize().view(out_fp8.shape) - # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 - # is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn - out_save = out_ret + if is_output_fp8 or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_fp8 = O_quantizer(out_) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - # 1: qkv packed, 2: kv packed, 3: qkv separate + # return appropriate tensors + out_ret = out_fp8 if is_output_fp8 else out + + # save appropriate tensors + fp8_tensors = (None, None, None, None) + qkvo_tensors = (None, None, None, None) + if is_bwd_fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + qkvo_tensors = (None, None, None, out) + else: + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + else: if is_input_fp8: - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape) - q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) - if qkv_group == 2: - q = q.dequantize() - dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = kv.dequantize() - k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True) - if qkv_group == 3: - q = q.dequantize() - k = k.dequantize() - v = v.dequantize() - if is_output_fp8: - out_save = out_fp8.dequantize() - - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + qkvo_tensors = (q, k, v, out) else: - # q, k, v, out_ret: torch.float16 or torch.bfloat16 - out_ret, aux_ctx_tensors = fused_attn_fwd( + # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1070,7 +1214,7 @@ def forward( q, k, v, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1090,10 +1234,18 @@ def forward( rng_gen, softmax_offset, ) - out_save = out_ret + out = out_ + out_ret = out_ fp8_tensors = (None, None, None, None) + qkvo_tensors = (q, k, v, out) + + nvtx_range_pop(f"{nvtx_label}") - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = is_bwd_fp8 + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + ctx.nominal_dtype = out_nominal_dtype from transformer_engine.pytorch.cpu_offload import ( CPUOffloadEnabled, @@ -1104,7 +1256,7 @@ def forward( if ctx.fp8: tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out_save] + tensor_list = [q, k, v, out] qkv_layout = "sbhd_sbhd_sbhd" mark_activation_offload(*tensor_list) @@ -1112,7 +1264,6 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *qkvo_tensors, @@ -1126,11 +1277,14 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_meta = fp8_meta + ctx.layer_number = layer_number + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.S_quantizer = S_quantizer - if ctx.fp8: + if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer): ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() @@ -1155,17 +1309,15 @@ def forward( @staticmethod def backward(ctx, d_out): # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2 - fake_dtype = d_out.dtype - d_out = d_out.contiguous() + # d_out is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase): + d_out = ctx.dO_quantizer(d_out) + if not ctx.use_FAv2_bwd: + d_out._data = d_out._data.contiguous() + elif not ctx.use_FAv2_bwd: + d_out = d_out.contiguous() ( q_fp8, k_fp8, @@ -1219,16 +1371,55 @@ def backward(ctx, d_out): dk = dk[..., : d_out.shape[-1]] dv = dv[..., : d_out.shape[-1]] else: - with torch.cuda.nvtx.range("_FusedAttn"): + with torch.cuda.nvtx.range("FusedAttnFunc.backward"): + # get nominal data type of dq, dk, dv + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + dqkv_nominal_dtype = ctx.nominal_dtype + if ctx.fp8: + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 if ctx.is_output_fp8: d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - dqkv_dtype = TE_DType[d_out_fp8._data.dtype] - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 - dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # get tex.DType for dq, dk, dv data + dqkv_te_dtype = d_out_fp8._fp8_dtype + + # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # fp8_dtype = tex.DType.kFloat8E4M3 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # + # dq_, dk_, dv_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ) + dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -1236,10 +1427,10 @@ def backward(ctx, d_out): q_fp8, k_fp8, v_fp8, - out_fp8, + out_, d_out_fp8, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1258,40 +1449,40 @@ def backward(ctx, d_out): ctx.deterministic, ) - # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 - # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 - if not ctx.is_input_fp8: - qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = ctx.qkv_layout.find("3") - dqkv_fp8_data = combine_tensors( - [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim - ) - dqkv_fp8 = dq_fp8.make_like( - tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape - ) - dqkv = dqkv_fp8.dequantize() - dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) - if qkv_group == 2: - dq = dq_fp8.dequantize() - dim = ctx.qkv_layout.split("_")[1].find("2") - dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim) - dkv_c_fp8 = dkv_fp8.view( - -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] - ) - dkv = dkv_c_fp8.dequantize() - dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True) - if qkv_group == 3: - dq = dq_fp8.dequantize() - dk = dk_fp8.dequantize() - dv = dv_fp8.dequantize() - else: - dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 + # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + dq, dk, dv = dq_, dk_, dv_ + is_float8tensor = isinstance(dq_, Float8Tensor) + if is_float8tensor and not ctx.is_input_fp8: + # return in F16 + dq, dk, dv = combine_and_dequantize( + ctx.qkv_layout, + dq_, + dk_, + dv_, + src_nominal_dtype=dq_.dtype, + ) + if not is_float8tensor and ctx.is_input_fp8: + # return in FP8 + dq, dk, dv = combine_and_quantize( + ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer + ) + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) else: - if isinstance(d_out, QuantizedTensor): - d_out = d_out.dequantize() - dqkv_dtype = TE_DType[d_out.dtype] - # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 + if isinstance(d_out, QuantizedTensorBase): + d_out = d_out.dequantize(dtype=ctx.nominal_dtype) + dqkv_te_dtype = TE_DType[d_out.dtype] + # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1302,8 +1493,8 @@ def backward(ctx, d_out): v, out, d_out, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1358,6 +1549,8 @@ def backward(ctx, d_out): None, None, d_softmax_offset, + None, + None, ) @@ -1463,6 +1656,7 @@ def forward( pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, + fp8_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -1563,15 +1757,27 @@ def forward( ) if fp8: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" " is required for FP8 attention!" ) assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" - assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across TP+CP group is necessary when using context parallelism with" - " FP8!" - ) + if fp8_recipe.delayed(): + assert not context_parallel or fp8_recipe.reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism" + " with FP8!" + ) + if fp8_recipe.float8_current_scaling() and context_parallel: + all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + for q in all_quantizers: + if isinstance(q, Float8CurrentScalingQuantizer): + q.with_amax_reduction = True + q.amax_reduction_group = ( + cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group + ) if context_parallel: assert ( @@ -1615,6 +1821,8 @@ def forward( pad_between_seqs=pad_between_seqs, softmax_type=self.softmax_type, softmax_offset=softmax_offset, + fp8_output=fp8_output, + layer_number=self.layer_number, ) else: with self.attention_dropout_ctx(): @@ -1648,6 +1856,8 @@ def forward( quantizers, self.deterministic, softmax_offset, + fp8_output, + self.layer_number, ) # ...hd -> ...(hd) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 2e4b6b6177..539caffbb9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -9,7 +9,6 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - combine_tensors, get_cudnn_version, nvtx_range_pop, nvtx_range_push, @@ -20,7 +19,9 @@ fused_attn_bwd, FusedAttnBackend, ) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( dist_group_type, @@ -41,6 +42,9 @@ import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, ) _cu_seqlens_info_with_cp_cache = {} @@ -48,6 +52,9 @@ _seq_chunk_ids_cache_for_reordering_after_attn = {} _softmax_offset_chunk_ids_cache = {} +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm @@ -226,11 +233,11 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): @jit_fuser def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -240,13 +247,13 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz @jit_fuser def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -278,16 +285,16 @@ def flash_attn_a2a_communicate( x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # or [s, b, h, d] -> [s, b, cp, h//cp, d] x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] a2a_inputs[i] = x.movedim(-3, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): @@ -298,8 +305,8 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -309,11 +316,11 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # [cp, 2, b, s//2, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -467,6 +474,585 @@ def get_fa_args( ] +def cp_p2p_fwd_prepare_qkv( + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + step, + cp_size, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P forward""" + cu_seqlens_q_per_step = None + cu_seqlens_kv_per_step = None + if section in ["diagonal", "all"]: + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + rank_ = rank if section == "diagonal" else (rank - step) % cp_size + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank_, True, True + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, k_part, v_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part]] + + elif section == "lower-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + False, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv_half + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part = q_part.view(q_part.shape[0], -1, *q_part.shape[-2:]) + # [b, 2, sk//2, h, d] -> [b, sk//2, h, d] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part = q_part.view(-1, *q_part.shape[-3:]) + # [2, sk//2, b, h, d] -> [sk//2, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + + elif section == "upper-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // (cp_size * 2) + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q_half + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part = q_part[:, 1, ...] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part = q_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part = tex.thd_read_half_tensor(q_part, cu_seqlens_q_padded, 1) + + return q_part, k_part, v_part, cu_seqlens_q_per_step, cu_seqlens_kv_per_step + + +def cp_p2p_fwd_fused_attn( + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step, + O_quantizer_per_step, + rank, + step, + cp_size, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FusedAttention backend""" + attn_bias_inputs = None + max_seqlen_q_ = None + max_seqlen_kv_ = None + cu_seqlens_q_ = None + cu_seqlens_kv_ = None + attn_mask_type_ = None + cu_seqlens_q_padded_ = None + cu_seqlens_kv_padded_ = None + if section in ["diagonal", "all"]: + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + elif section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = attn_bias[..., idx, :].contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = ( + cu_seqlens_kv_padded // 2 if cu_seqlens_kv_padded is not None else None + ) + elif section == "upper-triangle": + q_part = q_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q // 2 + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded // 2 if cu_seqlens_q_padded is not None else None + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step + fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step + + out_per_step, aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_, + cu_seqlens_kv_, + q_part, + k_part, + v_part, + fake_dtype=fwd_nominal_dtype, + fused_attention_backend=fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + **fp8_meta_kwargs, + ) + + if fp8: + softmax_lse_per_step, _, rng_states = aux_ctx_tensors + else: + softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors + attn_bias = rest[0] if len(rest) > 0 else None + + return out_per_step, softmax_lse_per_step, rng_states, attn_bias + + +def cp_p2p_fwd_flash_attn( + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FlashAttention backend""" + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + causal_ = False + if section in ["diagonal", "all"]: + causal_ = section == "diagonal" + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + if section in ["lower-triangle", "upper-triangle"]: + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_, + cu_seqlens_kv=cu_seqlens_kv_, + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + ) + fa_outputs = flash_attn_fwd( + q_part, + k_part, + v_part, + *fa_forward_args_thd, + causal=causal_, + **fa_forward_kwargs, + ) + rng_states = None + if not fa_utils.v2_7_0_plus: + out_per_step = fa_outputs[4] + softmax_lse_per_step = fa_outputs[5] + if not use_flash_attn_3: + rng_states = fa_outputs[7] + else: + out_per_step = fa_outputs[0] + softmax_lse_per_step = fa_outputs[1] + if not use_flash_attn_3: + rng_states = fa_outputs[3] + + return out_per_step, softmax_lse_per_step, rng_states + + +def cp_p2p_bwd_prepare_qkv( + q_part, + k_part, + v_part, + out_part, + dout_part, + qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P backward""" + if section in ["diagonal", "all"]: + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) + for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif section == "lower-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, out_part, dout_part] + ] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, out_part, dout_part] + ] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + elif section == "upper-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part, out_part, dout_part = q_part[:, 1], out_part[:, 1], dout_part[:, 1] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part, out_part, dout_part = q_part[1], out_part[1], dout_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part, out_part, dout_part = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q_part, out_part, dout_part] + ] + + return q_part, k_part, v_part, out_part, dout_part + + +def cp_p2p_bwd_fused_attn( + fp8, + fp8_recipe, + q_fp8, + kv_fp8, + out_fp8, + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + max_seqlen_q, + max_seqlen_kv, + step, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + deterministic, + fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + S_quantizer, + dP_quantizer_per_step, + dQKV_quantizer_per_step, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FusedAttention backend""" + if fp8: + aux_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + attn_mask_type_ = attn_mask_type + + if section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_kv_padded_ = None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + elif section == "upper-triangle": + q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] + if fp8: + aux_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q // 2 + cu_seqlens_q_padded_ = None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + + if attn_dbias is not None: + aux_tensors += [attn_biases[cp_size - step - 1]] + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step + + dq, dk, dv, dbias, *_ = fused_attn_bwd( + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv_per_step[cp_size - step - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + bwd_nominal_dtype, + bwd_output_te_dtype, + aux_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + deterministic=deterministic, + **fp8_meta_kwargs, + ) + + return dq, dk, dv, dbias + + +def cp_p2p_bwd_flash_attn( + use_flash_attn_3, + qkv_format, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + step, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FlashAttention backend""" + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - step - 1] + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + softmax_lse__ = softmax_lse + causal_ = False + if section == "diagonal": + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, 0) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 + causal_ = True + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + softmax_lse__ = softmax_lse_ + + fa_backward_args_thd = get_fa_args( + False, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + dq=dq, + dk=dk, + dv=dv, + ) + flash_attn_bwd( + dout_part, + q_part, + k_part, + v_part, + out_part, + softmax_lse__, + *fa_backward_args_thd, + causal=causal_, + **fa_backward_kwargs, + ) + + return dq, dk, dv + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -508,30 +1094,24 @@ def forward( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") - enable_mla = k.shape[-1] != v.shape[-1] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.forward" + nvtx_range_push(f"{nvtx_label}") + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} + cp_group_a2a = None + cp_size_a2a = 1 + rank_a2a = 0 if isinstance(cp_group, list): - assert ( - qkv_format != "thd" - ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" - assert attn_bias_type == "no_bias", ( - f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" - " yet!" - ) cp_group_a2a = cp_group[0] cp_size_a2a = get_distributed_world_size(cp_group_a2a) rank_a2a = get_distributed_rank(cp_group_a2a) cp_group = cp_group[1] - else: - cp_group_a2a = None - cp_size_a2a = 1 - rank_a2a = 0 - cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] @@ -541,18 +1121,19 @@ def forward( device_compute_capability < (10, 0) and cp_size == 2 ) + # set up attention args + enable_mla = k.shape[-1] != v.shape[-1] causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") - if enable_mla: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - else: - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None if use_fused_attention: batch_dim = qkv_format.index("b") @@ -563,7 +1144,6 @@ def forward( q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv ) else: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size @@ -573,79 +1153,110 @@ def forward( cu_seqlens_kv_per_step = [None for _ in range(cp_size)] fused_attn_backend = None - qkv_dtype = q.dtype amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] - O_CP_quantizer_per_step = [None for _ in range(cp_size)] - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False + O_quantizer_per_step = [None for _ in range(cp_size)] + + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + fwd_nominal_dtype = q.dtype + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] ( QKV_quantizer, O_quantizer, - O_CP_quantizer, S_quantizer, dQKV_quantizer, - dQKV_CP_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + q_f16 = None + q_fp8, k_fp8, v_fp8 = (None, None, None) + # communicate for the 'a2a' part of 'a2a+p2p' + if cp_size_a2a > 1: + if fp8 and is_input_fp8: + QKV_quantizer = q._quantizer + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = (q._data, k._data, v._data) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + if fp8 and is_input_fp8: + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) + ] + q, k, v = q_fp8, k_fp8, v_fp8 + + # convert qkv to the right type if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha - if is_input_fp8: - QKV_quantizer = q._quantizer - q, k, v = q._data, k._data, v._data - else: - q_f16, k_f16, v_f16 = q, k, v - if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16)._data - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - # partial result quantizer - for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() - O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if is_input_fp8: + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] else: - assert False, "FP8 is only supported with Fused Attention!" + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_f16 = q + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # amax_per_step[0]: amax_s x cp_size + # amax_per_step[1]: amax_o x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i] = O_quantizer.copy() + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype q_f16 = q if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) - - q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True - ) - if not fp8: - q_f16 = q - elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16 = q - q = QKV_quantizer(q_f16)._data - + # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" if causal: if qkv_format == "bshd": - # [b, s, np, hn] -> [b, 2, s//2, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]] elif qkv_format == "sbhd": - # [s, b, np, hn] -> [2, s//2, b, np, hn] + # [s, b, h, d] -> [2, s//2, b, h, d] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + attn_bias_ = None if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -654,7 +1265,7 @@ def forward( assert ( attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 ), "Sequence length does not meet divisible requirements!" - # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_bias_ = attn_bias.view( *attn_bias.shape[:-2], 2, @@ -662,12 +1273,14 @@ def forward( 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size), ) - # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" + # stats tensor shape: + # BHS1 before cuDNN 9.6 or flash-attention v2.6/v3 + # TH1 after cuDNN 9.6 or flash-attention v2.6/v3 softmax_lse_in_packed_format = False if qkv_format == "thd": if use_fused_attention: @@ -675,7 +1288,9 @@ def forward( else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 + # set up args for FlashAttention backend flash_attn_fwd = None + fa_forward_kwargs = {} if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if use_flash_attn_3: @@ -714,11 +1329,9 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - # Flash Attn inputs + # set up inputs for forward q_inputs = [None, None] kv_inputs = [None, None] - attn_bias_inputs = [None, None] - # Flash Attn outputs out_per_step = [None for _ in range(cp_size)] softmax_lse_per_step = [None for _ in range(cp_size)] rng_states = [None for _ in range(cp_size)] @@ -730,19 +1343,15 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if enable_mla: - # If MLA, the shape of k and v does not match, so we flatten them - # and split them after receiving them. - k_shape = k.shape - k_numel = k.numel() - v_shape = v.shape - p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) - elif qkv_format in ["bshd", "sbhd"]: - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) - else: # qkv_format == "thd" - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] + # P2P communication and compute: each rank has cp_size steps + # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None for i in range(cp_size + 1): if i < cp_size: @@ -763,634 +1372,205 @@ def forward( batch_p2p_comm, ) - if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - kv_inputs[i % 2] = p2p_comm_buffers[i] + kv_inputs[i % 2] = p2p_comm_buffers[i] + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + q_part = q + + prepare_inputs = [ + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + i, + cp_size, + ] + if use_fused_attention: + fused_attn_inputs = [ + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step[i], + O_quantizer_per_step[i], + rank, + i, + cp_size, + ] else: - # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data - if enable_mla: - # If MLA, k and v are flattened, so split them after receiving. - k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) - v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + flash_attn_inputs = [ + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + ] + + # cp_size = 4: + # + # step + # section | 0 1 2 3 + # -------------------- + # G 0 | d, u, u, u, + # P 1 | l, d, u, u, + # U 2 | l, l, d, u, + # 3 | l, l, l, d, + # + # Each GPU holds a slice of Q and KV. To compute the attention of each Q slice, each GPU + # runs cp_size steps to get the partial results of its own Q and all KV slices. KV is communicated + # in a point-to-point, ring fashion. For attn_mask_type = causal, there are three attention + # patterns in the cp_size x cp_size (i.e. GPU x step) matrix, the diagonal tiles, the lower-triangle + # tiles, and the upper-triangle tiles. For attn_mask_type != causal, the pattern is all the same. if causal: if i == 0: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - q_inputs[i % 2] = q + section = "diagonal" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - fake_dtype=qkv_dtype, - fused_attention_backend=fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=True, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] elif i <= rank: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - False, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] - k_part = k_part[:, 0, ...] - v_part = v_part[:, 0, ...] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0] - elif qkv_format == "thd": - q_inputs[i % 2] = q - if enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor( - k_part, cu_seqlens_kv_padded, 0 - ) - v_part = tex.thd_read_half_tensor( - v_part, cu_seqlens_kv_padded, 0 - ) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + section = "lower-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv // 2, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv // 2, ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, - ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q_half - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i % 2] = q[:, 1, ...] - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i % 2] = q[1] - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) + section = "upper-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - q_inputs[i % 2] = q_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias_[..., 1, :, idx, :], - attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q // 2, - max_seqlen_kv=max_seqlen_kv, ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, - ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv + # all tiles + section = "all" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, - ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q, - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] + # softmax_lse correction if i > 0: - # wait until fwd restuls correction of last step is done + # wait until fwd results correction of last step is done if i > 1: flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, np, sq, 1] -> [b, np, sq] or - # [t, np, 1] -> [t, np] + # [b, h, sq, 1] -> [b, h, sq] or + # [t, h, 1] -> [t, np] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() ) if fp8: - out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) + # dequantize out_per_step to torch.float32 + if fp8_recipe.delayed(): + out_per_step[i - 1] = out_per_step[i - 1].dequantize( + dtype=torch.float32 + ) + if fp8_recipe.float8_current_scaling(): + out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) + if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": @@ -1430,6 +1610,7 @@ def forward( if causal and rank < (cp_size - 1): second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: @@ -1482,7 +1663,6 @@ def forward( softmax_lse_in_packed_format, ) - kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] @@ -1497,39 +1677,84 @@ def forward( ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] + # [b*s, h, d] -> [b, s, h, d] out = out.view(ctx.batch_size, -1, *out.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] + # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) + # update FP8 quantizers: amax across cp_size steps if fp8 and use_fused_attention: amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) - O_CP_quantizer.amax.copy_(amax_cp_fwd[1]) + O_quantizer.amax.copy_(amax_cp_fwd[1]) - out_fp8 = None - out_f16 = out.to(qkv_dtype) - - if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): - out_fp8 = O_quantizer(out_f16) # final result + if fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + # prepare for return and ctx saves + out_fp8 = None + out_f16 = out.to(fwd_nominal_dtype) + if fp8 and ( + is_output_fp8 + or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + ): + out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 - if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, kv_save, out_save = q, kv, out_fp8._data + ctx.layer_number = layer_number + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = fp8 and is_bwd_fp8 + + kv_fp8 = None + kv = p2p_comm_buffers[-1] + if fp8: + q_fp8, kv_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8], [q, kv]) + ] + # q, kv, out + fp8_tensors = (None, None, None) + f16_tensors = (None, None, None) + if ctx.fp8: + # fwd: fp8, bwd: fp8, save all fp8 + fp8_tensors = (q_fp8, kv_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + f16_tensors = (None, None, out_f16) elif fp8 and is_input_fp8: - q_save, kv_save, out_save = q, kv, out_f16 + # fwd: fp8, bwd: f16, save all f16 + # dequantize fp8 inputs + q_f16 = q_fp8.dequantize() + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8: + # fwd: fp8, bwd: f16, save all f16 + # inputs are already in f16 + q_f16 = q_f16.view(q.shape) + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) else: + # fwd: f16, bwd: f16, save all f16 + # inputs and kernels are both f16 q_f16 = q_f16.view(q.shape) - q_save, kv_save, out_save = q_f16, kv, out_f16 + kv_f16 = kv + f16_tensors = (q_f16, kv_f16, out_f16) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - kv_save, - out_save, + *fp8_tensors, + *f16_tensors, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, @@ -1559,21 +1784,18 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 ctx.enable_mla = enable_mla - if enable_mla: - ctx.k_numel = k_numel - ctx.k_shape = k_shape - ctx.v_shape = v_shape + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape - ctx.qkv_dtype = qkv_dtype + ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer - ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.QKV_quantizer = QKV_quantizer @@ -1586,17 +1808,31 @@ def forward( ctx.O_quantizer.scale = O_quantizer.scale.clone() ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + + nvtx_range_pop(f"{nvtx_label}") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.backward" + nvtx_range_push(f"{nvtx_label}") + + # dout is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorBase): + dout = ctx.dO_quantizer(dout) + if ctx.use_fused_attention: + dout._data = dout._data.contiguous() + elif ctx.use_fused_attention: + dout = dout.contiguous() + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a - cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] @@ -1606,33 +1842,38 @@ def backward(ctx, dout): device_compute_capability < (10, 0) and cp_size == 2 ) - q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( - restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - ) + # get saved tensors + ( + q_fp8, + kv_fp8, + out_fp8, + q, + kv, + out, + softmax_lse, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] rng_states = other_tensors[cp_size * 2 : cp_size * 3] attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + # set up attention args causal = "causal" in ctx.attn_mask_type - padding = "padding" in ctx.attn_mask_type - seq_dim = None + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") - if ctx.enable_mla: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + # set up attention bias if attn_biases[0] is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] attn_dbias = torch.zeros( *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device ) - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_dbias_ = attn_dbias.view( *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] ) @@ -1640,6 +1881,7 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None + # set up softmax_lse softmax_lse_ = None if causal and ctx.second_half_lse_seqlen is not None: if ctx.qkv_format == "thd": @@ -1650,86 +1892,124 @@ def backward(ctx, dout): ctx.second_half_lse_seqlen, ) else: - # [b, np, sq] -> [b, np, 2, sq//2] + # [b, h, sq] -> [b, h, 2, sq//2] softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() - # [b, np, sq//2] -> [b, np, sq//2, 1] or - # [t//2, np] -> [t//2, np, 1] + # [b, h, sq//2] -> [b, h, sq//2, 1] or + # [t//2, np] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() - # [b, np, sq] -> [b, np, sq, 1] or - # [t, np] -> [t, np, 1] + # [b, h, sq] -> [b, h, sq, 1] or + # [t, np] -> [t, h, 1] softmax_lse.unsqueeze_(-1) - dout = dout.contiguous() - dq = None - dout_dtype = dout.dtype + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + bwd_nominal_dtype = ctx.fwd_nominal_dtype + + # convert out, dout to the right type fused_attn_backend = None - fused_attn_dqkv_dtype = None amax_per_step = None dP_quantizer_per_step = [None for _ in range(cp_size)] - dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_quantizer_per_step = [None for _ in range(cp_size)] + buffer_dtype = torch.uint8 + dq_buffer = None + dout_fp8 = None + bwd_output_te_dtype = None + dkv_buffer = None if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: - dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] - dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device) - dkv_fp8 = torch.empty( - (cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device - ) - dkv_fp8_ = torch.empty_like(dkv_fp8) - p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] - dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() - dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype + # dout: torch.Tensor, dtype=torch.uint8 + if ctx.is_output_fp8: + dout_fp8 = dout else: - assert False, "FP8 is only supported with Fused Attention!" + dout_fp8 = ctx.dO_quantizer(dout) + dout = dout_fp8._data + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # dout_fp8._fp8_dtype + bwd_output_te_dtype = ctx.dO_quantizer.dtype + + # create buffers for reduction in float32 + if ctx.fp8_recipe.delayed(): + dq_buffer = torch.empty( + (cp_size, *q.shape), + dtype=buffer_dtype, + device=q.device, + ) + if ctx.fp8_recipe.float8_current_scaling(): + dq_buffer = torch.empty( + q.shape, + dtype=torch.float32, + device=q.device, + ) + kv_recv_buffer = torch.empty_like(kv) + dkv_send_buffer = torch.empty( + (cp_size, *kv.shape), + dtype=buffer_dtype, + device=kv.device, + ) + dkv_recv_buffer = torch.empty_like(dkv_send_buffer) + p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] + if ctx.fp8_recipe.float8_current_scaling(): + dkv_buffer = torch.zeros( + kv.shape, + dtype=torch.float32, + device=kv.device, + ) + + # amax_per_step[0]: amax_dp x cp_size + # amax_per_step[1]: amax_dqkv x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: - if ctx.fp8_meta is not None: - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - kv = ctx.QKV_quantizer.create_tensor_from_data( - kv, fake_dtype=ctx.qkv_dtype, internal=True - ) - q = q.dequantize(dtype=ctx.qkv_dtype) - kv = kv.dequantize(dtype=ctx.qkv_dtype) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - if cp_size_a2a == 1: - dout = dout.dequantize(dtype=dout_dtype) - else: - ctx.dO_quantizer = dout._quantizer - dout = dout._data - dq = torch.empty_like(q) + if isinstance(dout, QuantizedTensorBase): + dout = dout.dequantize(dtype=bwd_nominal_dtype) + dq_buffer = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) @@ -1746,11 +2026,6 @@ def backward(ctx, dout): ctx.cp_stream, True, ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - dout = dout.dequantize(dtype=dout_dtype) if ctx.enable_mla: out = out.view(*ctx.v_shape) @@ -1759,7 +2034,6 @@ def backward(ctx, dout): # MHA or GQA out = out.view(*q.shape) dout = dout.view(*q.shape) - send_recv_reqs = [] flash_attn_bwd = None if not ctx.use_fused_attention: @@ -1794,6 +2068,7 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + send_recv_reqs = [] for i in range(cp_size): # wait until KV is received for req in send_recv_reqs: @@ -1814,8 +2089,8 @@ def backward(ctx, dout): ) else: dkv_a2a_req = torch.distributed.all_to_all_single( - dkv_fp8, - dkv_fp8_, + dkv_send_buffer, + dkv_recv_buffer, group=ctx.cp_group, async_op=True, ) @@ -1832,593 +2107,146 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] - q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None - if ctx.enable_mla: - k_part = kv[: ctx.k_numel].view(*ctx.k_shape) - v_part = kv[ctx.k_numel :].view(*ctx.v_shape) - # In reversed order of fwd - if causal: - if i == (cp_size - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - q_, kv_, out_, dout_ = q, kv, out, dout - if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) + q_part, out_part, dout_part = q, out, dout - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, - ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data - else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, 0) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = 0 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=True, - **fa_backward_kwargs, - ) - elif i >= (cp_size - rank - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part[:, 0] - v_part = v_part[:, 0] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0] - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0] - elif ctx.qkv_format == "thd": - q_, out_, dout_ = q, out, dout - if ctx.enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) - v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - if ctx.use_fused_attention: - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_ = kv_.contiguous() - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ + prepare_inputs = [ + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ] + if ctx.use_fused_attention: + fused_attn_inputs = [ + ctx.fp8, + ctx.fp8_recipe, + q_fp8, + kv_fp8, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ), + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + i, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + ctx.softmax_scale, + ctx.dropout_p, + qkv_layout, + ctx.attn_mask_type, + ctx.attn_bias_type, + ctx.deterministic, + ctx.fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + ctx.S_quantizer, + dP_quantizer_per_step[i], + dQKV_quantizer_per_step[i], + ] + else: + flash_attn_inputs = [ + ctx.use_flash_attn_3, + ctx.qkv_format, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + i, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + ] - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 - ), - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, + # there are still three sections in these tiles based on their attention pattern + # for attn_mask_type = causal, and one for attn_mask_type != causal. + if causal: + if i == (cp_size - 1): + section = "diagonal" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) + if ctx.use_fused_attention: + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv // 2, - dq=dq_, - dk=dk_, - dv=dv_, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + elif i >= (cp_size - rank - 1): + section = "lower-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) + if ctx.use_fused_attention: + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section + ) + else: + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_, out_, dout_ = q[1], out[1], dout[1] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_, out_, dout_ = [ - tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) - for x in [q, out, dout] - ] - kv_ = kv + section = "upper-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q // 2, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse_, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: + section = "all" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q - if not ctx.enable_mla: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - out_part = out - dout_part = dout - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data - else: - dq_ = torch.empty_like(q) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - dkv_ = torch.empty_like(kv) - dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] - dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout, - q, - k_part, - v_part, - out, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq = dq_fp8[(rank + i + 1) % cp_size] + # dq, dk, dv are reduced across steps in higher precision + # DelayedScaling: collect all results in uint8 to one tensor, dequantize to float32, then reduce + # CurrentScaling: dequantize partial results from each step to float32, then reduce + if ctx.fp8 and ctx.use_fused_attention: + if ctx.fp8_recipe.delayed(): + dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] + if ctx.fp8_recipe.float8_current_scaling(): + dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] + + # copy dq_ into the right buffer position + # buffer is cp_size x dq_size for DelayedScaling and the same size as dq for CurrentScaling + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dq = dq_buffer[(rank + i + 1) % cp_size] + else: + dq = dq_buffer if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): - # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or - # [sq, b, np, hn] -> [2, sq//2, b, np, hn] + # [b, sq, h, d] -> [b, 2, sq//2, h, d] or + # [sq, b, h, d] -> [2, sq//2, b, h, d] dq_ = dq_.view(*dq.shape) - - if ctx.fp8: + if ctx.fp8 and ctx.fp8_recipe.delayed(): if i >= (cp_size - rank - 1) or not causal: dq.copy_(dq_) else: @@ -2428,6 +2256,8 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq[0].fill_(0) dq[1].copy_(dq_) + else: + dq.copy_(dq_) elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) @@ -2463,18 +2293,19 @@ def backward(ctx, dout): else: dq.add_(dq_) + # dbias correction if attn_dbias is not None: idx = (rank + i + 1) % cp_size if i == (cp_size - 1) or not causal: - # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + # [b, h, sq, sk//cp] -> [b, h, sq, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) elif i >= (cp_size - rank - 1): - # [b, np, sq, sk//(2*cp)] + # [b, h, sq, sk//(2*cp)] attn_dbias[..., idx, :].copy_(dbias_) else: - # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + # [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) @@ -2483,254 +2314,159 @@ def backward(ctx, dout): for req in send_recv_reqs: req.wait() - if ctx.fp8: - if i < cp_size - 1: - dkv = dkv_fp8_[(rank + i + 1) % cp_size] - else: - dkv = dkv_fp8[(rank + i + 1) % cp_size] + # dkv correction + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] + elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] - if ctx.use_fused_attention: - if ctx.enable_mla: - dkv_ = None - elif ctx.qkv_format in ["bshd", "sbhd"]: - dkv_ = combine_tensors([dk_, dv_], -2) - elif ctx.qkv_format == "thd": - dkv_ = torch.cat( - (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 - ) # pylint: disable=used-before-assignment - if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: - # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] - # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. - dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) - dkv_ = dkv_.movedim(-3, 0) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] - dkv_ = dkv_.view(*dkv.shape) - - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] or - # [2, sk//2, b, np, hn] - dk = dkv[: ctx.k_numel].view(*ctx.k_shape) - dv = dkv[ctx.k_numel :].view(*ctx.v_shape) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - dk_ = dk_.view(*ctx.k_shape) - dv_ = dv_.view(*ctx.v_shape) - - if ctx.fp8: - # enable_mla and fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dk[:, 1, ...].fill_(0) - dv[:, 0, ...].copy_(dv_) - dv[:, 1, ...].fill_(0) - elif ctx.qkv_format == "sbhd": - dk[0].copy_(dk_) - dk[1].fill_(0) - dv[0].copy_(dv_) - dv[1].fill_(0) - else: - dk.copy_(dk_) - dv.copy_(dv_) - elif causal: - # enable_mla and not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_[:, 0, ...]) - dk[:, 1, ...].copy_(dk_[:, 1, ...]) - dv[:, 0, ...].add_(dv_[:, 0, ...]) - dv[:, 1, ...].copy_(dv_[:, 1, ...]) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_[0, ...]) - dk[1, ...].copy_(dk_[1, ...]) - dv[0, ...].add_(dv_[0, ...]) - dv[1, ...].copy_(dv_[1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "copy" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dk.add_(dk_) - dv.add_(dv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dv[:, 0, ...].copy_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].copy_(dk_) - dv[0, ...].copy_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "copy", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_) - dv[:, 0, ...].add_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_) - dv[0, ...].add_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dk.add_(dk_) - dv.add_(dv_) - else: # i == 0 + + # [b, 2, sk//2, h, d] or + # [2, sk//2, b, h, d] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8 and ctx.fp8_recipe.delayed(): + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: dk.copy_(dk_) dv.copy_(dv_) else: - # enable_mla and not fp8 and not causal - if i == 0: - dk.copy_(dk_) - dv.copy_(dv_) - else: # i > 0 + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "copy") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "copy") + else: dk.add_(dk_) dv.add_(dv_) - else: - if ctx.fp8: - # fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - dkv[:, :, 1, ...].fill_(0) + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - dkv[:, 1, ...].fill_(0) + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "copy", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "copy", "none") else: - dkv.copy_(dkv_) - elif causal: - # not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_[:, 0, ...]) - dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dkv.add_(dkv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dkv.add_(dkv_) - else: # i == 0 - dkv.copy_(dkv_) - else: - # not fp8 and not causal - if i == 0: - dkv.copy_(dkv_) - else: # i > 0 - dkv.add_(dkv_) + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "none") + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: amax_cp_bwd = amax_per_step.amax(dim=1) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) - dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dq_fp8, fake_dtype=torch.float32, internal=True - ) - - if ctx.enable_mla: - # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] - dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) - dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) - dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dk_fp8, fake_dtype=torch.float32, internal=True - ) - dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dv_fp8, fake_dtype=torch.float32, internal=True - ) - dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] - dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] - else: - if ctx.qkv_format in ["bshd", "sbhd"]: - # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or - # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] - dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dkv_fp8, fake_dtype=torch.float32, internal=True + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + + dq = dq_buffer + if ctx.fp8_recipe.delayed(): + # [cp, b, 2, sk//2, h, d] or [cp, 2, sk//2, b, h, d] + dk = dkv_recv_buffer[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv = dkv_recv_buffer[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) + dq, dk, dv = [ + ctx.dQKV_quantizer.create_tensor_from_data( + x, fake_dtype=bwd_nominal_dtype, internal=ctx.dQKV_quantizer.internal + ) + for x in [dq, dk, dv] + ] + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + des_nominal_dtype=torch.float32, ) - dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] - dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if causal: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) - dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - dq = dq.view(-1, *dq.shape[-3:]) - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - else: - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + if ctx.fp8_recipe.float8_current_scaling(): + dk = dkv[: ctx.k_numel].view(ctx.k_shape) + dv = dkv[ctx.k_numel :].view(ctx.v_shape) + + if causal and ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + dim = ctx.qkv_format.index("s") + dq, dk, dv = [x.view(*x.shape[:dim], -1, *x.shape[dim + 2 :]) for x in [dq, dk, dv]] if ctx.qkv_format == "thd" and not ctx.use_fused_attention: dq[cu_seqlens_q_padded[-1] :].fill_(0) - if ctx.enable_mla: - dk[cu_seqlens_kv_padded[-1] :].fill_(0) - dv[cu_seqlens_kv_padded[-1] :].fill_(0) - else: - dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - assert torch.uint8 not in [dq.dtype, dkv.dtype] - if ctx.enable_mla: - dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] - else: - dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] - if not ctx.enable_mla: - dk, dv = dkv[0], dkv[1] + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + + if ctx.fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) if cp_size_a2a > 1: + if ctx.fp8 and ctx.is_input_fp8: + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2741,20 +2477,21 @@ def backward(ctx, dout): ctx.cp_stream, False, ) + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] if ctx.qkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if attn_dbias is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) - # converting torch.uint8 to float8tensor - if ctx.fp8 and ctx.is_input_fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + nvtx_range_pop(f"{nvtx_label}") return ( None, @@ -2783,6 +2520,8 @@ def backward(ctx, dout): None, None, None, + None, + None, ) @@ -2912,22 +2651,22 @@ def forward( else: cu_seqlens_q_padded = None - # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) @@ -2947,8 +2686,8 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( @@ -2970,7 +2709,7 @@ def forward( k.shape[1], max_seqlen_kv_, k.device ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( @@ -3106,17 +2845,17 @@ def backward(ctx, dout): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) @@ -3157,8 +2896,8 @@ def backward(ctx, dout): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], @@ -3166,7 +2905,7 @@ def backward(ctx, dout): ) max_seqlen_kv = seq_end_idx - seq_start_idx k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) @@ -3239,7 +2978,7 @@ def backward(ctx, dout): dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ x.movedim(seq_dim, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] @@ -3258,13 +2997,13 @@ def backward(ctx, dout): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) @@ -3335,6 +3074,7 @@ def forward( use_flash_attn_3, softmax_type, softmax_offset, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3342,7 +3082,6 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) - qkv_dtype = q.dtype causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -3406,32 +3145,37 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + q_fp8, k_fp8, v_fp8 = (None, None, None) if fp8: if use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] + else: + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: @@ -3448,24 +3192,18 @@ def forward( softmax_offset, 1, cp_size, cp_group, cp_stream, True ) - if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] - + out_fp8 = None + out_f16 = None batch_size = q.shape[batch_dim] + q_part, k_part, v_part = q, k, v + out_part = None if use_fused_attention: - q_part, k_part, v_part = q, k, v if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v, fake_dtype=qkv_dtype, internal=True - ) - out, aux_ctx_tensors = fused_attn_fwd( + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -3474,7 +3212,7 @@ def forward( q_part, k_part, v_part, - qkv_dtype, + fwd_nominal_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -3489,8 +3227,24 @@ def forward( softmax_type=softmax_type, softmax_offset=softmax_offset, ) - if fp8: - out = out._data + if isinstance(out_, Float8Tensor): + out_fp8 = out_ + out_ = out_._data + if is_bwd_fp8 and not ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ): + out_part = out_fp8 + else: + out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ + out_part = out_ + if ( + fp8 + and is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_part = O_quantizer(out_) else: fa_forward_args_thd = get_fa_args( True, @@ -3502,67 +3256,67 @@ def forward( max_seqlen_kv=max_seqlen_kv, ) fa_outputs = flash_attn_fwd( - q, - k, - v, + q_part, + k_part, + v_part, *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: - out, softmax_lse = fa_outputs[4], fa_outputs[5] + out_, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not use_flash_attn_3 else None else: - out, softmax_lse = fa_outputs[0], fa_outputs[1] + out_, softmax_lse = fa_outputs[0], fa_outputs[1] rng_state = fa_outputs[3] if not use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] + out_part = out_ - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) - out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) + out_ = flash_attn_a2a_communicate( + out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) + # [b*s, h, d] -> [b, s, h, d] + out_ = out_.view(batch_size, -1, *out_.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, batch_size, *out.shape[-2:]) + # [s*b, h, d] -> [s, b, h, d] + out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - if fp8: - if is_output_fp8: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=False - ) - out_ret = out_fp8 - out = out_fp8._data - else: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=True - ) - out_f16 = out_fp8.dequantize(dtype=qkv_dtype) - out_ret = out_f16 + if fp8 and use_fused_attention: + if fp8_recipe.float8_current_scaling(): + out_f16 = out_ + if is_output_fp8: + out_fp8 = O_quantizer(out_) + if fp8_recipe.delayed(): + out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) + if not is_output_fp8: + out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) else: - out_ret = out + out_f16 = out_ - if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, k_save, v_save, out_save = q, k, v, out - else: - if is_input_fp8: - q_save, k_save, v_save = q, k, v - else: - q_save, k_save, v_save = q_f16, k_f16, v_f16 - if is_output_fp8: - out_save = out + out_ret = out_fp8 if is_output_fp8 else out_f16 + + ctx.fp8 = fp8 and is_bwd_fp8 + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + if ctx.fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) else: - out_save = out_f16 + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) + f16_tensors = (q_part, k_part, v_part, out_part) + else: + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - k_save, - v_save, - out_save, + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, @@ -3571,6 +3325,7 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects + ctx.out_shape = out_ret.shape ctx.batch_size = batch_size ctx.cp_group = cp_group @@ -3585,14 +3340,14 @@ def forward( ctx.deterministic = deterministic ctx.window_size = window_size ctx.use_fused_attention = use_fused_attention - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 ctx.softmax_type = softmax_type - ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -3616,6 +3371,10 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, q, k, v, @@ -3626,23 +3385,21 @@ def backward(ctx, dout): cu_seqlens_kv_padded, *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") - dout_dtype = dout.dtype + bwd_nominal_dtype = ctx.fwd_nominal_dtype + dqkv_te_dtype = None fused_attn_backend = None - fused_attn_dqkv_dtype = None + dout_fp8 = dout if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: + if not isinstance(dout, QuantizedTensorBase): dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] + dout_fp8 = dout + dqkv_te_dtype = dout._fp8_dtype dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3652,44 +3409,23 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None: - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - dout = dout._data - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - k = ctx.QKV_quantizer.create_tensor_from_data( - k, fake_dtype=ctx.qkv_dtype, internal=True - ) - v = ctx.QKV_quantizer.create_tensor_from_data( - v, fake_dtype=ctx.qkv_dtype, internal=True - ) - q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] + if isinstance(dout, QuantizedTensorBase): + dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) + dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) + else: + dout = dout.view(*ctx.out_shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device) - out, dout = flash_attn_a2a_communicate( - [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) + dout = flash_attn_a2a_communicate( + dout, chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - out = ctx.O_quantizer.create_tensor_from_data( - out, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - out = out.dequantize(dtype=ctx.qkv_dtype) - dout = dout.dequantize(dtype=dout_dtype) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -3730,30 +3466,14 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: - q_part = q - k_part = k - v_part = v - out_part = out - dout_part = dout - + q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - + q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_part = out + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3764,8 +3484,8 @@ def backward(ctx, dout): v_part, out_part, dout_part, - dout_dtype, - fused_attn_dqkv_dtype, + bwd_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, @@ -3780,10 +3500,9 @@ def backward(ctx, dout): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if ctx.fp8: - dq = dq._data - dk = dk._data - dv = dv._data + if isinstance(dq, Float8Tensor): + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] @@ -3813,7 +3532,7 @@ def backward(ctx, dout): **fa_backward_kwargs, ) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False ) @@ -3835,17 +3554,22 @@ def backward(ctx, dout): ) if ctx.fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data( - dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dk = ctx.dQKV_quantizer.create_tensor_from_data( - dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dv = ctx.dQKV_quantizer.create_tensor_from_data( - dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - if not ctx.is_input_fp8: - dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] + if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ctx.fp8_recipe.delayed(): + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] + if not ctx.is_input_fp8: + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + ) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( @@ -3876,6 +3600,7 @@ def backward(ctx, dout): None, None, d_softmax_offset, + None, ) @@ -3910,6 +3635,8 @@ def attn_forward_func_with_cp( use_flash_attn_3=False, softmax_type="vanilla", softmax_offset=None, + fp8_output=False, + layer_number=1, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3973,10 +3700,15 @@ def attn_forward_func_with_cp( """ if cp_comm_type == "a2a+p2p": - assert isinstance( - cp_group, list - ), "Hierarchical CP implementation needs multi-level CP groups!" - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + assert ( + isinstance(cp_group, list) and len(cp_group) == 2 + ), "CP implementation a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert ( + attn_bias_type == "no_bias" + ), f"{attn_bias_type} bias type is not supported with hierarchical CP implementation yet!" if get_distributed_world_size(cp_group[0]) == 1: cp_group = cp_group[1] cp_comm_type = "p2p" @@ -4064,6 +3796,8 @@ def attn_forward_func_with_cp( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": @@ -4082,6 +3816,7 @@ def attn_forward_func_with_cp( use_flash_attn_3, softmax_type, softmax_offset, + fp8_output, ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f72cd69262..a19d08ae59 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -14,8 +14,22 @@ from torch.nn.parameter import Parameter import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + Format, + Recipe, + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.fp8 import ( + get_fp8_te_dtype, + FP8GlobalStateManager, + RecipeState, + DelayedScalingRecipeState, + MXFP8BlockScalingRecipeState, + Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, +) from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -73,6 +87,67 @@ "_alibi_bias_require_update": False, } +""" +This feature is **experimental** and subject to change. + +Some models may use different FP8 recipes for their linear layers and attention layers. To support this, +users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer, +or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention +layers as follows. + ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| Linear | Attention | Configuration | ++===================+===========+===================================================================================+ +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); | +| | | export NVTE_DPA_FP8_RECIPE="F16" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); | +| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); | +| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| +| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); | +| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | +| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); | +| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | +| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +""" +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} +_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")] +_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent") +_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")) +_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1" + + __all__ = ["DotProductAttention"] @@ -462,6 +537,231 @@ def set_context_parallel_group( self.cp_stream = cp_stream self.cp_comm_type = cp_comm_type + def init_fp8_metadata(self, num_gemms: int = 1) -> None: + """ + Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support. + Initialize fp8 related metadata and tensors during fprop. + """ + _original_recipe = self.fp8_meta.get("recipe", None) + + # global recipe set in fp8_autocast() + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + + # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to + # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. + # + # fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers + # -------------------------------------------------------------------------------------------- + # DelayedScaling (DS) | unset | DS | all DS + # Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP + # x={DS, CS} | y | refer to row x=y | refer to row x=y + fp8_recipe_dpa = fp8_recipe + fp8_recipes = fp8_recipe + if _dpa_fp8_recipe == "F16": + # ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False + fp8_recipe.fp8_dpa = False + fp8_recipe.fp8_mha = False + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + fp8_recipe, + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in ( + "", + "Float8CurrentScaling", + ): + # use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format + # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP + fake_recipes = [ + Float8CurrentScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ), + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False + if not fp8_recipe_dpa.float8_per_tensor_scaling(): + assert not ( + fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha + ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + + # reduce over TP+CP groups; expect fp8_group to be set up so + # assume attention uses the same fp8_group as GEMMs + fp8_group = FP8GlobalStateManager.get_fp8_group() + + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self.fp8 = FP8GlobalStateManager.is_fp8_enabled() + self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration + self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if self.fp8_parameters or fp8_enabled: + self.fp8_meta["global_recipe"] = fp8_recipe + self.fp8_meta["local_recipes"] = ( + fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes] + ) + + if self.fp8_parameters or fp8_enabled: + if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]: + # FP8 init has already been run and recipe is the same, don't do anything. + return + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(fp8_recipes) + + if fp8_enabled: + # Set FP8 and other FP8 metadata + self.fp8_meta["num_gemms"] = num_gemms + self.fp8_meta["fp8_group"] = fp8_group + + # Set FP8_MAX per tensor according to recipe + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + + # Allocate scales and amaxes + self.init_fp8_meta_tensors(fp8_recipes) + self.fp8_initialized = True + + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + + def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None: + """Override to allow multiple recipes. Init scales and amaxes for fwd | bwd.""" + if isinstance(recipe, Recipe): + recipe = [recipe] + fp8_recipe_dpa = recipe[-1] + fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + + # Return early if recipe state matches recipe + if self.fp8_meta_tensors_initialized: + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd) + return + if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return + if fp8_recipe_dpa.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return + if fp8_recipe_dpa.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return + + # When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers. + # See table above in init_fp8_metadata for more detail. + num_gemms = [2, 1] if len(recipe) == 2 else [3] + # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and + # 2 (grad_output and grad_input) for bwd + num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms] + + # Initialize recipe state and quantizers + recipe_states = [ + RecipeState.create( + recipe[i], + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors[i], + ) + for i in range(len(recipe)) + ] + + self.fp8_meta[fp8_meta_tensor_key] = ( + recipe_states[-1] if len(recipe) == 2 else recipe_states[0] + ) + self.quantizers[fp8_meta_tensor_key] = [] + for recipe_state in recipe_states: + self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + @no_torch_dynamo(recursive=False) def forward( self, @@ -485,6 +785,7 @@ def forward( fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, pad_between_seqs: Optional[bool] = None, + fp8_output: Optional[bool] = False, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -657,6 +958,8 @@ def forward( pad_between_seqs: Optional[bool], default = `None` If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If true, there are padding tokens between individual sequences in a packed batch. + fp8_output: Optional[bool], default = `False` + Whether to enforce output to be in FP8 or not. """ with torch.cuda.device(query_layer.device), self.prepare_forward( @@ -693,6 +996,8 @@ def forward( tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" + else: + fp8_output = False # checks for q/k/v shapes assert ( @@ -1092,6 +1397,7 @@ def forward( quantizers=self.quantizers, inference_params=inference_params, flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, ) if use_fused_attention: @@ -1140,6 +1446,7 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + fp8_output=fp8_output, ) return self.fused_attention( query_layer, @@ -1169,6 +1476,7 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + fp8_output=fp8_output, ) from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled @@ -1180,6 +1488,7 @@ def forward( ) if use_unfused_attention: + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -1198,6 +1507,10 @@ def forward( alibi_slopes=alibi_slopes, inference_params=inference_params, softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return self.unfused_attention( _alibi_cache, @@ -1215,5 +1528,9 @@ def forward( alibi_slopes=alibi_slopes, inference_params=inference_params, softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 72c595e3ff..ea7b0e8763 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -17,6 +17,7 @@ from packaging.version import Version as PkgVersion import torch +import torch.distributed as dist import torch.nn.functional as F import transformer_engine_torch as tex import transformer_engine as te @@ -32,11 +33,13 @@ META_DO, META_S, META_DP, - META_O_CP, - META_DQKV_CP, ) from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -44,6 +47,8 @@ from transformer_engine.pytorch.utils import ( get_device_compute_capability, get_cudnn_version, + SplitAlongDim, + combine_tensors, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -54,6 +59,9 @@ # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) +# print quantizer info for a particular layer on a particular rank +_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1")) +_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0")) _cu_seqlens_cache = {} @@ -350,8 +358,31 @@ def get_attention_backend( field.name: getattr(attention_params, field.name) for field in fields(attention_params) } run_config.update(attention_params_dict) + # Add FP8 environment variables to config if fp8: + # all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd + run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1")) + # switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling" + _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") + run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe + if _dpa_fp8_recipe != "": + # config new recipe if switched + run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID") + run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv( + "NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent" + ) + run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int( + os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1") + ) + run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int( + os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") + ) + # UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow + run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") + ) logger.debug("Running with config=%s", run_config) # The following sections check if `FlashAttention` supports the provided attention params, @@ -431,8 +462,20 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") - use_unfused_attention = False + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + if not allow_emulation: + logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") + use_unfused_attention = False + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if ( + use_fused_attention + and fp8_recipe.float8_current_scaling() + and device_compute_capability < (10, 0) + ): + logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") + use_fused_attention = False # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size @@ -1875,11 +1918,10 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): +def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: - num_of_nones = 8 if cp_specific_quantizers else 6 - return [None] * num_of_nones + return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1888,6 +1930,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True dQKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1897,22 +1940,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True - dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] - dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_CP_quantizer.internal = True - O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] - O_CP_quantizer.set_usage(rowwise=True, columnwise=False) - - if cp_specific_quantizers: - return ( + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + +def print_quantizers( + label, + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, +): + """Print the type and scale/amax of attention quantizers""" + _to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2 + if ( + _to_print + and _print_layer == layer_number + and ( + not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank) + ) + ): + names = [ + "QKV_quantizer", + "S_quantizer", + "O_quantizer", + "dO_quantizer", + "dP_quantizer", + "dQKV_quantizer", + ] + quantizers = [ QKV_quantizer, - O_quantizer, - O_CP_quantizer, S_quantizer, - dQKV_quantizer, - dQKV_CP_quantizer, + O_quantizer, dO_quantizer, dP_quantizer, - ) + dQKV_quantizer, + ] + if "forward" in label: + names = names[:3] + quantizers = quantizers[:3] + if "backward" in label: + names = names[3:] + quantizers = quantizers[3:] + for i, q in enumerate(quantizers): + type_str = "" + if q is None: + type_str = "None" + elif isinstance(q, Float8Quantizer): + type_str = "DS" + elif isinstance(q, Float8CurrentScalingQuantizer): + type_str = "CS" + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) - return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + +def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): + """Combine q,k,v based on qkv_layout and quantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + src_nominal_dtype = q.dtype + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv = combine_tensors([q, k, v], dim) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv = combine_tensors([k, v], dim) + tensors = [q, kv] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, kv_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True) + case 3: + tensors = [q, k, v] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype) + for x in [q_data, k_data, v_data] + ] + + return q_fp8, k_fp8, v_fp8 + + +def combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None +): + """Combine q,k,v based on qkv_layout and dequantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + src_nominal_dtype = q_fp8.dtype + else: + assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" + if des_nominal_dtype is None: + des_nominal_dtype = src_nominal_dtype + + q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv_data = combine_tensors([q_data, k_data, v_data], dim) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv_data = combine_tensors([k_data, v_data], dim) + tensors = [q_data, kv_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + k, v = SplitAlongDim.apply(kv, dim, [1, 1], True) + case 3: + tensors = [q_data, k_data, v_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + return q, k, v diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 790d78c75e..b2f1ff1ac9 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Multi-head Attention.""" +import os import collections from typing import Callable, List, Optional, Tuple, Union import torch @@ -31,7 +32,13 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + +# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast(). +# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" +# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1" +_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1" class MultiheadAttention(torch.nn.Module): @@ -570,10 +577,12 @@ def set_context_parallel_group( self.cp_size = get_distributed_world_size(cp_group) self.cp_rank = get_distributed_rank(cp_group) elif isinstance(cp_group, list): - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" assert ( cp_comm_type == "a2a+p2p" ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" + assert ( + len(cp_group) == 2 + ), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" cp_size_a2a = get_distributed_world_size(cp_group[0]) cp_rank_a2a = get_distributed_rank(cp_group[0]) cp_size_p2p = get_distributed_world_size(cp_group[1]) @@ -730,10 +739,22 @@ def forward( # Query, Key, and Value # ====================== - fp8_mha = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.get_fp8_recipe().fp8_mha - ) + fp8 = FP8GlobalStateManager.is_fp8_enabled() + if _dpa_fp8_recipe == "": + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + fp8_dpa = fp8_recipe.fp8_dpa + fp8_mha = fp8_recipe.fp8_mha + float8_current_scaling = fp8_recipe.float8_current_scaling() + else: + fp8_dpa = _dpa_fp8_recipe_dpa + fp8_mha = _dpa_fp8_recipe_mha + float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe + qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling + # DPA: always produce FP8 output when fp8=True to take advantage of the O amax + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + # Proj Gemm: match DPA output except for Float8CurrentScaling + proj_fp8_grad = dpa_fp8_output and not float8_current_scaling layernorm_output = None if self.attention_type == "self": @@ -742,7 +763,7 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -752,7 +773,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) num_queries_per_key_value = ( @@ -806,7 +827,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.qkv_weight_interleaved: @@ -861,7 +882,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -871,7 +892,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) # [sq, b, hp] --> [sq, b, np, hn] @@ -972,6 +993,7 @@ def forward( fast_zero_fill=fast_zero_fill, inference_params=inference_params, pad_between_seqs=pad_between_seqs, + fp8_output=dpa_fp8_output, ) # =================== @@ -980,7 +1002,7 @@ def forward( projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch, - fp8_grad=isinstance(context_layer, QuantizedTensor), + fp8_grad=proj_fp8_grad, ) if self.return_bias: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index df2f5d1cab..94a12c4a09 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -109,9 +109,6 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -# repurpose some unused amax history buffers for partial results of CP fwd and bwd -META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT -META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 def fused_attn_fwd( diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c94bd0d2a5..978bee52dc 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -201,7 +201,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { * amax to be initialized to zero. */ std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype); + const std::vector& shape, DType dtype, std::optional data = std::nullopt); std::pair convert_and_update_tensor(py::object shape) const override; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4edc6d81e1..cc33f2a89c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -78,6 +78,11 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data); + std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 5db9dd73da..344bc4ab0b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -53,6 +53,47 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( return fused_attention_backend; } +// helper function for S and dP quantizers +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data) { + std::unique_ptr T_quantizer = convert_quantizer(quantizer); + TensorWrapper te_T; + py::object py_T; + if (quantizer.is_none()) { + // high precision + auto *none_quantizer = dynamic_cast(T_quantizer.get()); + if (data.has_value()) { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype); + } + } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + // delayed scaling; this helps initialize scale_inv + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + std::tie(te_T, py_T) = + T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // current scaling + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK( + !data.has_value(), + "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + } + } + return {std::move(te_T), std::move(py_T)}; +} + // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, @@ -66,44 +107,30 @@ std::vector fused_attn_fwd( py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, size_t rng_elts_per_thread) { - TensorWrapper te_Q, te_K, te_V, te_O, te_S; - auto none = py::none(); - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); + // create QKV tensor wrappers + TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); - - // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + // create S tensor + TensorWrapper te_S; + py::object py_S; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + + // create O tensor + TensorWrapper te_O; + py::object py_O; + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); - std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - // create output tensor O - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; - py::object o_python, s_python; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // Initialize FP8 tensor with scale-inverse - auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - } - auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); // construct NVTE tensors TensorWrapper te_Bias; @@ -114,11 +141,12 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - te_O.zero_(at::cuda::getCurrentCUDAStream()); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if ((h * d) % block_size == 0) { + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + te_O.zero_(at::cuda::getCurrentCUDAStream()); + } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { @@ -181,7 +209,8 @@ std::vector fused_attn_fwd( auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options); philox_unpack(philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); @@ -210,7 +239,7 @@ std::vector fused_attn_fwd( // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; - output_tensors.push_back(o_python); + output_tensors.push_back(py_O); auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { output_tensors.push_back(py::cast(output_tensor)); NVTEBasicTensor temp_data = {output_tensor.data_ptr(), @@ -280,50 +309,44 @@ std::vector fused_attn_bwd( const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer) { auto none = py::none(); - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + + // create QKV, O, dO tensor wrappers + TensorWrapper te_Q, te_K, te_V, te_O, te_dO; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); te_O = makeTransformerEngineTensor(O, none); te_dO = makeTransformerEngineTensor(dO, none); - // qkv type from the te_Q - std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); - const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - py::object s_python, dp_python; - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); - } + // create S and dP tensors + TensorWrapper te_S, te_dP; + py::object py_S, py_dP; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + std::tie(te_dP, py_dP) = + quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt); + // create dQ, dK, dV tensors + TensorWrapper te_dQ, te_dK, te_dV; + py::object py_dQ, py_dK, py_dV; + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d_qk = q_shape[q_shape.size() - 1]; - auto d_v = v_shape[v_shape.size() - 1]; - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - std::vector o_shape{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = d_v; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); at::Tensor dQ, dK, dV, dQKV, dKV; - py::object py_dQ, py_dK, py_dV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + options = options.dtype(torch::kUInt8); + } + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + options = options.dtype(fake_dtype); + } switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: @@ -396,39 +419,27 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_dQ, py_dQ) = - fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); - std::tie(te_dK, py_dK) = - fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); - std::tie(te_dV, py_dV) = - fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); - } else { - auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); - std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); - } + + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); // construct NVTE tensors - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dK.fill_(0); - dV.fill_(0); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && + dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } } - - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); @@ -605,7 +616,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 int seq_dim = tensor.dim() == 3 ? 0 : 1; - int batch = cu_seqlens.size(0) - 1; int num_heads = tensor.size(seq_dim + 1); int dim_per_head = tensor.size(seq_dim + 2); int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); @@ -769,8 +779,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t NVTE_CHECK(world_size > 0); NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); - int batch = cu_seqlens.size(0) - 1; - std::vector shape = {total_tokens / world_size}; at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); @@ -808,7 +816,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, **************************************************************************************************/ at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { - int max_seq_len = tensor.size(1); int h = tensor.size(2); int d = tensor.size(3); std::vector shape = {t, h, d}; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index e9647b44fe..2c1edae4c6 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob // Convert input tensor to C++ object auto input_contiguous = tensor.contiguous(); - const auto input_cpp = makeTransformerEngineTensor(input_contiguous); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + // Set amax if use_existing_amax = true (only valid for CS) + bool use_existing_amax = false; + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + use_existing_amax = quantizer.attr("use_existing_amax").cast(); + if (use_existing_amax) { + const at::Tensor &amax = quantizer.attr("amax").cast(); + input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + } // Initialize output tensor TensorWrapper output_cpp; @@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob } // Perform quantization - quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + if (use_existing_amax) { + auto *quantizer_cs = dynamic_cast(quantizer_cpp.get()); + quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp); + } else { + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + } return output_py; } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 2abe9614e1..8470466aef 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -390,9 +390,13 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::pair Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, - DType dtype) { + DType dtype, + std::optional data) { amax.zero_(); - auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); return {std::move(out_cpp), std::move(out_py)}; diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index a75a03bfa5..15017913fe 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -970,7 +970,9 @@ def make_quantizers(self) -> list: from .tensor.float8_tensor import Float8CurrentScalingQuantizer return [ - Float8CurrentScalingQuantizer(self.dtype, device=self.device) + Float8CurrentScalingQuantizer( + self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales + ) for i in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 1524584aa7..18750d0392 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -215,6 +215,8 @@ class Float8CurrentScalingQuantizer(Quantizer): amax: torch.Tensor """FP8 datatype""" dtype: TE_DType + """amax update options""" + use_existing_amax: bool """amax reduction options""" with_amax_reduction: bool amax_reduction_group: Optional[dist_group_type] @@ -229,6 +231,7 @@ def __init__( *, rowwise: bool = True, columnwise: bool = True, + use_existing_amax: bool = False, with_amax_reduction: bool = False, amax_reduction_group: Optional[dist_group_type] = None, force_pow_2_scales: bool = False, @@ -238,6 +241,7 @@ def __init__( self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype + self.use_existing_amax = use_existing_amax self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales From 7fa0f5541bff9df574c7b7c7c6b6cd46e9009b57 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 30 Sep 2025 16:51:43 -0700 Subject: [PATCH 45/78] [Pytorch] Support for Swiglu Activation used in GPT OSS (#2161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Test working as I think it should work Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> Signed-off-by: Varun Thumbe minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe fix linting error Signed-off-by: Varun Thumbe * initial draft of changes to get GPT oss based swiglu integrated, gated kernels needs to be fixed Signed-off-by: Varun Thumbe * redundant implementation for the pytorch to te hook up, refactoring to be done later Signed-off-by: Varun Thumbe * all gated kernels modified, pytest working for oss swiglu Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * fix the merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Signed-off-by: Varun Thumbe Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: Varun Thumbe [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe Update list of authorized CI users (#2152) Signed-off-by: Tim Moon Signed-off-by: Varun Thumbe a bit of cleanup Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * accidentally had removed some activations, minor bug in the templated function Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262513 +0000 parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262476 +0000 parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262304 +0000 merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen Update list of authorized CI users (#2152) Signed-off-by: Tim Moon Fused RoPE with combined QKV input. (#2122) * Fused RoPE with combined QKV input. Initial commit for Dropout with 8-bit RNG Fix documentation Initial commit for Fused QKV RoPE WIP Initial tests passing Enable rotary percent and margin Enable CP2, start_positions, interleaved Cleanup test Revert "Fix documentation" This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715. Revert "Initial commit for Dropout with 8-bit RNG" This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca. Cleanup. Minor cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernels Signed-off-by: Vasudevan Rengasamy * Misc. Cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernel performance Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Move fused_qkv_rope test to test_fused_rope.py Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply shared memory optimization to separate fused rope kernels Signed-off-by: Xin Yao * fix lint Signed-off-by: Xin Yao --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * accidentally removed the copyright Signed-off-by: Varun Thumbe * fix linting issue Signed-off-by: Varun Thumbe * minor issue in comments Signed-off-by: Varun Thumbe * Commit is for another PR Signed-off-by: vthumbe1503 * revert changes since this belongs to another PR Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert change back since belongs to another PR Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changes belong to another PR Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert changes here Signed-off-by: vthumbe1503 Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162) * add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion Signed-off-by: tongliu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tongliu Co-authored-by: tongliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [JAX] Scale swizzling via JAX transpose op (#2163) * add swizzle in jax Signed-off-by: Phuong Nguyen * added outer_impl Signed-off-by: Phuong Nguyen * clean up FFI Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Extract cpp distributed tests into a separate project (#2165) * Extract cpp distributed tests into a separate project Signed-off-by: Vladimir Cherepanov * Remove obsolete exclusion Signed-off-by: Vladimir Cherepanov * Run L1_cpp_distributed tests if at least 4 GPUs Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129) * test - adds unit test for cp utilities and the utilites Signed-off-by: Jonathan Mitchell * assert line change Signed-off-by: Jonathan Mitchell * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jonathan Mitchell Co-authored-by: Jonathan Mitchell Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sudhakar Singh * address review comments Signed-off-by: Varun Thumbe * cleanup Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linting error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [PyTorch Debug] Fix issue with negative underflow% stat. (#2107) * fix underflows log issue Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Address review comments, fix mxfp8 kernel bug: was not passing clamped swiglu parameter correctly Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Lower precision gated-act to accelerate FP8 current-scaling. (#2153) * Applying the original precision as Norm outputs' and activation compuations. Signed-off-by: Ming Huang * Adding knob to control norm output precision. Signed-off-by: Ming Huang * Removing the knob and applying lower-precision norm with current-scaling only. Signed-off-by: Ming Huang * Fix the error when quantizer==None Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang [PyTorch] Support activation CPU offloading in fusible ops (#2158) * Add CPU offloading logic to ops. Fix test to compute dgrad. Signed-off-by: Tim Moon * Make sure grads are contiguous in op backwards Signed-off-by: Tim Moon * Add op-based MLP to CPU offloading tests Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle different weight cache behavior on Hopper/Blackwell Add MXFP8 to CPU offload tests. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove MXFP8 test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174) * Do not use norm fwd + amax fusion if cudnn backend is requested Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Read envirornment vairable directly to avoid include error Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix unjoined comm stream in UB communicator (#2160) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> FP8 Output Quantization for GEMM (#2123) * Test working as I think it should work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> address review comments Signed-off-by: Varun Thumbe * Update test_multi_process_distributed_grouped_gemm.py change accidentally added while merging Signed-off-by: vthumbe1503 * Update dense.py change accidentally added while merging Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address revie comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug solved: delayed scaling quantization with mxfp8 inputs didnt work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the unit test error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * just to trigger ci Signed-off-by: Varun Thumbe * address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * fix merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> TE Gemma tutorial attempt#2 (#1839) * add tutorial files and other local changes Signed-off-by: Sudhakar Singh * remove extraneous code for easy debu Signed-off-by: Sudhakar Singh * make cuda graphs work with non-paged and paged attention Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * perf imp for kv cache ops Signed-off-by: Sudhakar Singh * add code for calibration Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize kv_cache reindex and copy kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes to make quantizers work with fp8_calibration Signed-off-by: Sudhakar Singh * avoid reindexing from python side Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename variable from previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use quantizer only if needed Signed-off-by: Sudhakar Singh * functionality of the tutorial tested and perf checked Signed-off-by: Sudhakar Singh * remove files and update headers/licenses Signed-off-by: Sudhakar Singh * update header/license Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tutorial for review Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make weights downloadable on the fly; remove extra print statements Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint and update comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comma back, typo Signed-off-by: Sudhakar Singh * sequence_start_positions should be None for training Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add paged attention numberes and update requirements.txt file Signed-off-by: Sudhakar Singh * more fixes Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make tutorial work on blackwell Signed-off-by: Sudhakar Singh * remove gemma FT tutorial for now Signed-off-by: Sudhakar Singh * fixing the headings placement and rewording attention -> kv caching Signed-off-by: Sudhakar Singh * fixes from comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the images Signed-off-by: Sudhakar Singh * misc fixes Signed-off-by: Sudhakar Singh * add more comments to te_gemma.py and cleanup utils.py Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more information about the hierarchy of the classes used in the tutorial Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add better cuda graphs picture Signed-off-by: Sudhakar Singh * addd updated cuda graphs pictures Signed-off-by: Sudhakar Singh * add illustrated cuda graphs Signed-off-by: Sudhakar Singh * fix Signed-off-by: Sudhakar Singh * small fixes in documentation Signed-off-by: Sudhakar Singh * add torch.no_grad() to force reduced memory usage Signed-off-by: Sudhakar Singh * some fixes from recent comments Signed-off-by: Sudhakar Singh * more fixes from remaining comments Signed-off-by: Sudhakar Singh * add te_rope_emb to class desc Signed-off-by: Sudhakar Singh * fix tutorial wording; add calibration fix to grouped_linear.py Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Fix memory overhead of linear layer when all gather from sequence parallel (#2125) * fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon --------- Signed-off-by: Yuzhong Wang Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Fix incorrect TP rank calculation when using data parallel (#2179) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> [Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045) * feat: add cutlass group gemm support Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor multi tensor gemm interface Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor nvte_multi_stream_cublas_gemm func and add license info Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add unit test for cutlass group gemm Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add cutlass support type protect Signed-off-by: Min Yang * add tests and fix lint Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: fix unit tests error Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: refactor host workspace malloc Signed-off-by: Min Yang * update cutlass Signed-off-by: Xin Yao * update cutlass Signed-off-by: Xin Yao * further relex threshold and add a env var to warn fall back Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Min Yang Signed-off-by: Xin Yao Signed-off-by: alan yang <89962857+cassiewilliam@users.noreply.github.com> Co-authored-by: Min Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Phuong Nguyen [PyTorch] Support FA3 for MLA and with CP (#1907) feature(FA3,MLA,CP): 1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward 2. Update get_attention_backend method because FA3 support MLA now 3. Add CP MLA support for FA3 4. Add unit tests for FA3 MLA CP 5. Update attention doc Signed-off-by: zhujian Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185) * Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix for cuDNN version condition check Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use limit=0.75 in clamped SwiGLU test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * accidentally removed a line while resolving merge conflict Signed-off-by: Varun Thumbe * match pytorch implementation: dclamp should be 1 for borders of clamping limits as well Signed-off-by: Varun Thumbe * fix dswiglu quantization fusion bug Signed-off-by: Varun Thumbe * pass param by reference as much as possible Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * float should rather be bool: fix by copilot Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * { missed in activation.cpp Signed-off-by: Varun Thumbe * minor formatting change Signed-off-by: Varun Thumbe * nvfp4 change Signed-off-by: Varun Thumbe --------- Signed-off-by: Varun Thumbe Signed-off-by: Vasudevan Rengasamy Signed-off-by: Xin Yao Signed-off-by: vthumbe1503 Signed-off-by: Jonathan Mitchell Signed-off-by: Pawel Gadzinski Signed-off-by: Kshitij Lakhani Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Jonathan Mitchell Co-authored-by: Sudhakar Singh --- tests/pytorch/test_fusible_ops.py | 74 +++++++++ .../common/activation/activation_template.h | 10 +- transformer_engine/common/activation/gelu.cu | 12 +- transformer_engine/common/activation/relu.cu | 12 +- .../common/activation/swiglu.cu | 23 ++- .../include/transformer_engine/activation.h | 40 +++++ .../common/util/cast_gated_kernels.cuh | 146 +++++++++++------- transformer_engine/common/util/math.h | 39 ++++- .../common/util/vectorized_pointwise.h | 26 +++- transformer_engine/pytorch/csrc/extensions.h | 4 + .../pytorch/csrc/extensions/activation.cpp | 139 ++++++++++++----- .../pytorch/csrc/extensions/pybind.cpp | 6 + .../pytorch/ops/basic/__init__.py | 14 +- .../pytorch/ops/basic/activation.py | 36 +++++ 14 files changed, 459 insertions(+), 122 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 4409866617..231fa64bc1 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1736,6 +1736,80 @@ def test_swiglu( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantize_forward", (False, True)) + @pytest.mark.parametrize("quantize_backward", (False, True)) + def test_clamped_swiglu( + self, + *, + out_shape: Iterable[int] = (32, 32), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantize_forward: bool, + quantize_backward: bool, + limit: float = 0.75, + alpha: float = 1.702, + ): + # Test SwiGLU variant used in GPT OSS. + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and (quantize_forward or quantize_backward): + pytest.skip("Quantization scheme has not been provided") + maybe_skip_quantization(quantization, dims=in_shape, device=device) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x_glu, x_linear = x_ref.chunk(2, dim=-1) + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + y_ref = out_glu * (x_linear + 1) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + recipe = make_recipe(quantization) + + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantize_backward), + te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), + te_ops.Quantize(forward=quantize_forward, backward=False), + ) + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = forward(x_test) + + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if quantized_compute and quantization == "nvfp4": + tols = dtype_tols(tex.DType.kFloat4E2M1) + elif quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("dtype", _dtypes) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 67f173a4ab..1d9a3fb43c 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, } template -void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { +void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = false; constexpr NVTETensor grad = nullptr; - - quantize_gated_helper(grad, input, output, stream); + quantize_gated_helper(grad, input, output, p, stream); } template -void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = true; - - quantize_gated_helper(grad, input, output, stream); + quantize_gated_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 0cf43007a7..4949ba5906 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dgelu>(grad, input, output, e, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dqgelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index a794b7315f..c74fc6eee9 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, drelu>(grad, input, output, e, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dsrelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 8194964745..cafc48abba 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dsilu>(grad, input, output, e, stream); +} + +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream) { + NVTE_API_CALL(nvte_clamped_swiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha}; + gated_act_fn>(input, output, param, stream); +} + +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, float alpha, cudaStream_t stream) { + NVTE_API_CALL(nvte_clamped_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha}; + dgated_act_fn, clamped_dsilu>( + grad, input, output, param, stream); } diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 49029ed588..e50d71040d 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -173,6 +173,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input used in GPT OSS. + * + * See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This Gated activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -230,6 +250,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS. + * + * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, float alpha, cudaStream_t stream); + /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 6093b54b6d..ca37a28319 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -55,7 +55,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_act, const __grid_constant__ CUtensorMap tensor_map_output_gate, float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols) { + const float *const scale_ptr, const size_t rows, const size_t cols, + const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; @@ -161,7 +162,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; OType *out_act_sh_curr = out_act_sh + buff * buff_elems; OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; - #pragma unroll for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; @@ -171,6 +171,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); @@ -178,18 +184,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + if (act_elt <= p.limit) { + dact_x = s + s * (1 - s) * p.alpha * x; + } else { + dact_x = 0.0f; + } } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt; + float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); @@ -197,7 +212,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) amax = fmaxf(amax, fabsf(after_dact)); amax = fmaxf(amax, fabsf(after_dgate)); } else { - const float after_act = ActOP(act_elt, {}) * gate_elt; + const float after_act = ActOP(act_elt, p) * gate_elt; out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); amax = fmaxf(amax, fabsf(after_act)); } @@ -300,7 +315,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + const size_t scale_stride_colwise, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -476,25 +491,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); float after_act_elt; float after_gate_elt; - + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; } // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { @@ -720,27 +747,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate.data.elt[e]); float after_act_elt; float after_gate_elt; - + bool dgate_elt = true; + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; after_act_rowwise[j] = after_act_elt; after_gate_rowwise[j] = after_gate_elt; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; after_act_rowwise[j] = after_act_elt; } @@ -885,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -960,15 +999,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) } template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1099,7 +1137,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -1116,7 +1155,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: @@ -1125,7 +1165,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out OType, true, true, THREADS_PER_CHUNK_NON_COLWISE>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - mxfp8_kernel::cast_mxfp8_gated_kernel <<>>( @@ -1133,7 +1172,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; }); // NOLINT(*) @@ -1141,12 +1180,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } template -void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); NVTE_CHECK(input.flat_last_dim() % 2 == 0, "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); @@ -1168,7 +1204,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), {}, stream); + output->flat_last_dim(), p, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -1177,7 +1213,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, + cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); CheckOutputTensor(*output, "dgated_act_output"); @@ -1206,7 +1243,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), {}, stream); + grad.flat_last_dim(), p, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -1215,7 +1252,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, cudaStream_t stream) { constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); @@ -1255,17 +1292,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (is_delayed_tensor_scaling(output->scaling_mode)) { if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, stream); + cast_fp8_gated(grad, gated_input, output, p, stream); } else { if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); + cast_dgated(grad, gated_input, output, p, stream); } else { - cast_gated(gated_input, output, stream); + cast_gated(gated_input, output, p, stream); } } } else if (is_mxfp8_scaling(output->scaling_mode)) { if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, stream); + cast_mxfp8_gated(grad, gated_input, output, p, stream); } else { NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", "by 32, got input of shape ", gated_input.data.shape); @@ -1281,7 +1318,7 @@ namespace detail { template void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - cudaStream_t stream) { + ParamOP &p, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; @@ -1290,13 +1327,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, if (is_supported_by_CC_100()) { quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, stream); + output_tensor, p, stream); } else { if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, + stream); } else { - cast_gated(gated_input_tensor, output_tensor, stream); + cast_gated(gated_input_tensor, output_tensor, p, stream); } } else { // MX scaling diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 2d425d6753..2f20817fb0 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -11,6 +11,11 @@ namespace transformer_engine { struct Empty {}; +struct ClampedSwiGLUParam { + float limit; + float alpha = 1.702f; // Default value for QuickGELU +}; + template __device__ inline OType gelu(const IType val, const Empty&) { const float cval = val; @@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) { return s * (1.f - s); } +template +__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) { + const float cval = val; + Empty e = {}; + return cval * sigmoid(alpha * cval, e); +} + template __device__ inline OType qgelu(const IType val, const Empty& e) { + return qgelu_with_alpha(val, 1.702f); +} + +template +__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) { const float cval = val; - return cval * sigmoid(1.702f * cval, e); + Empty e = {}; + return alpha * cval * dsigmoid(alpha * cval, e) + + sigmoid(alpha * cval, e); } template __device__ inline OType dqgelu(const IType val, const Empty& e) { - const float cval = val; - return 1.702f * cval * dsigmoid(1.702f * cval, e) + - sigmoid(1.702f * cval, e); + return dqgelu_with_alpha(val, 1.702f); } template @@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) { return cval * sigmoid(cval, e); } +template +__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) { + const float cval = min(p.limit, static_cast(val)); // Clamping + return qgelu_with_alpha(cval, p.alpha); +} + template __device__ inline OType dsilu(const IType val, const Empty& e) { const float cval = val; return cval * dsigmoid(cval, e) + sigmoid(cval, e); } +template +__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) { + const bool dclamp_val = static_cast(val) <= p.limit; + const float clamp_val = min(static_cast(val), p.limit); + const float dsilu_val = dqgelu_with_alpha(clamp_val, p.alpha); + return dclamp_val ? dsilu_val : 0.0f; +} + template __device__ inline OType relu(IType value, const Empty&) { return fmaxf(value, 0.f); diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 0d667a0ece..dd6869e027 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -11,7 +11,7 @@ #include "../common.h" #include "../utils.cuh" - +#include "math.h" namespace transformer_engine { /* \brief Helper class that enables storing multiple values of type DType @@ -338,7 +338,7 @@ template void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, - const Param params, cudaStream_t stream) { + const Param ¶ms, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -372,7 +372,7 @@ template void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, OutputType *output, const fp32 *scale, fp32 *amax, - fp32 *scale_inv, const size_t N, const Param params, + fp32 *scale_inv, const size_t N, const Param ¶ms, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, grad, output); @@ -431,7 +431,13 @@ __launch_bounds__(unary_kernel_threads) __global__ #pragma unroll for (int i = 0; i < nvec; ++i) { const ComputeType val = static_cast(loader0.separate()[i]); - const ComputeType val2 = static_cast(loader1.separate()[i]); + ComputeType val2 = static_cast(loader1.separate()[i]); + + if constexpr (std::is_same::value) { + // Clamp the gated value and add 1 at the end + ComputeType limit = p.limit; + val2 = std::min(std::max(-limit, val2), limit) + 1; + } ComputeType temp = static_cast(Activation(val, p) * val2); if (requires_amax) { __builtin_assume(max >= 0); @@ -532,10 +538,18 @@ __launch_bounds__(unary_kernel_threads) __global__ for (int i = 0; i < nvec; ++i) { const ComputeType grad_val = static_cast(grad_loader.separate()[i]); const ComputeType gelu_in = static_cast(input_loader0.separate()[i]); - const ComputeType gate_in = static_cast(input_loader1.separate()[i]); + ComputeType gate_in = static_cast(input_loader1.separate()[i]); + bool dgate_in = true; + + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + const ComputeType limit = p.limit; + dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp + gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; + } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; - ComputeType after_dgate = grad_val * Activation(gelu_in, p); + ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f; if (requires_amax) { __builtin_assume(max >= 0); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index cc33f2a89c..d86a96959c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -205,6 +205,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); + +py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + float limit, float alpha); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index cdfb4be408..14cc084c0c 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -3,7 +3,6 @@ * * See LICENSE for license information. ************************************************************************/ - #include "../extensions.h" #include "common.h" #include "pybind.h" @@ -12,10 +11,12 @@ namespace transformer_engine { namespace pytorch { namespace { +using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t); +using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); -py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t), - const at::Tensor& input, py::handle quantizer, - int shape_divisor = 1) { +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, + Args&&... args) { init_extension(); // Input tensor @@ -56,14 +57,28 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud // Compute activation in high precision, then quantize { auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); quantizer_cpp->quantize(temp_nvte, out_nvte); } break; case Impl::FULLY_FUSED: // Compute activation directly { - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), out_nvte.data(), stream); + } + }); } break; case Impl::FUSED_ACTIVATION_AMAX_FP8: @@ -73,7 +88,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); auto [temp_nvte, _] = fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; @@ -84,7 +106,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud static_cast(quantizer_cpp.get()); // Already checked cast is valid auto [temp_nvte, _] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; @@ -95,10 +124,9 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud return out_py; } -py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, - cudaStream_t), - const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer) { +template +py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, + py::handle quantizer, Args&&... args) { init_extension(); // Grad output and input tensors @@ -142,8 +170,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen { auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), - at::cuda::getCurrentCUDAStream()); + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } }); quantizer_cpp->quantize(temp_nvte, grad_input_nvte); } @@ -152,7 +184,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen // Compute activation backward directly { NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + } }); } break; @@ -163,8 +200,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); auto [temp_nvte, _] = fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } + }); fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; @@ -175,8 +218,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen static_cast(quantizer_cpp.get()); // Already checked cast is valid auto [temp_nvte, _] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } + }); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; @@ -186,90 +235,98 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen return grad_input_py; } - } // namespace /* GELU and variants */ py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_gelu, input, quantizer); + return activation_helper(input, quantizer); } py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dgelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_geglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dgeglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_qgelu, input, quantizer); + return activation_helper(input, quantizer); } py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dqgelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_qgeglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dqgeglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } /* ReLU and variants */ py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_relu, input, quantizer); + return activation_helper(input, quantizer); } py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_drelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_reglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dreglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_srelu, input, quantizer); + return activation_helper(input, quantizer); } py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dsrelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_sreglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dsreglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } - /* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_silu, input, quantizer); + return activation_helper(input, quantizer); } py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dsilu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_swiglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dswiglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); +} + +/* clamped functions */ +py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { + return activation_helper(input, quantizer, 2, limit, alpha); +} + +py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + float limit, float alpha) { + return dactivation_helper(grad, input, quantizer, limit, alpha); } } // namespace pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 98f71f9a7b..382adbfb05 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -155,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); + m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, + "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), + py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -178,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, + "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 2c903675fb..28d49bf7b9 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,7 +4,19 @@ """Single tensor operations supported by the operation fuser.""" -from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU +from .activation import ( + GELU, + GEGLU, + QGELU, + QGEGLU, + ReLU, + ReGLU, + SReLU, + SReGLU, + SiLU, + SwiGLU, + ClampedSwiGLU, +) from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 22779b6017..8a754c6382 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -28,6 +28,7 @@ "SReGLU", "SiLU", "SwiGLU", + "ClampedSwiGLU", ] @@ -392,3 +393,38 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dswiglu(*args, **kwargs) + + +class ClampedSwiGLU(_ActivationOperation): + r"""GPT-OSS + Implementation based on `GPT-OSS`__. + + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + + Parameters + ---------- + limit: float + The clamp limit. + alpha: float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input: bool, default = False + Quantize input tensor when caching for use in the backward pass. + """ + + def __init__( + self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False + ): + super().__init__(cache_quantized_input=cache_quantized_input) + self.limit = limit + self.alpha = alpha + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) From ce18bee70fe98f041e91666f1935ffdd524090db Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:53:01 -0700 Subject: [PATCH 46/78] [JAX] Load modules during initialize for Norm and Act primitives (#2219) Load modules during initialize Signed-off-by: Jeremy Berchtold Co-authored-by: JAX Toolbox --- transformer_engine/jax/csrc/extensions.h | 4 ++ .../jax/csrc/extensions/activation.cpp | 58 +++++++++++++++++ transformer_engine/jax/csrc/extensions/ffi.h | 15 +++++ .../jax/csrc/extensions/normalization.cpp | 63 +++++++++++++++++++ .../jax/csrc/extensions/pybind.cpp | 10 ++- 5 files changed, 148 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 92937dd461..2ab95002fa 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -41,16 +41,20 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D // Activation XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, JAXX_Scaling_Mode scaling_mode, bool is_2x); // Normalization +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 17fa9906bb..b2b3db52c8 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -148,6 +148,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Attr("is_2x"), FFI_CudaGraph_Traits); +Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, int64_t act_enum, + JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { + return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, + colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, + act_enum, scaling_mode, is_2x_int); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x")); + pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, JAXX_Scaling_Mode scaling_mode, bool is_2x) { @@ -410,5 +434,39 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("is_2x") .Attr("is_dbias"), FFI_CudaGraph_Traits); + +Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type act_input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, + bool is_2x, bool is_dbias) { + return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, + act_input_buf, scale_buf, output_buf, colwise_output_buf, + scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, + DActLuDBiasQuantizeInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Ret() // dbias + .Ret() // wkspace + .Attr("scaling_mode") + .Attr("act_enum") + .Attr("is_2x") + .Attr("is_dbias")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 852a67c6cb..82f062a15b 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream; using Dictionary = xla::ffi::Dictionary; constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare; +constexpr auto FFI_Initialize = xla::ffi::ExecutionStage::kInitialize; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); @@ -106,5 +107,19 @@ inline static size_t te_dtype_bytes(const DType& type) { } } +template +Error_Type wrapInStreamCapture(std::function func, + cudaStream_t stream, Args... args) { + cudaGraph_t graph{}; + NVTE_CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed)); + + Error_Type error = func(stream, std::forward(args)...); + + NVTE_CHECK_CUDA(cudaStreamEndCapture(stream, &graph)); + NVTE_CHECK_CUDA(cudaGraphDestroy(graph)); + + return error; +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index c35bc6668e..5238193922 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("is_2x"), FFI_CudaGraph_Traits); +Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, + Buffer_Type gamma_buf, Buffer_Type beta_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, Result_Type mu_buf, + Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, + bool zero_centered_gamma, double epsilon, int64_t sm_margin, + JAXX_Scaling_Mode scaling_mode, bool is_2x) { + return wrapInStreamCapture( + std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf, + colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf, + wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // scale + .Arg() // gamma + .Arg() // beta + .Ret() // output + .Ret() // colwise_output + .Ret() // scale_inv + .Ret() // colwise_scale_inv + .Ret() // amax + .Ret() // mu + .Ret() // rsigma + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("epsilon") + .Attr("sm_margin") + .Attr("scaling_mode") + .Attr("is_2x")); + pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, NVTE_Norm_Type norm_type, bool zero_centered_gamma, int sm_margin) { @@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI, .Attr("sm_margin"), FFI_CudaGraph_Traits); +Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, + Buffer_Type mu_buf, Buffer_Type rsigma_buf, + Buffer_Type gamma_buf, Result_Type xgrad_buf, + Result_Type wgrad_buf, Result_Type dbeta_buf, + Result_Type wkspace_buf, int64_t norm_type, + bool zero_centered_gamma, int64_t sm_margin) { + return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf, + rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf, + norm_type, zero_centered_gamma, sm_margin); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // dz + .Arg() // x + .Arg() // mu + .Arg() // rsigma + .Arg() // gamma + .Ret() // xgrad + .Ret() // wgrad + .Ret() // dbeta + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("sm_margin")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 06e2e2e005..36dd8205bf 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -22,8 +22,12 @@ pybind11::dict Registrations() { pybind11::dict dict; // Activation - dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); - dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler); + dict["te_act_lu_ffi"] = + pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ActLuHandler)); + dict["te_dact_dbias_quantize_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler)); // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); @@ -44,9 +48,11 @@ pybind11::dict Registrations() { // Normalization dict["te_norm_forward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler)); dict["te_norm_backward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); // Attention From 7022d50fe1e95eafa7771b74029ea0e3bac9b6d7 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 1 Oct 2025 10:10:48 +0200 Subject: [PATCH 47/78] [PyTorch] Quantizer as API (#2039) * Introduce QuantizerBase Signed-off-by: Evgeny * Expose as a first-class API Signed-off-by: Evgeny * Undo QuantizerBase Signed-off-by: Evgeny * Make Quantizer a base class without implementations Signed-off-by: Evgeny * Support CustomRecipe and CustomRecipeState Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolving comments: quantize impl, num_quantizers, defaults Signed-off-by: Evgeny * Quantizer factories Signed-off-by: Evgeny * Add tests Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * QuantizedTensorBase _get_quantizer() + quantize_() Signed-off-by: Evgeny * Experimental note + LayerNormMLP fix Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor._internal -> tensor.base Signed-off-by: Evgeny * Expose Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor import fix Signed-off-by: Evgeny * Single quantizer factory with roles Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * More context for qfactory, fwd/bwd_roles Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor Signed-off-by: Evgeny * Rename *Base -> *Storage quantized tensors Signed-off-by: Evgeny * make_quantizers() will take roles from the operation Signed-off-by: Evgeny * Improve tests and fix missing imports Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Merge main followup Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Evgeny Signed-off-by: Evgeny Tsykunov Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_custom_recipe.py | 290 ++++++++++++++++++ transformer_engine/common/recipe/__init__.py | 42 ++- .../debug/features/log_tensor_stats.py | 6 +- .../debug/pytorch/debug_quantization.py | 4 +- transformer_engine/pytorch/__init__.py | 15 + .../pytorch/cpp_extensions/gemm.py | 6 +- transformer_engine/pytorch/cpu_offload.py | 8 +- .../pytorch/csrc/extensions/cast.cpp | 4 +- .../pytorch/csrc/extensions/pybind.cpp | 36 +-- transformer_engine/pytorch/csrc/pybind.h | 16 +- transformer_engine/pytorch/csrc/quantizer.cpp | 10 +- transformer_engine/pytorch/distributed.py | 80 ++--- .../pytorch/experimental/quantization.py | 6 +- transformer_engine/pytorch/fp8.py | 56 ++++ transformer_engine/pytorch/module/base.py | 60 ++-- .../pytorch/module/grouped_linear.py | 17 +- .../pytorch/module/layernorm_linear.py | 28 +- .../pytorch/module/layernorm_mlp.py | 56 ++-- transformer_engine/pytorch/module/linear.py | 30 +- transformer_engine/pytorch/ops/_common.py | 8 +- .../pytorch/ops/basic/basic_linear.py | 4 +- .../pytorch/ops/basic/dropout.py | 4 +- .../ops/fused/userbuffers_forward_linear.py | 4 +- transformer_engine/pytorch/tensor/__init__.py | 50 ++- .../pytorch/tensor/_internal/__init__.py | 4 - .../pytorch/tensor/float8_blockwise_tensor.py | 28 +- .../pytorch/tensor/float8_tensor.py | 38 ++- .../pytorch/tensor/mxfp8_tensor.py | 35 +-- .../pytorch/tensor/nvfp4_tensor.py | 10 +- .../pytorch/tensor/quantized_tensor.py | 90 ++++-- .../pytorch/tensor/storage/__init__.py | 9 + .../float8_blockwise_tensor_storage.py} | 10 +- .../float8_tensor_storage.py} | 14 +- .../mxfp8_tensor_storage.py} | 14 +- .../nvfp4_tensor_storage.py} | 12 +- transformer_engine/pytorch/tensor/utils.py | 8 +- transformer_engine/pytorch/utils.py | 8 +- 37 files changed, 808 insertions(+), 312 deletions(-) create mode 100644 tests/pytorch/test_custom_recipe.py delete mode 100644 transformer_engine/pytorch/tensor/_internal/__init__.py create mode 100644 transformer_engine/pytorch/tensor/storage/__init__.py rename transformer_engine/pytorch/tensor/{_internal/float8_blockwise_tensor_base.py => storage/float8_blockwise_tensor_storage.py} (98%) rename transformer_engine/pytorch/tensor/{_internal/float8_tensor_base.py => storage/float8_tensor_storage.py} (96%) rename transformer_engine/pytorch/tensor/{_internal/mxfp8_tensor_base.py => storage/mxfp8_tensor_storage.py} (97%) rename transformer_engine/pytorch/tensor/{_internal/nvfp4_tensor_base.py => storage/nvfp4_tensor_storage.py} (98%) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py new file mode 100644 index 0000000000..cb840f1971 --- /dev/null +++ b/tests/pytorch/test_custom_recipe.py @@ -0,0 +1,290 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.common import recipe +from transformer_engine.pytorch.fp8 import check_fp8_support, fp8_autocast +from transformer_engine.pytorch import Linear +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear +from transformer_engine.pytorch.module.layernorm_mlp import LayerNormMLP +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.module.grouped_linear import GroupedLinear + + +@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"]) +def test_custom_recipe_sanity(module_type): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + # Simple linear layer with dims divisible by 16 + in_features = 64 + out_features = 64 + batch = 32 + + if module_type == "Linear": + model = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + elif module_type == "LayerNormLinear": + model = LayerNormLinear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + elif module_type == "LayerNormMLP": + # hidden_size == in_features == out_features for simplicity + model = LayerNormMLP( + hidden_size=in_features, ffn_hidden_size=out_features, params_dtype=torch.bfloat16 + ).cuda() + else: + # OpsLinear path + model = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Single factory: map roles to quantizers + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + # Execute with custom recipe + with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Basic sanity: gradients exist + assert inp.grad is not None + + +def test_custom_recipe_grouped_linear_sanity(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + num_gemms = 3 + in_features = 64 + out_features = 64 + batch = 32 + base = batch // num_gemms + rem = batch % num_gemms + m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)] + + model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda() + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + out = model(inp, m_splits) + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None + + +def test_custom_recipe_matches_current_scaling(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(123) + + in_features = 64 + out_features = 64 + batch = 32 + + # Create two identical models + model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + model_custom = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + model_custom.load_state_dict(model_ref.state_dict()) + + # Identical inputs for both paths + base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16) + inp_ref = base_inp.clone().detach().requires_grad_(True) + inp_custom = base_inp.clone().detach().requires_grad_(True) + + # Reference: use Float8CurrentScaling recipe + ref_recipe = recipe.Float8CurrentScaling() + with fp8_autocast(enabled=True, fp8_recipe=ref_recipe): + out_ref = model_ref(inp_ref) + # Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd) + ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3 + assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3 + assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3 + assert ref_bwd_go.dtype == tex.DType.kFloat8E5M2 + assert ref_bwd_gi.dtype == tex.DType.kFloat8E5M2 + + # Stress dynamic range in grad_output + scale = torch.ones(out_features, device="cuda", dtype=torch.float32) + scale[0] = 1e8 + scale[1] = 1e-8 + loss_ref = (out_ref.float() * scale.view(1, -1)).sum() + loss_ref.backward() + + # Custom: single factory returning quantizers per role to match Float8CurrentScaling + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + out_custom = model_custom(inp_custom) + # Assert dtypes for custom quantizers match reference mapping + cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3 + assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 + assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 + assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2 + + loss_custom = (out_custom.float() * scale.view(1, -1)).sum() + loss_custom.backward() + + # Compare forward outputs (exact match expected) + assert torch.allclose(out_ref, out_custom, rtol=0.0, atol=0.0) + + # Compare input gradients + assert inp_ref.grad is not None and inp_custom.grad is not None + assert torch.allclose(inp_ref.grad, inp_custom.grad, rtol=0.0, atol=0.0) + + # Compare parameter gradients (weights and bias if present) + ref_params = dict(model_ref.named_parameters()) + custom_params = dict(model_custom.named_parameters()) + for name, p_ref in ref_params.items(): + p_cus = custom_params[name] + assert p_ref.grad is not None and p_cus.grad is not None + assert torch.allclose(p_ref.grad, p_cus.grad, rtol=0.0, atol=0.0) + + +def test_custom_recipe_ops_linear_2_1_layout(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(7) + + in_features = 64 + out_features = 64 + batch = 16 + + # Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer + op = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom = recipe.CustomRecipe(qfactory=quantizer_factory) + + with fp8_autocast(enabled=True, fp8_recipe=custom): + out = op(inp) + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None + + +def test_custom_recipe_factory_invocation_counts_and_cycling(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(13) + + in_features = 64 + out_features = 64 + batch = 8 + + op = Linear(in_features, out_features, params_dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Counters per role + counts = { + "linear_input": 0, + "linear_weight": 0, + "linear_output": 0, + "linear_grad_output": 0, + "linear_grad_input": 0, + } + + def quantizer_factory(role): + if role in counts: + counts[role] += 1 + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + + custom = recipe.CustomRecipe(qfactory=quantizer_factory) + + # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), + # and backward to build 2 quantizers (cycled from 1 factory). + with fp8_autocast(enabled=True, fp8_recipe=custom): + out = op(inp) + loss = out.float().sum() + loss.backward() + + # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input + assert counts["linear_input"] == 1 + assert counts["linear_weight"] == 1 + assert counts["linear_output"] == 1 + assert counts["linear_grad_output"] == 1 + assert counts["linear_grad_input"] == 1 + + +def test_factories_return_distinct_instances_and_buffers(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + # Two calls should produce distinct quantizer objects and distinct tensor buffers + def factory(): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + + q1 = factory() + q2 = factory() + + assert q1 is not q2 + assert q1.scale.data_ptr() != q2.scale.data_ptr() + assert q1.amax.data_ptr() != q2.amax.data_ptr() + + # Mutating one should not affect the other + q1.scale.fill_(123.0) + assert not torch.equal(q1.scale, q2.scale) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 179d618b35..324b5d50c8 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -6,7 +6,8 @@ from __future__ import annotations import os from enum import Enum -from typing import Literal, Optional, Union, Callable, NamedTuple +from typing import Any, Literal, Optional, Union, Callable, NamedTuple +from dataclasses import field from pydantic.dataclasses import dataclass @@ -111,6 +112,10 @@ def float8_block_scaling(self): """Whether the given recipe is float8 blockwise scaling.""" return isinstance(self, Float8BlockScaling) + def custom(self): + """Whether the given recipe is custom.""" + return isinstance(self, CustomRecipe) + @dataclass() class DelayedScaling(Recipe): @@ -377,7 +382,6 @@ def __repr__(self) -> str: ) -@dataclass() class NVFP4BlockScaling(Recipe): """ Use the NVFP4 scaling strategy. @@ -456,3 +460,37 @@ def __repr__(self) -> str: f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " ) + + +@dataclass() +class CustomRecipe(Recipe): + """ + Custom recipe that allows users to provide quantizer factories. + + .. warning:: + **EXPERIMENTAL**: Custom recipe is experimental, still under active development, + and the API is subject to change without notice. Use at your own risk. + + Parameters + ---------- + qfactory : Callable + Factory callable that returns a quantizer instance for a + given semantic tensor role. + The callable is typically invoked as: + qfactory( + role: str, + ) + + Where `role` is one of the following strings for e.g. te.Linear + (stable public contract): + - forward: "linear_input", "linear_weight", "linear_output" + - backward: "linear_grad_output", "linear_grad_input" + """ + + qfactory: Callable[..., Any] + + fp8_dpa: bool = False + fp8_mha: bool = False + + def __repr__(self) -> str: + return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 7ba2f9f771..5d721d9969 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -15,8 +15,8 @@ from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase -from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params @@ -123,7 +123,7 @@ def inspect_tensor( """API call used to collect the data about the tensor before process_tensor()/quantization.""" assert ( - type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase] + type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage] and tensor.dtype != torch.uint8 ), ( f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using" diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index d564ca8e92..185bf15d05 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, Quantizer, - QuantizedTensorBase, + QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) @@ -557,7 +557,7 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): self._update_parent_quantizer_usage() -class DebugQuantizedTensor(QuantizedTensorBase): +class DebugQuantizedTensor(QuantizedTensorStorage): """ Class containing quantized tensors after debug. Depending on configuration it can contain one or two different objects. These objects can be accessed by the method diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 3bdbe4089e..3256512b5c 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -56,6 +56,21 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch.tensor import Quantizer +from transformer_engine.pytorch.tensor import Float8Quantizer +from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor import QuantizedTensorStorage +from transformer_engine.pytorch.tensor import Float8TensorStorage +from transformer_engine.pytorch.tensor import MXFP8TensorStorage +from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor import Float8Tensor +from transformer_engine.pytorch.tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor import prepare_for_saving +from transformer_engine.pytorch.tensor import restore_from_saved try: torch._dynamo.config.error_on_nested_jit_trace = False diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d330e023ea..a45fafb68a 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -12,7 +12,7 @@ from ..utils import get_sm_count, _empty_tensor from ..tensor.quantized_tensor import Quantizer -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.utils import is_experimental from ..experimental.gemm import experimental_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -107,9 +107,9 @@ def general_gemm( # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] - if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage): # There is not use_split_accumulator == False - # implementation for Float8BlockwiseQTensorBase GEMM + # implementation for Float8BlockwiseQTensorStorage GEMM use_split_accumulator = True # Check that data format is supported diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 179c80a656..9378774ea8 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -10,7 +10,7 @@ import torch from transformer_engine.debug.pytorch.debug_state import TEDebugState -from .tensor.quantized_tensor import QuantizedTensorBase +from .tensor.quantized_tensor import QuantizedTensorStorage from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] @@ -34,7 +34,7 @@ def mark_activation_offload(*tensors): if tensor is not None: tensor.activation_offloading = True # This is a hack to force clear the tensor after it is offloaded. - # It is needed, because .*TensorBase classes are saved in the ctx, + # It is needed, because .*TensorStorage classes are saved in the ctx, # and they contain the reference to their data tensors. tensor.needs_force_clear = True @@ -362,7 +362,7 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: ), ) - is_quantized_tensor = isinstance(tensor, QuantizedTensorBase) + is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage) if not torch_stray_tensor: @@ -514,7 +514,7 @@ def synchronize_on_group_commit_forward(self, current_group): if tensor_tag[0] == self.offloaded_group_count: if hasattr(tensor_buf, "needs_force_clear"): # Need to clear activation tensor - sometimes references persist in the code. - # This is the case for example with the Float8TensorBase class, + # This is the case for example with the Float8TensorStorage class, # which is saved directly inside the ctx while its internal tensors are # saved inside save_for_backward. tensor_buf.data = torch.Tensor() diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2c1edae4c6..b6e9ef828c 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -314,7 +314,7 @@ std::tuple, std::vector> bulk_allocate_fp // Construct FP8 block-wise tensors py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); for (size_t i = 0; i < num_tensors; ++i) { // Create tensor objects with proper reference counting py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); @@ -461,7 +461,7 @@ std::tuple, std::vector> bulk_allocate_mx } // Construct mxfp8 tensors - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); for (size_t i = 0; i < num_tensors; ++i) { // Create tensor objects with proper reference counting py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 382adbfb05..3b81393dbd 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -23,17 +23,17 @@ namespace transformer_engine::pytorch { PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *Float8TensorBasePythonClass = nullptr; +PyTypeObject *Float8TensorStoragePythonClass = nullptr; PyTypeObject *Float8QuantizerClass = nullptr; PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *MXFP8TensorBasePythonClass = nullptr; +PyTypeObject *MXFP8TensorStoragePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; -PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; -PyTypeObject *NVFP4TensorBasePythonClass = nullptr; +PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; void init_float8_extension() { @@ -46,9 +46,9 @@ void init_float8_extension() { Float8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); - Float8TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage"); + Float8TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage")); NVTE_CHECK(Float8TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch Float8 extension."); } @@ -61,29 +61,29 @@ void init_mxfp8_extension() { MXFP8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); - MXFP8TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage"); + MXFP8TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorStorage")); NVTE_CHECK(MXFP8TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch MXFP8 extension."); } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorBasePythonClass) return; + if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( - "transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); + "transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage"); Float8BlockwiseQuantizerClass = reinterpret_cast( PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); - Float8BlockwiseQTensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); + Float8BlockwiseQTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorStorage")); Float8BlockwiseQTensorPythonClass = reinterpret_cast( PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); - NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, + NVTE_CHECK(Float8BlockwiseQTensorStoragePythonClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); @@ -97,9 +97,9 @@ void init_nvfp4_extensions() { NVFP4TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor")); auto nvfp4_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"); - NVFP4TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage"); + NVFP4TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorStorage")); NVTE_CHECK(NVFP4TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch NVFP4 extension."); } diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index f46edaa70e..65665d01b6 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -31,17 +31,17 @@ namespace transformer_engine::pytorch { } while (false); extern PyTypeObject *Float8TensorPythonClass; -extern PyTypeObject *Float8TensorBasePythonClass; +extern PyTypeObject *Float8TensorStoragePythonClass; extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; -extern PyTypeObject *MXFP8TensorBasePythonClass; +extern PyTypeObject *MXFP8TensorStoragePythonClass; extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass; -extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; +extern PyTypeObject *Float8BlockwiseQTensorStoragePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; -extern PyTypeObject *NVFP4TensorBasePythonClass; +extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; void init_extension(); @@ -55,13 +55,13 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { } inline bool IsFloat8Tensor(PyObject *obj) { - return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; + return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorStoragePythonClass; } inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } inline bool IsMXFP8Tensor(PyObject *obj) { - return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; + return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorStoragePythonClass; } inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { @@ -72,11 +72,11 @@ inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4Quant inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || - Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; + Py_TYPE(obj) == Float8BlockwiseQTensorStoragePythonClass; } inline bool IsNVFP4Tensor(PyObject *obj) { - return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass; + return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass; } TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 8470466aef..42ae658f2a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -152,7 +152,7 @@ std::pair Float8Quantizer::create_tensor( // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); @@ -357,7 +357,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); @@ -630,7 +630,7 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); ret = Float8BlockwiseQTensorClass( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, @@ -950,7 +950,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, @@ -1230,7 +1230,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Construct Python NVFP4 tensor py::object out_py; if (internal) { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorBasePythonClass)); + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); out_py = NVFP4TensorClass( "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 3ab0717d0d..c001e8e79a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -41,11 +41,11 @@ from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer -from .tensor.quantized_tensor import QuantizedTensorBase, QuantizedTensor, Quantizer -from .tensor._internal.float8_tensor_base import Float8TensorBase -from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from .tensor._internal.nvfp4_tensor_base import NVFP4TensorBase -from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer +from .tensor.storage.float8_tensor_storage import Float8TensorStorage +from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .triton.pad import pad_columnwise_scale_inv from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer @@ -907,7 +907,7 @@ def _all_gather_fp8( async_op: bool = False, quantizer: Optional[Quantizer] = None, out_shape: Optional[list[int]] = None, -) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]: """All-gather FP8 tensor along first dimension.""" world_size = get_distributed_world_size(process_group) @@ -925,7 +925,7 @@ def _all_gather_fp8( # Cast input tensor to FP8 if needed # Note: We cannot directly all-gather the transposed FP8 tensor, # so temporarily modify quantizer to avoid creating FP8 transpose. - if not isinstance(inp, Float8TensorBase): + if not isinstance(inp, Float8TensorStorage): assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) # we cannot directly gather the transposed fp8 tensor # so we need to disable columnwise usage for the quantizer @@ -940,7 +940,7 @@ def _all_gather_fp8( ) # Construct output tensor - out: Float8TensorBase + out: Float8TensorStorage if quantizer is not None: dtype = torch.float32 device = "cuda" @@ -958,7 +958,7 @@ def _all_gather_fp8( out._transpose = None out._transpose_invalid = True else: - raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + raise RuntimeError("Float8TensorStorage is not supported yet without Quantizer") # Assume scaling factors are identical across ranks out._scale_inv = inp._scale_inv @@ -1003,10 +1003,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: def _post_process_fp8_blockwise_gather( - out: Float8BlockwiseQTensorBase, + out: Float8BlockwiseQTensorStorage, quantizer: Float8BlockQuantizer, handle: Optional[torch.distributed.Work] = None, -) -> Float8BlockwiseQTensorBase: +) -> Float8BlockwiseQTensorStorage: """Post-process FP8 blockwise gather.""" if handle is not None: handle.wait() @@ -1040,7 +1040,7 @@ def _post_process_fp8_blockwise_gather( class _FP8BlockwiseAllGatherAsyncHandle: """Handle for asynchronous FP8 blockwise all-gather.""" - tensor: Float8BlockwiseQTensorBase + tensor: Float8BlockwiseQTensorStorage quantizer: Float8BlockQuantizer async_handle: torch.distributed.Work _synchronized: bool = False @@ -1078,18 +1078,18 @@ def _all_gather_fp8_blockwise( if isinstance(inp, torch.Tensor): device = inp.device dtype = inp.dtype - elif isinstance(inp, Float8BlockwiseQTensorBase): + elif isinstance(inp, Float8BlockwiseQTensorStorage): if inp._rowwise_data is not None: device = inp._rowwise_data.device elif inp._columnwise_data is not None: device = inp._columnwise_data.device else: - raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") + raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data") dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " - f"found {inp.__class__.__name__})" + "Invalid type for input tensor (expected torch.Tensor or" + f" Float8BlockwiseQTensorStorage, found {inp.__class__.__name__})" ) world_size = get_distributed_world_size(process_group) @@ -1106,7 +1106,7 @@ def _all_gather_fp8_blockwise( # Doing BF16 gather for now as baseline because it's simpler if ( - not isinstance(inp, Float8BlockwiseQTensorBase) + not isinstance(inp, Float8BlockwiseQTensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1131,7 +1131,7 @@ def _all_gather_fp8_blockwise( # Set to compact usage in case the quantizer is not correctly configured orig_all_gather_usage = quantizer.all_gather_usage quantizer.all_gather_usage = True - if not isinstance(inp, Float8BlockwiseQTensorBase): + if not isinstance(inp, Float8BlockwiseQTensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1228,12 +1228,12 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): def _post_process_nvfp4_gather( - out: NVFP4TensorBase, + out: NVFP4TensorStorage, columnwise_data_interleaved: torch.Tensor, columnwise_scale_inv_interleaved: torch.Tensor, world_size: int, handle: Optional[torch.distributed.Work] = None, -) -> NVFP4TensorBase: +) -> NVFP4TensorStorage: """Post-process FP8 blockwise gather.""" if handle is not None: handle.wait() @@ -1251,7 +1251,7 @@ def _post_process_nvfp4_gather( class _NVFP4AllGatherAsyncHandle: """Handle for asynchronous NVFP4 all-gather.""" - output: NVFP4TensorBase + output: NVFP4TensorStorage columnwise_data_interleaved: torch.Tensor columnwise_scale_inv_interleaved: torch.Tensor world_size: int @@ -1279,7 +1279,7 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, -) -> tuple[NVFP4TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: """All-gather NVFP4 tensor along first dimension.""" # Input tensor attributes @@ -1289,7 +1289,7 @@ def _all_gather_nvfp4( dtype: torch.dtype # Construct packed shapes for input and input_t. - if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorBase): + if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorStorage): # High-precision tensor. in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size()) in_shape_t = NVFP4Quantizer.convert_shape_for_fp4( @@ -1297,7 +1297,7 @@ def _all_gather_nvfp4( ) device = inp.device dtype = inp.dtype - elif isinstance(inp, NVFP4TensorBase): + elif isinstance(inp, NVFP4TensorStorage): if inp._rowwise_data is not None: in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device @@ -1307,7 +1307,7 @@ def _all_gather_nvfp4( dtype = torch.bfloat16 else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, " + "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, " f"found {inp.__class__.__name__})" ) @@ -1321,7 +1321,7 @@ def _all_gather_nvfp4( # For cases where inp has dimensions that cannot be quantized, # we gather in high precision followed by a cast to NVFP4. if ( - not isinstance(inp, NVFP4TensorBase) + not isinstance(inp, NVFP4TensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1336,7 +1336,7 @@ def _all_gather_nvfp4( return out, None # Cast input tensor to NVFP4 with required data - if not isinstance(inp, NVFP4TensorBase): + if not isinstance(inp, NVFP4TensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1453,7 +1453,7 @@ def _all_gather_mxfp8( async_op: bool = False, quantizer: MXFP8Quantizer, out_shape: Optional[list[int]] = None, -) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: """All-gather MXFP8 tensor along first dimension.""" # Input tensor attributes @@ -1464,7 +1464,7 @@ def _all_gather_mxfp8( in_shape = inp.size() device = inp.device dtype = inp.dtype - elif isinstance(inp, MXFP8TensorBase): + elif isinstance(inp, MXFP8TensorStorage): if inp._rowwise_data is not None: in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device @@ -1476,7 +1476,7 @@ def _all_gather_mxfp8( dtype = torch.bfloat16 # Guess high-precision dtype. else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " + "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, " f"found {inp.__class__.__name__})" ) @@ -1488,7 +1488,7 @@ def _all_gather_mxfp8( # For cases where inp has dimensions that cannot be quantized, # we gather in high precision followed by a cast to FP8. if ( - not isinstance(inp, MXFP8TensorBase) + not isinstance(inp, MXFP8TensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1503,7 +1503,7 @@ def _all_gather_mxfp8( return out, None # Cast input tensor to MXFP8 with required data - if not isinstance(inp, MXFP8TensorBase): + if not isinstance(inp, MXFP8TensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1587,7 +1587,7 @@ def gather_along_first_dim( # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - if quantizer is not None and not isinstance(inp, QuantizedTensorBase): + if quantizer is not None and not isinstance(inp, QuantizedTensorStorage): inp = quantizer(inp) return inp, None @@ -1634,7 +1634,7 @@ def gather_along_first_dim( out_shape[0] *= world_size # FP8 case: delayed scaling or current scaling - if isinstance(inp, Float8TensorBase) or isinstance( + if isinstance(inp, Float8TensorStorage) or isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): return _all_gather_fp8( @@ -1646,7 +1646,9 @@ def gather_along_first_dim( ) # FP8 block scaling case, block length = 128 - if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance( + quantizer, Float8BlockQuantizer + ): return _all_gather_fp8_blockwise( inp, process_group, @@ -1656,7 +1658,7 @@ def gather_along_first_dim( ) # MXFP8 case - if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): + if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) return _all_gather_mxfp8( inp, @@ -1667,7 +1669,7 @@ def gather_along_first_dim( ) # NVFP4 case - if isinstance(inp, NVFP4TensorBase) or isinstance(quantizer, NVFP4Quantizer): + if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer): assert isinstance(quantizer, NVFP4Quantizer) return _all_gather_nvfp4( inp, @@ -1683,7 +1685,7 @@ def gather_along_first_dim( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - if isinstance(inp, QuantizedTensorBase): + if isinstance(inp, QuantizedTensorStorage): inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format @@ -1701,7 +1703,7 @@ def gather_along_first_dim( return out, None # Dequantize quantized tensor if not supported - if isinstance(inp, QuantizedTensorBase): + if isinstance(inp, QuantizedTensorStorage): warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." diff --git a/transformer_engine/pytorch/experimental/quantization.py b/transformer_engine/pytorch/experimental/quantization.py index 9adf4dabf8..7d573abac8 100644 --- a/transformer_engine/pytorch/experimental/quantization.py +++ b/transformer_engine/pytorch/experimental/quantization.py @@ -13,7 +13,7 @@ import torch from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase, Quantizer +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer from transformer_engine.pytorch.experimental import utils @@ -36,7 +36,7 @@ class MMParams: @dataclasses.dataclass -class ExperimentalQuantizedTensor(QuantizedTensorBase): +class ExperimentalQuantizedTensor(QuantizedTensorStorage): """Base class for experimental quantized tensor containers. An experimental container to hold quantization result, including quantized tensor, optional @@ -187,7 +187,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - ) -> QuantizedTensorBase: + ) -> QuantizedTensorStorage: raise NotImplementedError( f"{self.__class__.__name__} class does not implement make_empty function" ) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 15017913fe..a62e10bc57 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -22,6 +22,7 @@ Float8CurrentScaling, Float8BlockScaling, NVFP4BlockScaling, + CustomRecipe, ) from .constants import dist_group_type @@ -866,6 +867,8 @@ def create( cls = Float8BlockScalingRecipeState elif recipe.nvfp4(): cls = NVFP4BlockScalingRecipeState + elif recipe.custom(): + cls = CustomRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( @@ -1191,3 +1194,56 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: ] raise RuntimeError(f"Unexpected recipe mode ({self.mode})") + + +class CustomRecipeState(RecipeState): + """State for CustomRecipe: produce quantizers per tensor.""" + + recipe: CustomRecipe + mode: str + num_quantizers: int + device: Optional[torch.device] + + def __init__( + self, + recipe: CustomRecipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + if device is None: + device = torch.device("cuda") + self.device = device + + if getattr(recipe, "qfactory", None) is None: + raise ValueError("CustomRecipe requires `qfactory`.") + + def make_quantizers(self) -> list: + qfactory = self.recipe.qfactory + out = [] + + # TODO(negvet): make_quantizers() should take roles from the operation + # Hardcode linear-specific roles for now + roles: List[str] + if self.mode == "forward": + roles = [ + ("linear_input", "linear_weight", "linear_output")[i % 3] + for i in range(self.num_quantizers) + ] + elif self.mode == "backward": + roles = [ + ("linear_grad_output", "linear_grad_input")[i % 2] + for i in range(self.num_quantizers) + ] + else: + roles = ["unknown"] * self.num_quantizers + + for i in range(self.num_quantizers): + # Get quantizer from the user defined factory + quantizer = qfactory(roles[i]) + out.append(quantizer) + return out diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index bf4fb97d2d..d60ff80593 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -38,15 +38,15 @@ _fsdp_gather_tensors, ) from ..constants import dist_group_type -from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer +from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..tensor._internal.float8_tensor_base import Float8TensorBase -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor.storage.float8_tensor_storage import Float8TensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -505,7 +505,7 @@ def fill_userbuffers_buffer_for_all_gather( local_tensor: torch.Tensor, quantizer: Optional[Quantizer], process_group, -) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]: +) -> tuple[torch.Tensor | QuantizedTensorStorage, torch.Tensor | QuantizedTensorStorage]: """Fill local shard of Userbuffers buffer with data for all-gather Returns the full tensor and the local shard, both using the @@ -529,7 +529,7 @@ def fill_userbuffers_buffer_for_all_gather( # Unquantized data if quantizer is None: - if isinstance(local_tensor, QuantizedTensorBase): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor = local_tensor.dequantize() if comm.is_fp8_ubuf(): raise RuntimeError( @@ -542,8 +542,8 @@ def fill_userbuffers_buffer_for_all_gather( # FP8 data if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): - if not isinstance(local_tensor, Float8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): + if not isinstance(local_tensor, Float8TensorStorage): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor.dequantize() quantizer.set_usage(rowwise=True, columnwise=False) local_tensor = quantizer(local_tensor) @@ -554,7 +554,7 @@ def fill_userbuffers_buffer_for_all_gather( ) comm.copy_into_buffer(local_tensor._data, local_chunk=True) global_tensor_data = comm.get_buffer(shape=global_shape) - global_tensor = Float8TensorBase( + global_tensor = Float8TensorStorage( data=global_tensor_data, fp8_scale_inv=local_tensor._scale_inv, fp8_dtype=local_tensor._fp8_dtype, @@ -566,8 +566,8 @@ def fill_userbuffers_buffer_for_all_gather( if isinstance(quantizer, MXFP8Quantizer): # Cast to MXFP8 if needed - if not isinstance(local_tensor, MXFP8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): + if not isinstance(local_tensor, MXFP8TensorStorage): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor.dequantize() local_tensor = quantizer(local_tensor) if not comm.is_fp8_ubuf(): @@ -622,7 +622,7 @@ def fill_userbuffers_buffer_for_all_gather( rowwise_data, rowwise_scale_inv = global_data, global_scale_inv else: columnwise_data, columnwise_scale_inv = global_data, global_scale_inv - global_tensor = MXFP8TensorBase( + global_tensor = MXFP8TensorStorage( rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, @@ -786,10 +786,10 @@ def _update_weight_quantizers(self) -> None: f"({len(weight_quantizers)}) must match" ) for weight, quantizer in zip(weight_tensors, weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): + if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_quantizer(quantizer) - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement _get_weight_tensors function" @@ -1038,8 +1038,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(self.fp8_meta["recipe"], "fp8_format"): + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) @@ -1170,9 +1171,9 @@ def grad_output_preprocess( grad_output, ( QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, ), ): grad_output = quantizer(grad_output) @@ -1201,9 +1202,9 @@ def grad_output_preprocess( grad_output_.get_tensor(True), ( QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, ), ) and ctx.use_bias @@ -1219,7 +1220,12 @@ def grad_output_preprocess( if ctx.use_bias: if isinstance( grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ( + QuantizedTensor, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, + ), ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: @@ -1229,7 +1235,7 @@ def grad_output_preprocess( grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, QuantizedTensorBase): + if not isinstance(grad_output, QuantizedTensorStorage): grad_output = quantizer(grad_output) return grad_output, grad_bias @@ -1383,14 +1389,14 @@ def get_weight_workspace( # Reset cache if workspace is invalid if out is not None and quantizer is not None: reset_cache = False - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): if ( not is_non_tn_fp8_gemm_supported() and quantizer.columnwise_usage and out._transpose is None ): reset_cache = True - elif isinstance(out, MXFP8TensorBase): + elif isinstance(out, MXFP8TensorStorage): if quantizer.rowwise_usage and out._rowwise_data is None: reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: @@ -1581,7 +1587,7 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, QuantizedTensorBase): + if isinstance(tensor, QuantizedTensorStorage): quantizer = tensor._get_quantizer() if quantizer is None: continue diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 5749d96c9f..b3adfb7dbf 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -44,7 +44,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.quantized_tensor import ( - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -200,13 +200,13 @@ def forward( inputmats[0] = inp else: for inputmat in inputmats: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms if inp.requires_grad: for weight in weights_fp8: - if isinstance(weight, QuantizedTensorBase): + if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) tensors_to_save, tensor_objects = prepare_for_saving( @@ -338,7 +338,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) for weight, quantizer in zip(weights, ctx.weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): + if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_usage( rowwise_usage=quantizer.rowwise_usage, columnwise_usage=quantizer.columnwise_usage, @@ -734,7 +734,7 @@ def forward( produced) """ assert not isinstance( - inp, QuantizedTensorBase + inp, QuantizedTensorStorage ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." @@ -868,16 +868,17 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] - if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors): + if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " "Please make sure this is intentional." ) weight_tensors = [ - w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors + w.dequantize() if isinstance(w, QuantizedTensorStorage) else w + for w in weight_tensors ] return weight_tensors diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6dbbd335eb..e1c0eab2dc 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -58,7 +58,7 @@ from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers from ..tensor.quantized_tensor import ( QuantizedTensor, - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -66,8 +66,8 @@ from ...debug.pytorch.debug_state import TEDebugState from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload @@ -200,7 +200,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not experimental + and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) # Apply normalization @@ -278,7 +278,7 @@ def forward( weightmat = weight quantized_weight = False if fp8 or debug: - quantized_weight = not isinstance(weight, QuantizedTensorBase) + quantized_weight = not isinstance(weight, QuantizedTensorStorage) # Configure quantizer if weight_quantizer is not None: @@ -403,18 +403,18 @@ def forward( # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: - if isinstance(ln_out, QuantizedTensorBase): + if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data # can be allgathered. if ( - isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) + isinstance(ln_out, (MXFP8TensorStorage, Float8BlockwiseQTensorStorage)) or not ctx.ln_out_needs_gather ): ln_out.update_usage(rowwise_usage=False) # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading: @@ -685,9 +685,9 @@ def backward( # -------------------------------------------------- # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -806,14 +806,14 @@ def backward( ln_out_total_work.wait() ln_out_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(ln_out_total, QuantizedTensorBase): + if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -999,7 +999,7 @@ def wgrad_gemm( nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - # if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + # if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( @@ -1790,7 +1790,7 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a0e5f3aedd..2097f01b1d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -71,7 +71,7 @@ from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..tensor.quantized_tensor import ( - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -116,9 +116,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "swiglu": (tex.swiglu, tex.dswiglu, None), } # no activation fusion written yet - # Per-tensor current scaling or fp8 blockwise scaling: [] + # Per-tensor current scaling or fp8 blockwise scaling or custom quantization: [] # TODO(ksivaman): Fuse nvfp4 act once kernel is available. - if recipe.float8_current_scaling() or recipe.float8_block_scaling() or recipe.nvfp4(): + if ( + recipe.float8_current_scaling() + or recipe.float8_block_scaling() + or recipe.nvfp4() + or recipe.custom() + ): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), @@ -448,10 +453,18 @@ def forward( act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs - if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): - # tex.quantize does not support GELU fusion for blockwise. - act_out = activation_func(fc1_out, None) - act_out = tex.quantize(act_out, fc2_input_quantizer) + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_block_scaling(): + # tex.quantize does not support GELU fusion for blockwise + act_out = activation_func(fc1_out, None) + act_out = tex.quantize(act_out, fc2_input_quantizer) + elif recipe.custom(): + # tex.quantize does not support custom quantizers + act_out = activation_func(fc1_out, None) + act_out = fc2_input_quantizer(act_out) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) else: if fp8_calibration: act_out = activation_func(fc1_out, None) @@ -522,9 +535,9 @@ def forward( if is_grad_enabled: # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(fc1_weight_final, QuantizedTensorBase): + if isinstance(fc1_weight_final, QuantizedTensorStorage): fc1_weight_final.update_usage(columnwise_usage=True) - if isinstance(fc2_weight_final, QuantizedTensorBase): + if isinstance(fc2_weight_final, QuantizedTensorStorage): fc2_weight_final.update_usage(columnwise_usage=True) if cpu_offloading: @@ -823,10 +836,10 @@ def backward( ) # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorBase + ctx.fc2_weight, QuantizedTensorStorage ): ctx.fc2_weight.update_usage(columnwise_usage=True) @@ -908,14 +921,14 @@ def backward( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.fp8 or ctx.debug: - if isinstance(act_out, QuantizedTensorBase): + if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1023,10 +1036,13 @@ def fc2_wgrad_gemm( ) # activation in high precision if ctx.fp8: - # TODO float8 blockwise current scaling has no bgrad fusion for now + # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now # TODO(ksivaman): Re-add fusion once kernel is available. - if isinstance( - ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) + if ( + isinstance( + ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) + ) + or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) @@ -1072,7 +1088,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorBase + ctx.fc1_weight_quantizer, QuantizedTensorStorage ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1143,7 +1159,7 @@ def fc2_wgrad_gemm( ln_out_total_work.wait() ln_out_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(ln_out_total, QuantizedTensorBase): + if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1153,7 +1169,7 @@ def fc2_wgrad_gemm( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.fp8 or ctx.debug: - if isinstance(dact, QuantizedTensorBase): + if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -2153,7 +2169,7 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: tex.FP8BwdTensors.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" return [self.fc1_weight, self.fc2_weight] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cf7f58947b..02872439a3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -59,7 +59,7 @@ from ..graph import is_graph_capturing from ..tensor.quantized_tensor import ( QuantizedTensor, - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -178,7 +178,7 @@ def forward( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorBase) and not experimental: + if not isinstance(inputmat, QuantizedTensorStorage) and not experimental: own_quantized_input = True input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( @@ -216,7 +216,7 @@ def forward( else: # Do not all-gather input tensor if fp8 or debug: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat.update_usage(rowwise_usage=True) else: if input_quantizer is None: @@ -372,7 +372,7 @@ def forward( if ( backward_needs_input and own_quantized_input - and isinstance(inputmat, QuantizedTensorBase) + and isinstance(inputmat, QuantizedTensorStorage) ): if ( ctx.backward_input_needs_gather @@ -391,7 +391,7 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM. if inp.requires_grad: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading and saved_inputmat is not None: @@ -404,7 +404,7 @@ def forward( ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, saved_inputmat, - weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None, + weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -613,7 +613,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.requires_wgrad: if ctx.fp8 or ctx.debug: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass elif ctx.debug or ctx.experimental: @@ -632,7 +632,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) else: inputmat = cast_if_needed(inputmat, ctx.activation_dtype) @@ -677,9 +677,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): + if ctx.weight_quantizer is not None and isinstance( + weight_fp8, QuantizedTensorStorage + ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -763,7 +765,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work.wait() inputmat_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(inputmat_total, QuantizedTensorBase): + if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) @@ -805,7 +807,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -958,7 +960,7 @@ def wgrad_gemm( nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( wgrad, @@ -1524,7 +1526,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad): for name, q in zip(names, original_quantizers) ) - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 99bbc34c45..13db35fc78 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -13,17 +13,17 @@ from .. import torch_version from ..fp8 import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor -from ..tensor.quantized_tensor import QuantizedTensorBase +from ..tensor.quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype -def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool: +def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" - return isinstance(tensor, QuantizedTensorBase) + return isinstance(tensor, QuantizedTensorStorage) def maybe_dequantize( - tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None + tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None ) -> torch.Tensor: """Dequantize tensor to given dtype or just convert if not a quantized tensor""" if is_quantized_tensor(tensor): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index f8f95cf194..844e49ff07 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -29,7 +29,7 @@ ) from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -568,7 +568,7 @@ def _functional_forward( # Prepare input tensor for backward pass if weight_requires_grad: if with_quantized_compute and is_quantized_tensor(x_local): - if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): + if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 30ccf5ebcd..38b2a59a73 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -11,7 +11,7 @@ import transformer_engine_torch as tex from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize from ..op import BasicOperation, OperationContext @@ -56,7 +56,7 @@ def op_forward( out = input_ elif impl == "fused": x = input_ - if not isinstance(x, Float8TensorBase): + if not isinstance(x, Float8TensorStorage): x = maybe_dequantize(x, dtype=dtype) out, mask = tex.dropout_fwd(x, self.dropout_probability) elif impl == "unfused": diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index a604e57dcd..cbbe529d6a 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -23,7 +23,7 @@ ) from ...tensor.quantized_tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_dequantize, is_quantized_tensor from ..basic import BasicLinear, Bias, ReduceScatter from ..op import ( @@ -267,7 +267,7 @@ def _functional_forward( # Prepare input tensor for backward pass if weight_requires_grad: if with_quantized_compute and is_quantized_tensor(x_local): - if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): + if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 43846512d7..7689e20194 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,12 +6,42 @@ import torch -from .quantized_tensor import QuantizedTensor, Quantizer +from .quantized_tensor import ( + QuantizedTensorStorage, + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from .storage.float8_tensor_storage import Float8TensorStorage +from .storage.mxfp8_tensor_storage import MXFP8TensorStorage +from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer +from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ - "QuantizedTensor", "Quantizer", + "Float8Quantizer", + "Float8CurrentScalingQuantizer", + "MXFP8Quantizer", + "Float8BlockQuantizer", + "NVFP4Quantizer", + "QuantizedTensorStorage", + "Float8TensorStorage", + "MXFP8TensorStorage", + "Float8BlockwiseQTensorStorage", + "NVFP4TensorStorage", + "QuantizedTensor", + "Float8Tensor", + "MXFP8Tensor", + "Float8BlockwiseQTensor", + "NVFP4Tensor", + "prepare_for_saving", + "restore_from_saved", ] @@ -48,24 +78,16 @@ def get_all_tensor_types(): """ Get all tensor-like types that can be used in TE. """ - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase - from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase - from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( - Float8BlockwiseQTensor, - Float8BlockwiseQTensorBase, - ) - from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase - all_tensor_types = [ torch.Tensor, torch.nn.Parameter, Float8Tensor, - Float8TensorBase, + Float8TensorStorage, MXFP8Tensor, - MXFP8TensorBase, + MXFP8TensorStorage, Float8BlockwiseQTensor, - Float8BlockwiseQTensorBase, + Float8BlockwiseQTensorStorage, NVFP4Tensor, - NVFP4TensorBase, + NVFP4TensorStorage, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/__init__.py b/transformer_engine/pytorch/tensor/_internal/__init__.py deleted file mode 100644 index e13014bf75..0000000000 --- a/transformer_engine/pytorch/tensor/_internal/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Internal data structures for quantized tensors.""" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0e41fc9c51..16631a3d0d 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -13,8 +13,12 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat from transformer_engine.common.recipe import Float8BlockScaling, Recipe -from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .quantized_tensor import ( + QuantizedTensor, + Quantizer, + _IdentityFunc, +) from ..utils import devices_match, round_up_to_nearest_multiple aten = torch.ops.aten @@ -101,6 +105,10 @@ def update_quantized( dst._fp8_dtype = self.dtype return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: """Calculate the shape of the scaling tensor for blockwise quantization. @@ -270,7 +278,7 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling -class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): +class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. The tensor presents as having a standard, higher-precision dtype, @@ -295,7 +303,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): holds configuration about quantization and dequantization modes. """ - # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, @@ -334,15 +342,6 @@ def __repr__(self, *, tensor_contents=None): f" data_format={self._data_format}" ) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - assert self._quantizer is not None - return self._quantizer - def quantize_( self, tensor: torch.Tensor, @@ -361,8 +360,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize()) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 18750d0392..a4e68e53b0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -13,8 +13,12 @@ from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match -from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func +from .quantized_tensor import ( + QuantizedTensor, + Quantizer, + _IdentityFunc, +) from ..constants import dist_group_type aten = torch.ops.aten @@ -89,6 +93,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def make_empty( self, shape: Iterable[int], @@ -147,7 +155,7 @@ def create_tensor_from_data( torch.float8_e5m2fnuz, ] if internal: - return Float8TensorBase( + return Float8TensorStorage( data=data, fp8_scale_inv=1 / self.scale, fp8_dtype=self.dtype, @@ -271,6 +279,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def make_empty( self, shape: Iterable[int], @@ -333,7 +345,7 @@ def create_tensor_from_data( torch.float8_e5m2fnuz, ] if internal: - return Float8TensorBase( + return Float8TensorStorage( data=data, fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, @@ -388,7 +400,7 @@ def supports_only_rowwise_all_gather(self) -> bool: return True -class Float8Tensor(Float8TensorBase, QuantizedTensor): +class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -443,19 +455,6 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: return _FromFloat8Func.apply(self, dtype) return _FromFloat8Func.forward(None, self, dtype) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - if self._quantizer is not None: - return self._quantizer - # Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling) - raise ValueError( - "Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable" - ) - def quantize_( self, tensor: torch.Tensor, @@ -474,8 +473,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def detach(self) -> Float8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d7f5f8c7d2..700de24c4e 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -16,8 +16,12 @@ from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple -from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func +from .quantized_tensor import ( + QuantizedTensor, + Quantizer, + _IdentityFunc, +) aten = torch.ops.aten @@ -67,6 +71,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" if inp.ndim < 2: @@ -161,14 +169,14 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor) return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) - def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor: + def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor: return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return MXFP8BlockScaling -class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): +class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -192,7 +200,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): """ - # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, @@ -236,17 +244,9 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: return _FromMXFP8Func.apply(self, dtype) return _FromMXFP8Func.forward(None, self, dtype) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - if self._quantizer is not None: - return self._quantizer - return MXFP8Quantizer( - fp8_dtype=self._fp8_dtype, - ) + def _build_default_quantizer(self) -> Optional[Quantizer]: + """Build default quantizer for the tensor""" + return MXFP8Quantizer(fp8_dtype=self._fp8_dtype) def quantize_( self, @@ -266,8 +266,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize()) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def detach(self) -> MXFP8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index b12e89956a..ca2154f554 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -21,7 +21,7 @@ round_up_to_nearest_multiple, ) -from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func +from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc aten = torch.ops.aten @@ -173,6 +173,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" if inp.ndim < 2: @@ -332,7 +336,7 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling -class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): +class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): """Quantized tensor class with FP4 data The tensor presents as having a standard, higher-precision dtype, @@ -365,7 +369,7 @@ class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): Nominal tensor datatype, used in dequantize. """ - # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 7b88d25196..a524d5c8de 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -5,7 +5,7 @@ """Tensor with quantized data""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Any, Dict, Union +from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union import abc import copy import warnings @@ -13,12 +13,11 @@ import torch from torch.utils._pytree import tree_map -import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -class QuantizedTensorBase: - r"""Base class for all *TensorBase classes. +class QuantizedTensorStorage: + r"""Base class for all *TensorStorage classes. This class (and its subclasses) are optimization for when the full QuantizedTensor is not needed (when it is fully @@ -26,9 +25,9 @@ class QuantizedTensorBase: PyTorch's autograd). When creating a new tensor type X one should create both - XTensorBase class inheriting from QuantizedTensorBase and - XTensor inheriting from XTensorBase and QuantizedTensor. - XTensorBase should contain all data members needed to + XTensorStorage class inheriting from QuantizedTensorStorage and + XTensor inheriting from XTensorStorage and QuantizedTensor. + XTensorStorage should contain all data members needed to implement the functionality of the tensor, while XTensor should only implement the functionality needed to behave like regular torch.Tensor (liek __torch_dispatch__).""" @@ -59,7 +58,7 @@ def update_usage( f"{self.__class__.__name__} class does not implement update_usage function" ) - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement prepare_for_saving function" @@ -73,6 +72,30 @@ def restore_from_saved( f"{self.__class__.__name__} class does not implement restore_from_saved function" ) + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return self._build_default_quantizer() + + def _build_default_quantizer(self) -> Quantizer: + """Build default quantizer for the tensor""" + raise ValueError( + f"{self.__class__.__name__} has no quantizer " + "and no default quantizer is available defined in the subclass." + ) + + def quantize_( + self, tensor: torch.Tensor, *, noop_flag: Optional[torch.Tensor] = None + ) -> QuantizedTensor: + """Quantize tensor in-place""" + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + def update_quantizer(self, quantizer: Quantizer): """Update quantizer for the tensor""" if self._quantizer is None: @@ -83,13 +106,13 @@ def update_quantizer(self, quantizer: Quantizer): def prepare_for_saving( - *tensors: Union[torch.Tensor, QuantizedTensorBase], + *tensors: Union[torch.Tensor, QuantizedTensorStorage], ) -> Tuple[ - list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]] + list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]] ]: """Prepare tensors for saving. Needed because save_for_backward accepts only torch.Tensor/torch.nn.Parameter types, while we want to be able to save - the internal TensorBase types too.""" + the internal *TensorStorage types too.""" tensor_list, tensor_objects_list = [], [] for tensor in tensors: @@ -104,12 +127,12 @@ def prepare_for_saving( def restore_from_saved( - tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]], + tensors: list[Optional[Union[torch.Tensor, QuantizedTensorStorage]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], return_saved_tensors: bool = False, ) -> ( - list[Optional[torch.Tensor | QuantizedTensorBase]] - | tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]] + list[Optional[torch.Tensor | QuantizedTensorStorage]] + | tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]] ): """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] @@ -178,7 +201,6 @@ def __repr__(self): ")" ) - @abc.abstractmethod def update_quantized( self, src: torch.Tensor, @@ -187,6 +209,9 @@ def update_quantized( noop_flag: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Quantize tensor in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_quantized" + ) def quantize( self, @@ -199,8 +224,14 @@ def quantize( if out is not None: return self.update_quantized(tensor, out) if (not self.internal) and torch.is_grad_enabled(): - return _QuantizeFunc.apply(tensor, self) - return _QuantizeFunc.forward(None, tensor, self) + return _QuantizeFunc.apply(tensor, self.quantize_impl) + return _QuantizeFunc.forward(None, tensor, self.quantize_impl) + + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement quantize_impl function" + ) def multi_quantize(self, list_of_tensors): """Quantize multiple tensors""" @@ -213,7 +244,6 @@ def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor""" return self.quantize(tensor) - @abc.abstractmethod def make_empty( self, shape: Iterable[int], @@ -222,8 +252,11 @@ def make_empty( device: Optional[torch.device] = None, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement make_empty function, " + "required for construction of unintialized quantized tensor" + ) - @abc.abstractmethod def calibrate(self, tensor: torch.Tensor) -> None: """Calibrate quantizer state @@ -252,13 +285,21 @@ def copy(self) -> Quantizer: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Symbolic function for ONNX export""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement onnx_quantize" + ) def onnx_dequantize(self, tensor) -> torch.Tensor: """Symbolic function for ONNX export""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement onnx_dequantize" + ) - @abc.abstractmethod def _get_compatible_recipe(self) -> Union[type[Recipe], None]: """Returns recipe class that is compatible with this quantizer""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_compatible_recipe" + ) def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" @@ -270,20 +311,21 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-a class _QuantizeFunc(torch.autograd.Function): - """Cast to FP8 from other dtype""" + """Quantize tensor""" @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused tensor: torch.Tensor, - quantizer: Quantizer, + quantize_impl: Callable, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tex.quantize(tensor, quantizer) + return quantize_impl(tensor) @staticmethod def backward( - _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py new file mode 100644 index 0000000000..9cb228f3a7 --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Storage for quantized tensors.""" + +from .float8_tensor_storage import Float8TensorStorage # noqa: F401 +from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 +from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py similarity index 98% rename from transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index da0220eb7a..9040ea3a43 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -13,7 +13,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import Float8BlockScaleTensorFormat -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType_To_Torch @@ -22,7 +22,7 @@ from ...utils import _empty_tensor -class Float8BlockwiseQTensorBase(QuantizedTensorBase): +class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8BlockwiseQTensor. Float8BlockwiseQTensor inherits from the PyTorch tensor class and this @@ -53,7 +53,7 @@ def __new__( *args, **kwargs, ): - if cls is Float8BlockwiseQTensorBase: + if cls is Float8BlockwiseQTensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -98,7 +98,7 @@ def _is_gemm_ready_format(self) -> bool: def prepare_for_saving( self, - ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: + ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]: """ Prepare the tensor base for saving for backward """ @@ -366,7 +366,7 @@ def __repr__(self): data = self.dequantize() descriptor = "columnwise" return ( - "Float8BlockwiseQTensorBase(" + "Float8BlockwiseQTensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"{descriptor}_scaled_data={data}" ) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py similarity index 96% rename from transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 6d48223443..b9533edb6e 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -12,7 +12,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType as torch_to_transformer_engine_dtype @@ -27,7 +27,7 @@ class _FromFloat8Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: Float8TensorBase, + tensor: Float8TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -52,7 +52,7 @@ def backward( return grad, None -class Float8TensorBase(QuantizedTensorBase): +class Float8TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8Tensor. Float8Tensor inherits from the PyTorch tensor class and this mixin @@ -81,7 +81,7 @@ def __new__( quantizer: Optional[Quantizer] = None, **kwargs, ): - if cls is Float8TensorBase: + if cls is Float8TensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -116,7 +116,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [self._data, self._transpose, self._scale_inv] self._data = None @@ -163,7 +163,7 @@ def view(self, shape: torch.Size): if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]: out_transpose = None - return Float8TensorBase( + return Float8TensorStorage( data=out_data, fp8_scale_inv=self._scale_inv, fp8_dtype=self._fp8_dtype, @@ -173,7 +173,7 @@ def view(self, shape: torch.Size): def __repr__(self): return ( - "Float8TensorBase(" + "Float8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " f"data={self.dequantize()}" diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py similarity index 97% rename from transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 5a7dd6b449..c1f30146c9 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -13,7 +13,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType as torch_to_transformer_engine_dtype @@ -28,7 +28,7 @@ class _FromMXFP8Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: MXFP8TensorBase, + tensor: MXFP8TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -49,7 +49,7 @@ def backward( return grad, None -class MXFP8TensorBase(QuantizedTensorBase): +class MXFP8TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of MXFP8Tensor. MXFP8Tensor inherits from the PyTorch tensor class and this mixin @@ -77,7 +77,7 @@ def __new__( *args, **kwargs, ): - if cls is MXFP8TensorBase: + if cls is MXFP8TensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -112,7 +112,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [ self._rowwise_data, @@ -192,7 +192,7 @@ def view(self, shape: torch.Size): if cur_columnwise_data is not None: new_columnwise_data = cur_columnwise_data.view(*shape) - return MXFP8TensorBase( + return MXFP8TensorStorage( rowwise_data=new_rowwise_data, rowwise_scale_inv=self._rowwise_scale_inv, columnwise_data=new_columnwise_data, @@ -205,7 +205,7 @@ def __repr__(self): data_rowwise = self.dequantize() return ( - "MXFP8TensorBase(" + "MXFP8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"rowwise_scaled_data={data_rowwise}" f"rowwise_scale_inv={self._rowwise_scale_inv}, " diff --git a/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py similarity index 98% rename from transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index df187d6741..350103f7ca 100644 --- a/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -16,7 +16,7 @@ # import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage # from ...constants import TE_DType as torch_to_transformer_engine_dtype from ..quantized_tensor import Quantizer @@ -39,7 +39,7 @@ class _FromNVFP4Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: NVFP4TensorBase, + tensor: NVFP4TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -89,7 +89,7 @@ def backward( return grad, None -class NVFP4TensorBase(QuantizedTensorBase): +class NVFP4TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of NVFP4Tensor. NVFP4Tensor inherits from the PyTorch tensor class and this mixin @@ -161,7 +161,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [ self._rowwise_data, @@ -267,7 +267,7 @@ def view(self, shape: torch.Size): new_columnwise_data = self._columnwise_data.view(byte_shape) # Construct tensor - return NVFP4TensorBase( + return NVFP4TensorStorage( rowwise_data=new_rowwise_data, rowwise_scale_inv=self._rowwise_scale_inv, columnwise_data=new_columnwise_data, @@ -282,7 +282,7 @@ def __repr__(self): data_rowwise = self.dequantize() return ( - "NVFP4TensorBase(" + "NVFP4TensorStorage(" f"rowwise_scaled_data={data_rowwise}," f"rowwise_scale_inv={self._rowwise_scale_inv}," f"amax_rowwise={self._amax_rowwise}," diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index a4bdf5e07d..cc02494013 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv -from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase +from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer @@ -454,7 +454,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ) -def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool: +def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: """Check if an environment or object is using experimental Kitchen middleware. Returns False if x is a torch.Tensor. @@ -466,6 +466,6 @@ def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) - # Detect if the object is experimental if isinstance(x, torch.Tensor): return False - if not isinstance(x, (Quantizer, QuantizedTensorBase)): - raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance") + if not isinstance(x, (Quantizer, QuantizedTensorStorage)): + raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance") return hasattr(x, "experimental") and x.experimental diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 1a0722f894..8ea3623713 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -225,13 +225,15 @@ def forward( ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections from transformer_engine.pytorch.float8_tensor import Float8Tensor - from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase + from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( + Float8TensorStorage, + ) - if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance( + if isinstance(mixed_x_layer, Float8TensorStorage) and not isinstance( mixed_x_layer, Float8Tensor ): return tuple( - Float8TensorBase( + Float8TensorStorage( fp8_scale_inv=mixed_x_layer._scale_inv, fp8_dtype=mixed_x_layer._fp8_dtype, data=x.squeeze(split_dim) if squeeze else x, From ac4e0fd63afb1998904695e1321c5631192c3a85 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 1 Oct 2025 10:02:26 -0400 Subject: [PATCH 48/78] [JAX] Rework amax reduction over TPSP (#2218) * rm using_global_amax_of_x Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/quantization.py | 44 +++++++++++++------ transformer_engine/jax/dense.py | 16 +++---- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 021af4c9db..9f9e8fec06 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -551,7 +551,10 @@ class AmaxCalculationPrimitive(BasePrimitive): name = "jax_local_amax" multiple_results = False - impl_static_args = (1,) # amax_scope + impl_static_args = ( + 1, + 2, + ) # amax_scope, batch_sequence_transpose inner_primitive = None outer_primitive = None @@ -560,11 +563,12 @@ def abstract( x_aval, *, amax_scope, + batch_sequence_transpose, ): """ amax calcuation abstract """ - del amax_scope + del amax_scope, batch_sequence_transpose dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -576,17 +580,19 @@ def abstract( def impl( x, amax_scope, + batch_sequence_transpose, ): """ amax calcuation implementation """ - del amax_scope + del amax_scope, batch_sequence_transpose amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) return amax @staticmethod def infer_sharding_from_operands( amax_scope, + batch_sequence_transpose, mesh, arg_infos, result_infos, @@ -594,7 +600,7 @@ def infer_sharding_from_operands( """ amax calcuation infer_sharding_from_operands """ - del (amax_scope, arg_infos, result_infos) # Unused. + del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused. amax_sharding = NamedSharding( mesh, PartitionSpec(None), @@ -605,6 +611,7 @@ def infer_sharding_from_operands( @staticmethod def partition( amax_scope, + batch_sequence_transpose, mesh, arg_infos, result_infos, @@ -613,25 +620,26 @@ def partition( amax calcuation partition """ del result_infos - + x_spec = get_padded_spec(arg_infos[0]) amax_sharding = NamedSharding( mesh, PartitionSpec(None), - desc="AmaxCalculationPrimitive.out_sharding", + desc="AmaxCalculation.amax_sharding", ) def sharded_impl(x): amax = AmaxCalculationPrimitive.impl( x, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) - if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP - gmesh = global_mesh_resource() - amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh) + gmesh = global_mesh_resource() + sequence_dim = 0 if batch_sequence_transpose else 1 + # Run AR across TPSP only when tensor-sequence is detected in the input spec + if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource: amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) - - if amax_scope is AmaxScope.FSDP: # Run AR across FSDP - gmesh = global_mesh_resource() + # Run AR across FSDP + if amax_scope is AmaxScope.FSDP: amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) return amax @@ -640,11 +648,11 @@ def sharded_impl(x): return mesh, sharded_impl, amax_sharding, arg_shardings @staticmethod - def shardy_sharding_rule(amax_scope, mesh, value_types, result_types): + def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types): """ amax calcuation shardy_sharding_rule """ - del amax_scope, mesh, result_types + del amax_scope, batch_sequence_transpose, mesh, result_types prefix = "AmaxCal" input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) output_spec = (f"{prefix}_amax",) @@ -701,6 +709,7 @@ def _quantize_dbias_impl( dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling + batch_sequence_transpose: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -745,6 +754,8 @@ def _quantize_dbias_impl( quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis, + amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias @@ -760,6 +771,7 @@ def _quantize_dbias_impl( amax = AmaxCalculationPrimitive.outer_primitive.bind( x.data, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: @@ -833,6 +845,7 @@ def quantize( quantizer: Quantizer, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, + batch_sequence_transpose: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -853,6 +866,7 @@ def quantize( quantizer=quantizer, flatten_axis=flatten_axis, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) return out @@ -863,6 +877,7 @@ def quantize_dbias( is_dbias: bool = True, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, + batch_sequence_transpose: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -889,6 +904,7 @@ def quantize_dbias( is_dbias=is_dbias, flatten_axis=flatten_axis, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 23df1a0ce2..3cdf6ba7a1 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -67,7 +67,6 @@ def dense( input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, output_axes: Tuple[str, ...] = None, - using_global_amax_of_x: bool = False, collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, ): @@ -86,7 +85,6 @@ def dense( input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix output_axes: Logical axes for sharding the output - using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types @@ -109,14 +107,13 @@ def dense( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) def _dense( x, kernel, @@ -126,7 +123,6 @@ def _dense( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, # need to be a diff_arg for DelayedScaling state management ): @@ -144,7 +140,6 @@ def _dense( input_axes: Logical axes for sharding the activation input output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix - using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types @@ -160,7 +155,6 @@ def _dense( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, ) @@ -176,7 +170,6 @@ def _dense_fwd_rule( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, ): @@ -203,7 +196,8 @@ def _dense_fwd_rule( x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, - amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL, + amax_scope=AmaxScope.TPSP, + batch_sequence_transpose=batch_sequence_transpose, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -250,7 +244,6 @@ def _dense_bwd_rule( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, ctx, grad, @@ -280,7 +273,8 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, - amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP, + amax_scope=AmaxScope.TPSP, + batch_sequence_transpose=batch_sequence_transpose, ) # GEMM NT From b0d562d8ac3f0ce36131471ae03d87b90a797e6f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 1 Oct 2025 10:13:40 -0400 Subject: [PATCH 49/78] [JAX] Fix `rng_state` shape in fused attention (#2217) fix rng_state shape Signed-off-by: Phuong Nguyen Co-authored-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 625f42049f..db2537c38f 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1820,7 +1820,7 @@ def ring_attn_fwd_impl( # RNG shape should be the shared shape. This is unused for ring attention as we do not # support dropout currently. - rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:]) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): @@ -2306,7 +2306,7 @@ def fwd_impl( # RNG shape should be the shared shape. This is unused for ring attention as we do not # support dropout currently. - rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:]) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): From ac886c3594a80e05ad6682b13e3099e3bdc8248d Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 1 Oct 2025 19:32:16 +0200 Subject: [PATCH 50/78] [PyTorch] Fix QuantizedTensorBase -> QuantizedTensorStorage (#2226) Fix QuantizedTensorBase -> QuantizedTensorStorage Signed-off-by: Evgeny --- .../attention/dot_product_attention/backends.py | 6 +++--- .../dot_product_attention/context_parallel.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index f72c1eb9e0..3a13758382 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -25,7 +25,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.quantized_tensor import ( - QuantizedTensorBase, + QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) @@ -1312,7 +1312,7 @@ def backward(ctx, d_out): # d_out is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase): + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): d_out = ctx.dO_quantizer(d_out) if not ctx.use_FAv2_bwd: d_out._data = d_out._data.contiguous() @@ -1479,7 +1479,7 @@ def backward(ctx, d_out): ctx.dP_quantizer, ) else: - if isinstance(d_out, QuantizedTensorBase): + if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 539caffbb9..d0ddae25ef 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -21,7 +21,7 @@ ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( dist_group_type, @@ -1823,7 +1823,7 @@ def backward(ctx, dout): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorBase): + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -1997,7 +1997,7 @@ def backward(ctx, dout): dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: - if isinstance(dout, QuantizedTensorBase): + if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) dq_buffer = torch.empty_like(q) p2p_comm_buffers = [ @@ -3396,7 +3396,7 @@ def backward(ctx, dout): if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorBase): + if not isinstance(dout, QuantizedTensorStorage): dout = ctx.dO_quantizer(dout) dout_fp8 = dout dqkv_te_dtype = dout._fp8_dtype @@ -3409,7 +3409,7 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if isinstance(dout, QuantizedTensorBase): + if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} From f0a9404881777ba0496e56e62d682ebb3896e91c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 1 Oct 2025 10:33:05 -0700 Subject: [PATCH 51/78] Fix hang during debug build (#2221) Disable debug build for cutlass GEMM Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a4915080e8..e0fe3c04a6 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -164,6 +164,9 @@ else() message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") endif() +# Disable debug build for cutlass due to hang. +set_source_files_properties("gemm/cutlass_grouped_gemm.cu" PROPERTIES COMPILE_FLAGS "-g0") + # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas From 90449f796718022fd34ae518c7f4a37df0fc76f2 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 1 Oct 2025 14:09:38 -0700 Subject: [PATCH 52/78] Convert `NVFP4BlockScaling` to dataclass (#2227) Fix passing args to nvfp4 recipe Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 324b5d50c8..1a9b029878 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -382,6 +382,7 @@ def __repr__(self) -> str: ) +@dataclass() class NVFP4BlockScaling(Recipe): """ Use the NVFP4 scaling strategy. From aee5a82108bc3053ef01fa1bd2459fe7c0a154f5 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 1 Oct 2025 14:12:15 -0700 Subject: [PATCH 53/78] Fix the cuBLAS workspace alignment (#2223) * Fix the cublas workspace alignment Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak Signed-off-by: Przemyslaw Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- transformer_engine/common/gemm/cublaslt_gemm.cu | 16 ++++++++++++---- transformer_engine/pytorch/module/base.py | 4 ++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ab80fe7698..a4810881c4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -679,6 +679,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif } + // align the workspace to 256 B + const int required_alignment = 256; + const auto original_workspace_alignment = _getAlignment(reinterpret_cast(workspace)); + uint8_t *aligned_workspace_ptr = + reinterpret_cast(workspace) + required_alignment - original_workspace_alignment; + workspaceSize = workspaceSize - required_alignment + original_workspace_alignment; + const auto new_workspace_alignment = + _getAlignment(reinterpret_cast(aligned_workspace_ptr)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); @@ -686,7 +694,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); - const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -695,8 +702,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); - NVTE_CHECK(workspace_alignment % 256 == 0, - "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); + NVTE_CHECK(new_workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", + new_workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, @@ -714,7 +722,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, C, /* C */ Cdesc, D, /* D */ Ddesc, &heuristicResult.algo, /* algo */ - workspace, /* workspace */ + aligned_workspace_ptr, /* workspace */ workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d60ff80593..3ae3895689 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -78,8 +78,8 @@ class UserBufferQuantizationMode(Enum): def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - # 32 MiB for NVFP4 GEMM, plus 256 B for misc scales - return 32 * 1024 * 1024 + 256 + # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales + return 32 * 1024 * 1024 + 1024 return 4_194_304 From c1003181dbd5123a3e349266e8dc118f89d78485 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:42:36 -0700 Subject: [PATCH 54/78] [PyTorch] Set usages for linear op quantizers before forward (#2222) * Make sure to set usages for linear op quantizers before forward Signed-off-by: Tim Moon * Avoid unsupported case for fused dbias+quantize kernel Hopper does not support dbias + FP8 cast without FP8 transpose. Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- tests/pytorch/distributed/test_fusible_ops.py | 215 +++++++++++++++++- .../pytorch/csrc/extensions/bias.cpp | 23 +- .../pytorch/ops/basic/basic_linear.py | 76 ++++--- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- transformer_engine/pytorch/ops/fuser.py | 4 + transformer_engine/pytorch/ops/op.py | 11 + 8 files changed, 296 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 11fe4333bc..af0f0e9313 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -635,6 +635,204 @@ def _test_linear( torch.testing.assert_close(db_test, db_ref, **tols) +def _test_mlp( + *, + bias: bool = True, + hidden_size: int = 32, + local_batch_size: int = 32, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str] = None, + quantized_weight: bool = False, + sequence_parallel: bool = False, +) -> None: + """2-layer MLP + + MLP includes GELU activation in order to test op fusions. Model + performs warmup steps in order to test inter-step logic. + + """ + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + mlp_size = hidden_size * world_size + batch_size = local_batch_size + if sequence_parallel: + batch_size *= world_size + in_shape = (batch_size, hidden_size) + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w1_ref, w1_test = make_reference_and_test_tensors( + (mlp_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b1_ref, b1_test = None, None + w2_ref, w2_test = make_reference_and_test_tensors( + (hidden_size, mlp_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = None, None + if bias: + b1_ref, b1_test = make_reference_and_test_tensors( + (mlp_size,), + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = make_reference_and_test_tensors( + (world_size, hidden_size), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w1_ref) + if bias: + y_ref += b1_ref + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w2_ref) + if bias: + y_ref += b2_ref.sum(dim=0) + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + local_mlp_size = mlp_size // world_size + local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size) + dx_ref = x_ref.grad + dw1_ref = w1_ref.grad[local_mlp_slice, :] + w1_ref = w1_ref[local_mlp_slice, :] + w1_test = w1_test[local_mlp_slice, :] + dw2_ref = w2_ref.grad[:, local_mlp_slice] + w2_ref = w2_ref[:, local_mlp_slice] + w2_test = w2_test[:, local_mlp_slice] + if bias: + db1_ref = b1_ref.grad[local_mlp_slice] + b1_ref = b1_ref[local_mlp_slice] + b1_test = b1_test[local_mlp_slice] + db2_ref = b2_ref.grad[rank, :] + b2_ref = b2_ref[rank, :] + b2_test = b2_test[rank, :] + else: + db1_ref = None + db2_ref = None + if sequence_parallel: + local_batch_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + x_ref = x_ref[local_batch_slice, ...] + dx_ref = dx_ref[local_batch_slice, ...] + x_test = x_test[local_batch_slice, ...].clone() + y_ref = y_ref[local_batch_slice, ...] + dy_ref = dy_ref[local_batch_slice, ...] + dy_test = dy_test[local_batch_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model = te_ops.Sequential( + te_ops.GELU(), + te_ops.Linear( + hidden_size, + mlp_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="column", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + te_ops.Linear( + mlp_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="row", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + ) + with torch.no_grad(): + model[1].weight.copy_(w1_test) + model[3].weight.copy_(w2_test) + if bias: + model[1].bias.copy_(b1_test) + model[3].bias.copy_(b2_test) + del w1_test, w2_test, b1_test, b2_test + + # Warmup steps + for _ in range(3): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + x_test.grad = None + model[1].weight.grad = None + model[3].weight.grad = None + if bias: + model[1].bias.grad = None + model[3].bias.grad = None + + # Forward and backward step + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + torch.testing.assert_close(dw1_test, dw1_ref, **tols) + torch.testing.assert_close(dw2_test, dw2_ref, **tols) + if bias: + db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") + db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db1_test, db1_ref, **tols) + torch.testing.assert_close(db2_test, db2_ref, **tols) + + def _test_fp8_scale_update( *, amax_history_len: int = 31, @@ -801,16 +999,31 @@ def run_parallel_tests() -> None: for config in itertools.product( quantization_list, ("column", "row"), + (False, True), ): if rank == 0: print(f"Running _test_linear with {config=}") - quantization, tensor_parallel_mode = config + quantization, tensor_parallel_mode, sequence_parallel = config dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, + sequence_parallel=sequence_parallel, + ) + + # MLP + for config in itertools.product(quantization_list, (False, True)): + if rank == 0: + print(f"Running _test_mlp with {config=}") + quantization, sequence_parallel = config + dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 + _test_mlp( + bias=True, # bias=False is tested in _test_basic_linear + dtype=dtype, + quantization=quantization, + sequence_parallel=sequence_parallel, ) # FP8 scale update diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 0531596dd3..b0435d2723 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -54,10 +54,25 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; } - // Unfused impl if quantizer is not supported - const bool with_fused_dbias_quantize_kernel = - detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); - if (!with_fused_dbias_quantize_kernel) { + // Check if fused kernel is supported + bool with_fused_kernel = false; + if (detail::IsFloat8Quantizers(quantizer.ptr())) { + auto prop = at::cuda::getCurrentDeviceProperties(); + const size_t sm_arch = 10 * prop->major + prop->minor; + if (sm_arch >= 100) { + // Fused kernel for dbias + FP8 cast on SM arch 10.0+ + with_fused_kernel = true; + } else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) { + // Fused kernel for dbias + FP8 cast + FP8 transpose + with_fused_kernel = true; + } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Fused kernel for dbias + MXFP8 quantize + with_fused_kernel = true; + } + + // Apply unfused impl if fused kernel is not supported + if (!with_fused_kernel) { at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 844e49ff07..cb2119296f 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -322,6 +322,20 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + weight_requires_grad = requires_grad and self.weight.requires_grad + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -352,6 +366,35 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: and not getattr(self, "_with_quantized_weight", False) ) + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + if recipe.nvfp4(): + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin @@ -731,7 +774,7 @@ def _functional_backward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(columnwise=True) + input_quantizer.set_usage(rowwise=False, columnwise=True) if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, @@ -912,42 +955,13 @@ def op_forward( input_requires_grad = ctx.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - # Configure quantizers - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) - - # Recipe-specific configuration - recipe = FP8GlobalStateManager.get_fp8_recipe() - if recipe.float8_current_scaling(): - input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group - if recipe.nvfp4(): - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Get autocast dtype if needed if torch.is_autocast_enabled(): diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 02bcfee0ae..ab271e17b7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -85,7 +85,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 15cc081c1d..4831ae4076 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -79,7 +79,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 21190d4fcf..72e17f64e8 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -58,7 +58,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index ccd7ee52b2..6f80a7a1f3 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -472,6 +472,10 @@ def __call__( # Attempt to fuse operations if neccesary self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) + # Initialization before forward + for idx, op in enumerate(self._basic_ops): + op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) + # Fuser forward pass if is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 903bc49d51..103ebf2418 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -65,6 +65,13 @@ def is_fused_op(self) -> bool: def pre_first_fuser_forward(self) -> None: """Preprocessing before first fuser forward pass""" + def pre_fuser_forward( + self, + *, + requires_grad: bool, # pylint: disable=unused-argument + ) -> None: + """Preprocessing before fuser forward pass""" + def get_input_quantizer(self) -> Optional[Quantizer]: """Get builder class for quantized input tensor""" @@ -710,6 +717,10 @@ def pre_first_fuser_forward(self) -> None: for op in self.basic_ops: op.pre_first_fuser_forward() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + for op in self.basic_ops: + op.pre_fuser_forward(requires_grad=requires_grad) + def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin From f936c2ac82f348deba74180eea1732a55e118cc6 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 2 Oct 2025 06:31:55 -0700 Subject: [PATCH 55/78] [JAX] Fix code block in fp8_autocast docstring (#2228) Fix code block in fp8_autocast docstring Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen --- transformer_engine/jax/quantize/helper.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 3d460e81ab..67f0a68c6a 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -404,20 +404,20 @@ def fp8_autocast( This context manager enables FP8 quantization for the duration of its context. .. code-block:: python - mesh_shape = (4, 2) - dp_mesh_axis_name = 'data_parallel' - tp_mesh_axis_name = 'tensor_parallel' - devices = np.asarray(jax.devices()).reshape(*mesh_shape) + mesh_shape = (4, 2) + dp_mesh_axis_name = 'data_parallel' + tp_mesh_axis_name = 'tensor_parallel' + devices = np.asarray(jax.devices()).reshape(*mesh_shape) - with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): - mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) + with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): + mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) - with fp8_autocast(enabled=True, mesh_resource=mesh_resource): - rules = extend_logical_axis_rules(tuple()) - transformer = TransformerLayer() + with fp8_autocast(enabled=True, mesh_resource=mesh_resource): + rules = extend_logical_axis_rules(tuple()) + transformer = TransformerLayer() - with partitioning.axis_rules(rules): - pjit(transformer.init, ...)(...) + with partitioning.axis_rules(rules): + pjit(transformer.init, ...)(...) .. note:: We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`, From be7f43f10ce34ce3f878a63933a6dd45eb10bafc Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 2 Oct 2025 06:32:24 -0700 Subject: [PATCH 56/78] [JAX] Fix shard map issue when `get_all_mesh_axes()` is used (#2229) Fix shard map issue Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen --- transformer_engine/jax/sharding.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 7a82612695..d3a7952d39 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -131,7 +131,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): # We want to exclude the axes that already used by shard_map and shard_map # only sets those in the abstract_mesh, not the physical one manual_axis_names = get_abstract_mesh().manual_axes - cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec) + + # Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too + def filter_manual_axes(name_or_tuple): + if isinstance(name_or_tuple, tuple): + out = tuple(n for n in name_or_tuple if n not in manual_axis_names) + if len(out) == 0: + return None + return out + if name_or_tuple in manual_axis_names: + return None + return name_or_tuple + + cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec) + + if cleaned_axis_names == (None,) * len(cleaned_axis_names): + return x cleaned_pspec = PartitionSpec(*cleaned_axis_names) return jax.lax.with_sharding_constraint(x, cleaned_pspec) From e30c36a30883c49b820096a0bb856c7ea71bebd5 Mon Sep 17 00:00:00 2001 From: hx Date: Fri, 3 Oct 2025 05:04:28 +0800 Subject: [PATCH 57/78] [PyTorch] fix int32 overflow in permute kernels (#2196) * fix overflow of int32 in permute kernels Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/triton/permutation.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index ceb88108f1..6292acb69b 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -324,7 +324,8 @@ def _permute_kernel( pid_h = tl.program_id(1) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = cur_off < hidden_size - input_off = pid_t * stride_input_token + cur_off * stride_input_hidden + src_row = pid_t.to(tl.int64) + input_off = src_row * stride_input_token + cur_off * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) if PERMUTE_SCALE: mask_scale = cur_off < scale_hidden_dim @@ -338,7 +339,7 @@ def _permute_kernel( for idx in tl.range(n_routed): dst_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ) + ).to(tl.int64) output_off = dst_row * stride_output_token + cur_off * stride_output_hidden if PERMUTE_SCALE: permuted_scale_off = ( @@ -519,7 +520,7 @@ def _unpermute_kernel( for idx in tl.range(n_routed): src_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ) + ).to(tl.int64) input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) inp = inp.to(compute_type) @@ -550,7 +551,8 @@ def _unpermute_kernel( prob = tl.load(permuted_probs_ptr + permuted_prob_off) tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) accumulator = accumulator.to(data_type) - output_off = pid_t * stride_output_token + current_offset * stride_output_hidden + dst_row = pid_t.to(tl.int64) + output_off = dst_row * stride_output_token + current_offset * stride_output_hidden tl.store(output_ptr + output_off, accumulator, mask=mask) @@ -681,7 +683,7 @@ def _unpermute_bwd_with_merging_probs_kernel( for idx in tl.range(n_routed): dst_row = tl.load( row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert - ) + ).to(tl.int64) expert_idx = tl.load( row_id_map_ptr + pid * stride_row_id_map_token @@ -692,8 +694,10 @@ def _unpermute_bwd_with_merging_probs_kernel( while current_start < hidden_size: current_offset = current_start + tl.arange(0, BLOCK_SIZE) mask = current_offset < hidden_size + src_row = pid.to(tl.int64) input_off = ( - pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden + src_row * stride_fwd_output_grad_token + + current_offset * stride_fwd_output_grad_hidden ) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) inp = inp.to(compute_type) @@ -902,11 +906,11 @@ def _sort_chunks_by_map_kernel( pid_t = tl.program_id(0) pid_h = tl.program_id(1) if FORWARD: - src_row = pid_t - dst_row = tl.load(row_id_map_ptr + pid_t) + src_row = pid_t.to(tl.int64) + dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) else: - src_row = tl.load(row_id_map_ptr + pid_t) - dst_row = pid_t + src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) + dst_row = pid_t.to(tl.int64) current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = current_offset < hidden_size input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden From b840898b75162bce68fbc3c9c8234b6f23dcdbff Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Fri, 3 Oct 2025 09:39:46 -0700 Subject: [PATCH 58/78] [JAX] Clamped Swiglu Integration (#2194) Signed-off-by: Varun Thumbe *Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax. --- tests/jax/test_custom_call_compute.py | 89 ++++++--- .../include/transformer_engine/activation.h | 1 + .../common/util/cast_gated_kernels.cuh | 14 +- transformer_engine/jax/activation.py | 26 ++- .../jax/cpp_extensions/activation.py | 176 +++++++++++++++--- transformer_engine/jax/csrc/extensions.h | 17 ++ .../jax/csrc/extensions/activation.cpp | 52 ++++-- .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/jax/flax/module.py | 13 +- transformer_engine/jax/flax/transformer.py | 5 + transformer_engine/jax/layernorm_mlp.py | 18 +- 11 files changed, 324 insertions(+), 88 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7f15eec892..7a4fa268af 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -170,6 +170,7 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { @@ -182,17 +183,21 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: - def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type).data + def ref_act(self, x, activation_type, act_params): + return _jax_act_lu(x, activation_type, act_params=act_params).data - def value_n_grad_ref_func(self, x, activation_type): + def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) - def primitive_func(self, inputs, activation_type, quantizer): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer) + def primitive_func(self, inputs, activation_type, quantizer, act_params): + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -209,12 +214,20 @@ def test_act_grad(self, shape, activation_type): x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -234,7 +247,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -242,9 +256,21 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -273,10 +299,18 @@ def test_act_forward_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=q_layout, ) - - te_output = tex.act_lu(x, activation_type, te_quantizer) - jax_output = _jax_act_lu(x, activation_type, jax_quantizer) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) + jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -296,10 +330,18 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - - output = tex.act_lu(x, activation_type, quantizer) - ref_out = self.ref_act(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + output = tex.act_lu(x, activation_type, quantizer, act_params) + ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) @@ -734,6 +776,7 @@ def test_quantize_dbias( def _test_quantize_dact_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): + key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) @@ -785,7 +828,7 @@ def _test_quantize_dact_dbias( (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. or ( - activation_type == ("squared_relu",) + activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")} and in_dtype == jnp.bfloat16 and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING ) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index e50d71040d..4e48088586 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { QGEGLU, SRELU, SREGLU, + CLAMPED_SWIGLU }; /*! \brief Computes the GeLU activation of the input. diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index ca37a28319..93086bd827 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; @@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } template -void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { +void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); NVTE_CHECK(input.flat_last_dim() % 2 == 0, @@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); @@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); @@ -1318,7 +1316,7 @@ namespace detail { template void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - ParamOP &p, cudaStream_t stream) { + ParamOP p, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 12b35ec43c..daa3679c48 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -11,7 +11,6 @@ import jax import jax.numpy as jnp - from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor @@ -22,6 +21,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[tex.activation.ActivationParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. @@ -32,17 +32,19 @@ def activation( x: Input tensor to apply activations to activation_type: Sequence of activation functions quantizer: Optional quantizer for quantizing the output + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated output tensor """ assert x.shape[-1] % len(activation_type) == 0 - output = _activation(x, activation_type, quantizer) + output = _activation(x, activation_type, quantizer, act_params) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _activation(x, activation_type, quantizer): +@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +def _activation(x, activation_type, quantizer, act_params): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated tensor """ - _output, _ = _activation_fwd_rule(x, activation_type, quantizer) + _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params) return _output -def _activation_fwd_rule(x, activation_type, quantizer): +def _activation_fwd_rule(x, activation_type, quantizer, act_params): """Forward pass rule for activation function. Args: x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Tuple of (output, context) for backward pass """ - fwd_output = tex.act_lu(x, activation_type, quantizer) + fwd_output = tex.act_lu(x, activation_type, quantizer, act_params) # This is a no-op for higher-precision tensors fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) -def _activation_bwd_rule(activation_type, ctx, g): +def _activation_bwd_rule(activation_type, act_params, ctx, g): """Backward pass rule for activation function. Args: activation_type: Sequence of activation functions + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. ctx: Context from forward pass g: Gradient from upstream @@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g): """ (x, _) = ctx assert x.dtype == g.dtype - dx = tex.dact_lu(g, x, activation_type) + dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should # always be a NoScaleTensor assert isinstance(dx, NoScaleTensor) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index a8c14a6087..925c1d01ae 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -5,6 +5,7 @@ from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial +from dataclasses import dataclass import jax import jax.numpy as jnp @@ -12,9 +13,9 @@ from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec +import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type - from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -51,17 +52,87 @@ ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, + ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } -def _convert_to_activation_function(fn_or_string): +@dataclass(frozen=True) +class ClampedSwigluParams: + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + limit: float = 7.0 + alpha: float = 1.702 + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work. + + Returns: + int: Hash value of the dataclass instance. + """ + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + + +@dataclass(frozen=True) +class ActivationParams: + """Parameters for various activation functions. + Currently only Clamped SwiGLU activation has parameters. + """ + + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + CLAMPED_ACTIVATION_TYPES = { + ("clamped_silu", "clamped_linear"), + "clamped_silu", + "clamped_linear", + } + if activation_type in CLAMPED_ACTIVATION_TYPES: + return ActivationParams(ClampedSwigluParams(**kwargs)) + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work""" + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + + +def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + if fn_or_string == "clamped_linear": + # This function is used for ClampedSwiGLU + # used in GPT OSS where the gates are not only clamped + # but also shifted by +1 + limit = act_params.clamped_swiglu.limit + return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if fn_or_string == "clamped_silu": + limit = act_params.clamped_swiglu.limit + alpha = act_params.clamped_swiglu.alpha + return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -84,7 +155,8 @@ class ActLuPrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer + 9, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params inner_primitive = None outer_primitive = None @@ -100,11 +172,12 @@ def abstract( is_2x, scale_dtype, is_outer, + act_params, ): """ te_act_lu_p abstract """ - del act_enum + del act_enum, act_params dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 @@ -150,6 +223,7 @@ def lowering( is_2x, scale_dtype, is_outer, + act_params, ): """ te_gated_act_lu_p lowering rules @@ -158,9 +232,14 @@ def lowering( x_aval, scale_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x + ctx, + x, + scale, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + is_2x=is_2x, + act_params=act_params.to_ffi_lowering_dict(), ) return out @@ -175,6 +254,7 @@ def impl( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe implementation @@ -193,6 +273,7 @@ def impl( is_2x=is_2x, scale_dtype=scale_dtype, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -221,6 +302,7 @@ def batcher( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -242,6 +324,7 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + act_params=act_params, ), out_bdims, ) @@ -255,6 +338,7 @@ def infer_sharding_from_operands( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -266,6 +350,7 @@ def infer_sharding_from_operands( scale_dtype, act_len, is_outer, + act_params, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -318,6 +403,7 @@ def partition( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -378,6 +464,7 @@ def sharded_impl(x, scale): is_2x=is_2x, scale_dtype=scale_dtype, is_outer=True, + act_params=act_params, ) ) @@ -405,11 +492,12 @@ def shardy_sharding_rule( is_2x, scale_dtype, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types + del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params prefix = "ActLu_" input_shape = value_types[0].shape output_shape = input_shape[:-2] + input_shape[-1:] @@ -455,8 +543,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params + impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) inner_primitive = None outer_primitive = None @@ -474,11 +562,12 @@ def abstract( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p abstract """ - del act_enum + del act_enum, act_params dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_dtype @@ -575,6 +664,7 @@ def lowering( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p lowering rules @@ -593,6 +683,7 @@ def lowering( is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), + act_params=act_params.to_ffi_lowering_dict(), ) @staticmethod @@ -608,6 +699,7 @@ def impl( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p impl @@ -627,6 +719,7 @@ def impl( act_enum=act_enum, act_len=act_len, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -655,6 +748,7 @@ def batcher( act_enum, act_len, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -685,6 +779,7 @@ def batcher( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, ), out_bdims, ) @@ -699,11 +794,12 @@ def infer_sharding_from_operands( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum + del out_dtype, result_infos, act_enum, act_params del scale_dtype, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -774,6 +870,7 @@ def partition( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -854,6 +951,7 @@ def sharded_impl(dz, x, scale): act_enum=act_enum, act_len=act_len, is_outer=True, + act_params=act_params, ) ) if is_dbias: @@ -880,11 +978,13 @@ def shardy_sharding_rule( act_enum, act_len, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types + + del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 @@ -923,20 +1023,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) + x_i = _convert_to_activation_function(act_fn, act_params)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) x = jnp.squeeze(x, axis=-2) @@ -951,10 +1053,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -962,7 +1066,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -985,6 +1090,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1008,24 +1114,22 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer) + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_act_lu(x, activation_type, quantizer) - + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer + f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) - if quantizer is None: out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, @@ -1037,6 +1141,7 @@ def act_lu( is_2x=False, scale_dtype=jnp.float32, is_outer=True, + act_params=act_params, ) out = out.reshape(output_shape) out = NoScaleTensor( @@ -1051,6 +1156,7 @@ def act_lu( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, _ = _quantize_dbias_impl( out, @@ -1060,7 +1166,6 @@ def act_lu( amax_scope=amax_scope, ) return out - if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1080,6 +1185,7 @@ def act_lu( is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), is_outer=True, + act_params=act_params, ) quantizer.update(updated_amax) @@ -1102,6 +1208,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1118,7 +1225,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1131,8 +1238,7 @@ def quantize_dact_dbias( if not PrimitiveClass.enabled() or ( quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( dz, @@ -1148,6 +1254,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) output = output.astype(x.dtype) dbias = None @@ -1163,7 +1270,11 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1180,6 +1291,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, + act_params=act_params, ) if war_output is not None: return war_output @@ -1191,6 +1303,7 @@ def quantize_dact_dbias( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1203,7 +1316,10 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1229,6 +1345,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1257,6 +1374,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1270,11 +1388,13 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ + act_params = act_params if act_params is not None else ActivationParams() output, _ = quantize_dact_dbias( dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=quantizer, + act_params=act_params, ) return output diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ab95002fa..bbfc62120a 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -36,6 +36,15 @@ namespace transformer_engine { namespace jax { +struct ClampedSwigluConfig { + float limit; + float alpha; +}; + +struct ActivationConfig { + ClampedSwigluConfig clamped_swiglu; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::ActivationConfig, + ::xla::ffi::StructMember("clamped_swiglu")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index b2b3db52c8..0ecf791505 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int) { + bool is_2x_int, ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::SREGLU: nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("is_2x") + .Attr("act_params"), FFI_CudaGraph_Traits); Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, - JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { + JAXX_Scaling_Mode scaling_mode, bool is_2x_int, + ActivationConfig act_params) { return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, - act_enum, scaling_mode, is_2x_int); + act_enum, scaling_mode, is_2x_int, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, @@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x")); + .Attr("is_2x") + .Attr("act_params")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias) { + int64_t act_enum, bool is_2x, bool is_dbias, + ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::SREGLU: nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias"), + .Attr("is_dbias") + .Attr("act_params"), FFI_CudaGraph_Traits); -Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, - Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t act_enum, - bool is_2x, bool is_dbias) { +Error_Type DActLuDBiasQuantizeInitializeFFI( + cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, + bool is_dbias, ActivationConfig act_params) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, act_input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias")); + .Attr("is_dbias") + .Attr("act_params")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 36dd8205bf..23d46b3384 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) .value("SREGLU", NVTE_Activation_Type::SREGLU) + .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU) .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c548c54efa..f02876d8f4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase): activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. + activation_params: dict, default = None + The parameters needed(if any) by the activation functions specified in :attr:`activations`. + At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS + need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1 @@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ("relu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1023,6 +1028,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ("relu", "linear"), ("quick_gelu", "linear"), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] @@ -1031,7 +1037,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) @@ -1150,6 +1158,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=self.ffn1_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, + activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1287,4 +1296,4 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): out = checkpoint_name(out, self.ffn2_ckpt_name) assert out.dtype == input_dtype - return out, ln_output # Output, layner_norm_output + return out, ln_output # Output, layer_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index ad66684f2b..868bcfa057 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_activations: Sequence[str], default = ('relu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + mlp_activation_params: dict = None + This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases. @@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None mlp_activations: Sequence[str] = ("relu",) + mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False @@ -2046,6 +2050,7 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, + activation_params=self.mlp_activation_params, intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index cf77f8e0a0..77daa4672c 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -50,6 +50,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + activation_params: dict = None, collective_op_sets: Tuple[tex.CollectiveOpSet] = ( tex.noop_collective_op_set, tex.noop_collective_op_set, @@ -138,13 +139,14 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -165,6 +167,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + activation_params: dict, collective_op_sets: Tuple[tex.CollectiveOpSet], quantizer_sets, ): @@ -220,6 +223,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ) @@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ): @@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, ctx, grad, @@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From dfe5b7dfc2288afc5d2f247709b1e0328af331e4 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Fri, 3 Oct 2025 20:09:41 +0200 Subject: [PATCH 59/78] [Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell (#2157) * Update to_string(NVTEScalingMode) to include block scaling Signed-off-by: Jan Bielak * Add `nvte_swizzle_block_scaling_to_mxfp8_scaling_factors` Signed-off-by: Jan Bielak * Convert FP8 block scaling tensors to MXFP8 tensors on Blackwell and newer in GEMM Signed-off-by: Jan Bielak * Allow Blackwell and newer in Deepseek recipe compatbility check Signed-off-by: Jan Bielak * Allow data_rows % 4 != 0 in 1d kernel Signed-off-by: Jan Bielak * Load scaling factors in unswizzled order in 1d kernel Signed-off-by: Jan Bielak * Enforce use of power of two scaling Signed-off-by: Jan Bielak * Skip the FP8 block scaling exact GEMM test on Blackwell Signed-off-by: Jan Bielak * Skip further tests with pow_2_scales=False Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Initial implementation of tensor conversion for grouped gemm Signed-off-by: Jan Bielak * Skip non power of two scaling cpp unit tests Signed-off-by: Jan Bielak * Fix handling of all gather Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jan Bielak * Use compute capability 10.0 for logic with Blackwell Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Jan Bielak Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../cpp/operator/test_cast_float8blockwise.cu | 12 + .../test_float8_blockwise_gemm_exact.py | 4 +- .../test_float8_blockwise_scaling_exact.py | 14 + transformer_engine/common/CMakeLists.txt | 1 + .../include/transformer_engine/swizzle.h | 20 ++ .../common/swizzle/swizzle_block_scaling.cu | 321 ++++++++++++++++++ .../common/transformer_engine.cpp | 4 + .../quantize_transpose_square_blockwise.cu | 6 + .../quantize_transpose_vector_blockwise.cu | 6 + .../pytorch/csrc/extensions/gemm.cpp | 99 ++++-- transformer_engine/pytorch/csrc/util.cpp | 70 ++++ transformer_engine/pytorch/csrc/util.h | 12 + transformer_engine/pytorch/distributed.py | 8 +- transformer_engine/pytorch/fp8.py | 11 +- 14 files changed, 553 insertions(+), 35 deletions(-) create mode 100644 transformer_engine/common/swizzle/swizzle_block_scaling.cu diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index e5faa688ce..fe4ae2d264 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -501,6 +501,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 2u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. @@ -552,6 +558,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 1u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index ec23cfe8c5..bdc73519be 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -19,7 +20,8 @@ def fp8_blockwise_gemm_supported() -> bool: supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() - return supported + emulated = get_device_compute_capability() >= (10, 0) + return supported and not emulated def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 858ce73b6b..51e0d1ec9b 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8BlockScaling +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -32,6 +33,7 @@ if tensor_dump_dir_env is not None: TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() +recipe_emulated = get_device_compute_capability() >= (10, 0) class GetRecipes: @@ -218,6 +220,12 @@ def check_quantization_block_tiling_versus_reference( pow_2_scales: bool, tile_size: Tuple[int, int], ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): block_scaling_dim = 1 @@ -409,6 +417,12 @@ def test_quantization_block_tiling_extrema_versus_reference( tile_size: Tuple[int, int], extrema_high: bool, ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + # This test runs a single tile through a quantizer as a way to test # branch coverage of scale computation. te_dtype = TE_DType[quant_dtype] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e0fe3c04a6..92b57897de 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -127,6 +127,7 @@ list(APPEND transformer_engine_SOURCES util/multi_stream.cpp util/rtc.cpp swizzle/swizzle.cu + swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 079feb4a7d..624e71d1e3 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM + * + * \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv. + * \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it + * not natively supported by cublasLt on architectures other than Hopper. + + * Requirements: + * - input is an FP8 block scaling tensor + * - input has rowwise usage + * - input.scale_inv is in GEMM_READY format + * - output is an MXFP8 tensor + * - output has rowwise usage + * - output.scale_inv has appropriate shape + * */ +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu new file mode 100644 index 0000000000..4be85474af --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -0,0 +1,321 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace { +constexpr uint32_t WARP_SIZE = 32; +} // namespace +namespace swizzle_kernel_1d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where +// each thread stores a single row (of four bytes). +// Example: +// lane0.row = 0x00010203 +// lane1.row = 0x04050607 +// lane2.row = 0x08090a0b +// lane3.row = 0x0c0d0e0f +// Becomes: +// lane0.row = 0x0004080c +// lane1.row = 0x0105090d +// lane2.row = 0x02060a0e +// lane3.row = 0x03070b0f +uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row, + const uint32_t lane, + const uint32_t active_mask) { + using cu = const uint32_t; + + // Threads operate in groups of 4, and each thread stores 4 bytes at a time. + // The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes + // until we have transposed the 4x4 matrix. + cu m_0123_4567_89ab_cdef = row; + cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4); + cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240); + cu m_5173_1537_d9fb_9dbf = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x3715); + cu m_0426_1537_8cae_9dbf = (lane & 1) ? m_5173_1537_d9fb_9dbf : m_0426_4062_8cae_c8ea; + cu m_8cae_9dbf_0426_1537 = __shfl_xor_sync(active_mask, m_0426_1537_8cae_9dbf, 2, 4); + cu m_048c_159d_8c04_9d15 = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x5410); + cu m_ae26_bf37_26ae_37bf = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x3276); + cu m_048c_159d_26ae_37bf = (lane & 2) ? m_ae26_bf37_26ae_37bf : m_048c_159d_8c04_9d15; + + return m_048c_159d_26ae_37bf; +} + +// Expands a uint32_t to a uint4 by duplicating each byte four times. +// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404} +uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) { + return {__byte_perm(x, 0, 0x0000), __byte_perm(x, 0, 0x1111), __byte_perm(x, 0, 0x2222), + __byte_perm(x, 0, 0x3333)}; +} + +// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data +// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors. +struct no_oob_tag_t {}; +constexpr no_oob_tag_t NO_OOB_TAG; + +template +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride, + OOBT first_oob) { + // resolve kernel variant + constexpr bool no_oob = std::is_same_v; + static_assert(no_oob || std::is_same_v); + + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_x; + const uint32_t in_tile_x = out_tile_y; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factors for this lane's initial four 1x128 tiles + uint4 sf; + if constexpr (no_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + if ((out_tile_y < tiles_y - 1) || lane < first_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + sf = uint4{0, 0, 0, 0}; + } + } + + // pack the exponent bits of the scaling factors + uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); + + // partially swizzle the scaling factors + constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches + const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); + packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx); + + // transpose 4x4 matrices of scaling factors + packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK); + + // broadcast the scaling factors for sixteen 1x32 tiles + sf = broadcast_uint32_t_to_uint4(packed_exponents); + + // store them cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(uint4)), "Input scaling factor pointer must be aligned to ", + alignof(uint4), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales + // and a 128x4 tile in the output scales. The input scales are in transposed order. + const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + const uint32_t first_oob = (input_scale_inv_cols % 128) / 4; + + if (first_oob == 0) { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, NO_OOB_TAG); + } else { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, first_oob); + } +} +} // namespace swizzle_kernel_1d +namespace swizzle_kernel_2d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_y; + const uint32_t in_tile_x = out_tile_x; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = sizeof(float); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factor for this warp's 128x128 tile + uint32_t sf = *reinterpret_cast(warp_src); + + // broadcast it to four scaling factors for 1x32 tiles + sf = (sf << 1) | (sf >> 7); + sf = sf | (sf >> 16); + + // broadcast it to sixteen scaling factors for 1x32 tiles + const uint4 sf4{sf, sf, sf, sf}; + + // store it cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf4; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(float)), "Input scaling factor pointer must be aligned to ", + alignof(float), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales + // and a 128x4 tile in the output scales. + const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); +} +} // namespace swizzle_kernel_2d + +void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* output, + cudaStream_t stream) { + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "block_scaling_scaling_factor_input"); + CheckInputTensor(*output, "mxfp8_scaling_factor_output"); + + const NVTEScalingMode scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING, + "Output tensor must be an mxfp8 tensor"); + + NVTE_CHECK(input->data.dtype == transformer_engine::DType::kFloat8E4M3 || + input->data.dtype == transformer_engine::DType::kFloat8E5M2, + "Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8"); + NVTE_CHECK(output->data.dtype == input->data.dtype, + "Output data must have the same dtype as input data"); + NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32, "Input must have FP32 scaling factors"); + NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0, + "Output must have E8M0 scaling factors"); + + NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data"); + NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input"); + NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors"); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Output must have rowwise scaling factors"); + + NVTE_CHECK(input->data.shape.size() == 2, "Input data must be a matrix"); + NVTE_CHECK(output->data.shape == input->data.shape, + "Output data must have the same shape as input data"); + NVTE_CHECK(input->scale_inv.shape.size() == 2, "Input scaling factors must be a matrix"); + NVTE_CHECK(output->scale_inv.shape.size() == 2, "Output scaling factors must be a matrix"); + + const size_t data_rows = input->data.shape[0]; + const size_t data_cols = input->data.shape[1]; + const size_t input_scale_inv_rows = input->scale_inv.shape[0]; + const size_t input_scale_inv_cols = input->scale_inv.shape[1]; + const size_t output_scale_inv_rows = output->scale_inv.shape[0]; + const size_t output_scale_inv_cols = output->scale_inv.shape[1]; + + NVTE_CHECK(output_scale_inv_rows == DIVUP(data_rows, 128) * 128, + "Expected the output scaling factor matrix to have ", + DIVUP(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows, + " rows instead."); + NVTE_CHECK(output_scale_inv_cols == DIVUP(data_cols, 128) * 4, + "Expected the output scaling factor matrix to have ", + DIVUP(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols, + " columns instead."); + + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_cols, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_cols, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_rows, 4) * 4, + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 4) * 4, + " columns, but it has ", input_scale_inv_cols, " columns instead."); + + swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } else { // scaling_mode == NVTE_BLOCK_SCALING_2D + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_rows, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_cols, 512) * 4, + "Expected the input scaling factor matrix to have ", + DIVUP(data_cols, 512) * 4, " columns, but it has ", input_scale_inv_cols, + " columns instead."); + + swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } +} + +} // namespace transformer_engine + +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors); + using namespace transformer_engine; + swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input), + convertNVTETensorCheck(output), stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index f49fe239aa..35e8b683ad 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -64,6 +64,10 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; + case NVTE_BLOCK_SCALING_1D: + return "NVTE_BLOCK_SCALING_1D"; + case NVTE_BLOCK_SCALING_2D: + return "NVTE_BLOCK_SCALING_2D"; case NVTE_NVFP4_1D_SCALING: return "NVTE_NVFP4_1D_SCALING"; case NVTE_INVALID_SCALING: diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index c3f085b877..661cf339ae 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,6 +14,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -485,6 +486,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index d38bf79963..fcf7a151c3 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,6 +17,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -529,6 +530,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; size_t num_rows = 1; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 1364597519..15404ad9a6 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -104,6 +104,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const bool low_precision = detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype()); + const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D; // Check tensor dimensions const auto& A_shape = A_tensor.shape(); @@ -235,6 +239,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { + // Convert tensors to mxfp8 and swizzle their scaling factors + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa))); + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb))); + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + if (comm_overlap) { // Prepare extra output tensor TensorWrapper extra_output_tensor; @@ -379,15 +396,6 @@ std::optional> te_general_grouped_gemm( std::vector bias, DType bias_type, bool single_output, std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { - std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, - te_pre_gelu_out_vector, te_workspace_vector; - std::vector te_A_wrappers, te_B_wrappers, wrappers; - std::vector D_vectors; - - auto none = py::none(); - - std::vector single_output_begins; - std::vector single_output_ends; if (single_output && D == std::nullopt) { NVTE_ERROR("not implemented, D should be allocated for single output case."); } @@ -397,6 +405,10 @@ std::optional> te_general_grouped_gemm( output_data_ptr = (*D)[0].data_ptr(); } + const auto none = py::none(); + std::vector te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers, + te_pre_gelu_out_wrappers; + std::vector D_vectors; for (size_t i = 0; i < A.size(); i++) { auto te_A = makeTransformerEngineTensor(A[i], none); auto te_B = makeTransformerEngineTensor(B[i], none); @@ -462,29 +474,72 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); - te_A_vector.emplace_back(te_A.data()); - te_B_vector.emplace_back(te_B.data()); - te_D_vector.emplace_back(te_D.data()); - te_bias_vector.emplace_back(te_bias.data()); - te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - te_A_wrappers.emplace_back(std::move(te_A)); te_B_wrappers.emplace_back(std::move(te_B)); - wrappers.emplace_back(std::move(te_D)); - wrappers.emplace_back(std::move(te_bias)); - wrappers.emplace_back(std::move(te_pre_gelu_out)); + te_D_wrappers.emplace_back(std::move(te_D)); + te_bias_wrappers.emplace_back(std::move(te_bias)); + te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } + // Keep the swizzled scaling factor tensors alive during the GEMM. + std::vector> swizzled_scale_inverses_list; + // Optionally swizzle the scaling factors - // Keep the swizzled scaling factor tensors alive during the GEMMs. - auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); - auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa)); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb)); + + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (transformer_engine::cuda::sm_arch() >= 100) { + // Check if is using FP8 block scaling + bool exists_tensor_using_fp8_block_scaling = false; + bool exists_tensor_not_using_fp8_block_scaling = false; + for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) { + for (const TensorWrapper& tensor : *tensor_wrappers) { + const NVTEScalingMode scaling_mode = tensor.scaling_mode(); + if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) + exists_tensor_using_fp8_block_scaling = true; + else + exists_tensor_not_using_fp8_block_scaling = true; + } + } + if (exists_tensor_using_fp8_block_scaling) { + NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling, + "Either all tensors or no tensor must be FP8 block scaling tensors"); + // Convert tensors to mxfp8 and swizzle their scaling factors + for (TensorWrapper& A_tensor : te_A_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)); + } + for (TensorWrapper& B_tensor : te_B_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)); + } + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + } + + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector; + for (size_t i = 0; i < te_A_wrappers.size(); i++) { + te_A_vector.emplace_back(te_A_wrappers[i].data()); + te_B_vector.emplace_back(te_B_wrappers[i].data()); + te_D_vector.emplace_back(te_D_wrappers[i].data()); + te_bias_vector.emplace_back(te_bias_wrappers[i].data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data()); + } + std::vector te_workspace_vector; + std::vector te_workspace_wrappers; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); - wrappers.emplace_back(std::move(wsp)); + te_workspace_wrappers.emplace_back(std::move(wsp)); } // For now, we only have multi-stream cublas backend. diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 3bb6be715d..ffba5b2763 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -7,6 +7,7 @@ #include "util.h" #include "common.h" +#include "common/common.h" std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { @@ -177,3 +178,72 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } + +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, + bool rowwise) { + using namespace transformer_engine::pytorch; + using transformer_engine::DIVUP; + + // Check input tensor + const NVTEScalingMode scaling_mode = input.scaling_mode(); + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + + // Get tensor data + NVTEBasicTensor data; + size_t data_flat_first_dim = 1; + size_t data_flat_last_dim = 1; + if (rowwise) { + data = input.get_rowwise_data(); + for (int i = 0; i < data.shape.ndim - 1; ++i) { + data_flat_first_dim *= data.shape.data[i]; + } + data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; + } else { + data = input.get_columnwise_data(); + data_flat_first_dim = data.shape.data[0]; + for (int i = 1; i < data.shape.ndim; ++i) { + data_flat_last_dim *= data.shape.data[i]; + } + } + NVTEShape data_shape{}; + data_shape.data[0] = data_flat_first_dim; + data_shape.data[1] = data_flat_last_dim; + data_shape.ndim = 2; + + // Recreate input tensor with rowwise usage + transformer_engine::TensorWrapper input_cu(scaling_mode); + input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + const NVTEBasicTensor scale_inv = + rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); + input_cu.set_rowwise_scale_inv( + scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); + + // Create output tensor + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + // Output swizzled mxfp8 scaling factor dimensions + const size_t swizzled_scale_inv_first_dim = DIVUP(data_flat_first_dim, 128) * 128; + const size_t swizzled_scale_inv_last_dim = DIVUP(data_flat_last_dim, 128) * 4; + // Allocate memory for swizzled mxfp8 scaling factors + const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + at::Tensor swizzled_scale_inv = at::empty( + std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options); + // Set rowwise scaling factors on output + void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + NVTEShape swizzled_scale_inv_shape{}; + swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; + swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; + swizzled_scale_inv_shape.ndim = 2; + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + swizzled_scale_inv_shape); + + // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format + nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor + // for it to be kept alive during the GEMM + input = std::move(output_cu); + return swizzled_scale_inv; +} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 4b26860967..57eee86d2a 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -27,4 +27,16 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); +/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. + * + * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid + * transposing it in memory. Due to differences in how block scaling and mxfp8 store data, + * this requires the calling code to treat the output tensor as having been tranposed in this case. + * + * Returns the swizzled scaling factor of the converted mxfp8 tensor. + * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + */ +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, + bool rowwise); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c001e8e79a..51fbb50c4c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1015,12 +1015,8 @@ def _post_process_fp8_blockwise_gather( if out._is_gemm_ready_format(): return out - needs_columnwise_data_transpose = ( - quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() - ) - need_rowwise_scale_transpose = ( - quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported() - ) + needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage + need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 # columnwise compact format means doing 128x1 quantization of it diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index a62e10bc57..bfe241f81b 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -64,13 +64,12 @@ def check_nvfp4_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" - if ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ): + if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" - return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + return ( + False, + "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", + ) def check_recipe_support(recipe: Recipe) -> None: From 5be81251d7e1c7b3d897dc2bd376901a83355f90 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:05:00 -0700 Subject: [PATCH 60/78] Fix bug where CUTLASS kernel was not being compiled for SM90a (#2235) Signed-off-by: Tim Moon --- transformer_engine/common/CMakeLists.txt | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 92b57897de..e6be47686a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -155,18 +155,12 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "gemm/cutlass_grouped_gemm.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() - -# Disable debug build for cutlass due to hang. -set_source_files_properties("gemm/cutlass_grouped_gemm.cu" PROPERTIES COMPILE_FLAGS "-g0") +# CUTLASS kernels require SM90a and cause hang in debug build +set_property( + SOURCE gemm/cutlass_grouped_gemm.cu + APPEND + PROPERTY + COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") # Configure dependencies target_link_libraries(transformer_engine PUBLIC From 08779fd876562fb5d6d052e03d7fc0a5a91e1585 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 3 Oct 2025 18:14:59 -0700 Subject: [PATCH 61/78] Fix FP8 current scaling attention logic (#2234) * Fix in FP8 attention selection logic Signed-off-by: Kirthi Shankar Sivamani * Improve logic Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .../dot_product_attention/dot_product_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index a19d08ae59..88e28e3d81 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -597,9 +597,10 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in ( - "", - "Float8CurrentScaling", + elif ( + fp8_recipe.float8_current_scaling() + and _dpa_fp8_recipe in ("", "Float8CurrentScaling") + and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha) ): # use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe From 7e45be73bb8d513abe8785ee078ac88719bcd9f1 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Sun, 5 Oct 2025 16:48:27 -0700 Subject: [PATCH 62/78] Added the NVFP4 section to the low precision training tutorial (#2237) * Added the NVFP4 part to the low precision tutorial Signed-off-by: Przemek Tredak * Added the runtime results Signed-off-by: Przemek Tredak * Update docs/examples/fp8_primer.ipynb Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/examples/FP4_format.png | Bin 0 -> 50946 bytes docs/examples/FP4_linear.png | Bin 0 -> 54101 bytes docs/examples/fp8_primer.ipynb | 168 +++++++++++++++++++++++++-------- 3 files changed, 129 insertions(+), 39 deletions(-) create mode 100644 docs/examples/FP4_format.png create mode 100644 docs/examples/FP4_linear.png diff --git a/docs/examples/FP4_format.png b/docs/examples/FP4_format.png new file mode 100644 index 0000000000000000000000000000000000000000..8c54c33793db6d2bc0d00a75114ffff5996d5ef6 GIT binary patch literal 50946 zcmeFad03KNyZ_x>cavp5<7EM+LWjfiBV2NwVLy|BdA(OFO5Pb%srQ{XU#&MsD{wn^Wwa@%xjl zJLF=v1{}{h^=^}??#5WB#W(XlUESFrerkLGXX|3-bR&5|U(xZ)=V% zFfQ9LUE8OnQoC((HcJTCCn5}nG2$^!j}MCAIcqd;{-QO9|NKjF;dR1J^k$~C_S#=x zc)a6q)Zo#-yi@W!4aV1`(_3rZ85Da;j|IWW0lf0^w z?DOZ@zyG8tJKxk-%du|N@n2tfTzh5Jub;Whef#047SF<0PyRg5_n#8amHgWJdDndx z?L^DfhsH_4rQ2ffMiXySw(x7*EcP^pLytgqL4x?AL+=a93VMid|F=x^!t8!d)ry~@q_U@JEwr# zh#}2vGG~JM@gfE5hk~;MW*#OOXJNeB#&ApPL>sa?@=p7_Qm*q$)eH7#DOC&-AyA@- zgQ7f%l&@qESGl)2wc^g5uy=cRf3p^pR&{5%EMy#_H=Gf8s3d}@!txIefG=+fVsw!s z4QIbN^NJxB8+V`kR+%y&qrg1V8+k9=ulE%imUB;Eg?ovYath~&7gY~Hx<*CRv0$SB zJ?D2jhlg?7n~y3QUWf7IN^xoWLGQOu;5c*-b$iPe|8hf9{K`L&;$b%LXu_-?@^tRL z8`7#{$LwS8^=Sr;T2tU`0t-hJ77NFMi4=~A3LsYz{HdrzKRx$`{N8< z1jN1&8K=wFGp{tN*se!SY*Nk+FBN`XYWPwtVv9?JRp$zKTGdVaBjWl`Yg4UN$2jlI zl6My*B~j^KPD_s4+eWi?XS#-sMWYA{l9s}%Fp9iyY(};)L_scx4U2Yf2Xxf^#_|oRp9FGF> z-X^#8$mcSPSrod6TjG_|*#AjJ`^)}^@(x+NH%k?I*+%H)`Suvk&Mf;PjaZJ<%0w|? zWtZ_{%mXis1-?+sK8}8l>?#jYd=hf)iD=z(1oxD_YtSl^54J<*IELa(s~&7hA&yw(KON%g|LV(~9?lt}a?6;EP5OaKB>o`&zRH?UlZ+`^K=>laidVGY2HJI2 zVw%YSltFBALQ$@>90cS6R4T(Ip2$MwCmNrD+@@Y8tlPih<5MSL6LGeQk#$5(Jne5? z3DXGV#Q7fJr(V`Qx!TYAaFcA$*yfjgoAE@kfQZsJs)oOTZfvbXGor19@@183byZX( z%{#`^(5lVOxrUR*nyULS3d4n+$!kaCkp28%2r$Sa(uG;B`MlWv?Q0xgb9kG`dM`6A zC!dFZD4sSwG7&sXX&4o&2NL|3VtrU%3B)m_WH2-JbN5STrm@g0?H=!qsK#*irrgyJ zM})DK9`RaXu6{FLhtp|Oc_sYi+5CO>)l=p2FaeiUo+8rXk3Qw|viI{NXr6FOd18jE zf!Ua}0p1V;8fpw53aqa7sTJMJw8i18EXu=lf^)JrupL90=z8C}dmGCAcBA4vrOv;i zrSz~6_@!2jyasJP_49-vr)r0Do5k#gcb6BHuD@Ktjwgs583`u%VXtxGd%AcAU*w!T z%}&-zn5B*kPUzj5?JOzo&QU?j>%A02%CA=_v##D-eb6aMbR#Vx)m}lH+-$!)pY2Ia4lxAU*{(#s+ z3l>~7WzK`Wnq7n`EUA1!U#eCPw|5gR{g|l0-2cHVZpq~>?@w_MpGwb8kBLU`EK}cR zFG$i2`yjD?CU72!+aaQdr|1EQiG>FYoXq%8nA!W6ri|e3CC+ zy`W%OJfV|)3{Qu**#3bGZY&EcPkK1s$@ZO8w!K&W$?j;uj5UY8djP6j5+5FSGOGAI z8-}F2;CanLj;pv_%@JF~ExIuNv4%?2&3konRE~I~<3!&Y7A>;pt8tfMbXZ@Mvcrrh zqpu@98g75z zuy`sUoX4BrcV|ca!1xge;loTaz_KSf|SP9ViZw8ay~@FzuM;UDm` zBqn>+1WtUjXaFTZ+<(xy{0wAqn=yUTA)+nUEVTeh_Y{`0;dFuT)Yo=l^*(nCH8dx; za9W%~U+55V(PEt5n6g|OzNEa}4wIO5e1622+cGXZ0-jS}Oi$-a-R}{?A;IDj7rK6- zYfk4(S4x1=F$R%`nl1U#Csq@EUBN-Ukap~)QHk|0r%UJ)NvrDG3U{K5n;nEtoWyb; z@M8aIS1|)g7=u5R6^hRMoF?c+4YbP>C85Ve>MWc}i@{}Z?aUOk*a%yWitY%U<7-S4 z&jqA(6+3#7oC0H7vJHt^nAtx*_91cb5%0~b>y{po3;J{xZ>vt1adf>GzT!cM?)%jm zS#oOC>iI8KGFM>0Vz5ipry+XEr8jz=kwwH?2_w}v(9w?TNfR<7K_q9d?CDx%@5yK+ z9=V4!wll`RiR_gvtYoXmHDpy}prBD|L(bqKWm^@DpSBARb01$eIX&KglZ8xD!oF%3 zbD?5S8Q1UtmeNce@pnb=j6!v}+;cqixa{hPc=q}iyhclNnq1pp@Xe-8oY@JDc0Nv^l-@I#XUlw*_26_1z(Q(~ zzUO9|#jASX1IqWn;*%+@2=8}YIRKs2|GZK0L(GE2?!$P8B_s6)+`ahmcd~+F2XVv= zs-cbchZY9kV_AC%9b&!)tQ&Y2*N|1GC7hd8 zr7`TQvr*(o!J^72XTr*DJ)M-M_uIgVch%ws+O*7)8Cj&%tHsN%=FC&l`dpGs_X}>* z=P|=jUg+@gH|8?VS09|t)D0pDycg1n^dTxouf}zInoIxf)I7Uyp*-w~YU;{9Aury)V03I8O=d7@RzL^p{0><(X^Jl-uczDw3}M*_i*YHTx@H zT5pWXMrRizUaY{C1m>5q8`g3s@>((-AFqh3VmD~_G^M+YywzkhtHktr-SX>cIFR=> z7S6lC?a1%iWwj~#@c#LwcKOP1M010SP0Va(ZG%gT>cH0$4Ky4)ii<-oBTjTMNonDq z9Q_zM!R|=;Z5HPn;Q5dp*JNA_XIhqDSYo6szn|;^XB|*n({S_}vlAD^>}+XpG1QzO zQDc}vxaIl}rDYus-Q+B@L?kS;Y`HI87`h?nH%sdM&EsJ2Q>jlbOn*R#@Ju`xtX)c? z?wUGb>U(i~W!Lr-oYTLbIG|}-G;}+=50-q_P*$1Xv5VF)xdplg%yc0d<4?e!|!*&()lRC|Faf+SV zQHq`5c$$ooEe~SsG~cvn%z``W>5j_Ytkgw~k!JXB2>U}xt1M&*W9>8kje61zY2Z7wp2A5A|N?+yCi5-nW_-@CnJ9%zp;JV>*Y)OKH*F$Vhi zF9+}}BuK-vcEjV^QIR{!-Km&BH?*1$A zG^A$#{BA%aWI;dY277L>=Vtc*u-nLz^~h0lCyeqGMeaEu^~shs!Yc2lBlDd1yTz`O z-?#s>+9iUu>5wR{^syEj`d|G2f;QWYI% z_$^Vba!d=}_oTes=LHT|Peby@xnyKj!=06<`*oYPonRBi5%`HlRlmW;Bt16j<`VME za5i*uxD9uq=FUkTX;++*TJvYqMV7nKdax;Sv9^<&nSVv=j#Ij1Pglb8hk?bTitZ@& zli)s-YYBPpf)Iy$S|{b($kh*T-O}=;)>er;bdXna&X?3W<5c|0{E&a?DmjB*TRN`S zg<48Bj@#@NDASvC>cMHH6Hw8Z5cVa>6}5rXVe#lZbd%Pnu`2cv%5-BPY`Xg_jIK@0 z`2NPgy=DcqC9i}oc<-U|^1iTMm34mIvUhrfbVol&Euy#EYFR&$Mr!)82d4?Ow5uJb zwGjL%hXAm=dBGy^Xju36(v(*LIdTEDYq8#)yM}*v97Hb=-t$VOE&`)QM1BbsjBWs^ zNO7MMc19`PEOx!uVeTj~%<{c!sU2M}d$Y~pC7b$jok-8XooE|xEuF3vZH7KFu9qKj zIDWX$72+D&Uhh+W?&HtDWoVqjnb^|ZhEHc-7)H!a&i6kpP_o>KCWeZR>v{JmVf_R! zM^14eEB-7LMg3MSr^jTP4fXJB^*u3$0m&_*bPOf?%(b;Di!Paa7To&E57Ov1QFisE z;CSS>9+S+^!34Kq1O^(37=Nn6O;9j^Cxa)M3@`}Oy6DBu^VplzbH4$|H-)Nvls;_?e;)mDt9 z^N=R{+TP`zfp`fgam*g52K1bD)9Zs5#fD}>v!-WvRp{iFjh*{uR~t;9&-ZdGJY;*#mA$+Dm|~JN-jAZde2dx9^J;r+ zsYOeMkA?V3aZA1+I*MPdQ$@eK@4R;PETkqoczWo-p>U?OqX8w1ZDavnhf4`3BXEHXEKAFOG@o@l zMeD$9=t@ktl9o0GG=IjIvE9(YjU>*lfHk;Q6^ap)QWR!JzeZUK5{LfMX5{KsklZ~% z?!vRH@wS%raSC*?s4}Q9X0zryEr5i+4RdaSK7AElAcz)#k?6B6MviTUGv+TP`G3~P z3cDpu8Gy8?nSj|+kyf`G?GAlZ_I#H`p{}`dYnn9Wx6VqpWTl`cRaw?MJq7gI?cy@} zJ(j6|B-G`h5_Wfi29yRNR{2~fT*}tw;xk!Cx8Jd6N~fpWFIR$56wF44PCGJVJGFZn z5GGSZsAycNuEo^-jVy_kS#~=_1X>B6SJanYb#k!{yGnjMzwVA_KZ@33y?0ipj<58&Q5aQ%BQk8 zGdfCLCZ`|uPh^oM^f8IunJ?AZp=$J-;8!|!=SovX^-+TYSY>(ba$;l8#x`5Yix&M3 zkivtu8;0NQos}&7oOekn;`mj+mEy$JLwug;a;0QuC(>l6rjtS3Tcz|u-PkJ>=Xy(s zbMB$~a`+DH!QD#OfSv2JPxXFH?S||>l7_Q41#ro|&$z92OAm z0XFi5LN%{r%90cL?*gA3M90J0YC+e(pC^denijL$WmFH&^`(@0m=U+-XN7sQ92Yb~ zS6!Uir9$#G*GW-9FC^Y^kiX8W9erIZqD$(YDoNd**#^Dp7bRCqwK8_) ztMC7v$B_H`rIC0}cu(=i!g%*>X*2HF>}rm!*a0Db9G{8$em&qRjLAULqPVn!`JGL- zn=DMX>U(abuX7yk86A8pb!f&F_Ey=G^2N<$Tpyj924>NUN?Q#X(MS&QrA5TWyciwd zt9D2W-woOE-QXAg7G80xdDZ9^MSQW@P=8iWQ|Y#1G4fc*8E9N)#M?!(hc3pe ziQBhh$CK<>W&Q6cEYf}vt+DlanSuB6BWVftyOgkdRI_7PEEc-oJoVzuCA^+Ec4!0E zH==V#=KeP&l^0<5%|?TK@|+;ivW^CpdKW(^rh{kis?GaqlhqtO>)=6or&1wiDeBZ9 za}K-z(o|GkBl=j6E$xPZii=0N}n!_rOw8nKHc*f zM=fi|7+gVTUT#dyGGR5Bfyo%s71Um;(34UkHZOj)QT%S}lvX;Yy#PjA8p~>a50^iX zq}Wmiu6f5pjLKO$8@)IXTC#!EN57sU^->Bsx?6AN{<=e`yQBHJE|6$1XBl;;izzN) zm%e{item5O4uVRIrKSEGg6w7(=`}q`P;{gM^?XmbZQ1eRPhL$FEC&Lc4sa4 zoI5%cZ`*C~AQQ=G+?Fn+PaM>2pdB1jF9i3OMDEMIV2}Le=N|c9`=SE>clSdilPBvh zdqVQuzIP5XBrJm;46XT_7T^c5nF9`UbY~8w{f8*60YKC3CSz{t+R){!G)@>n9|@rd zVD5<3PoRf9(?TpP;y4`6r3CJ`_U*>^U!G2%cE(1!E+JaB9Wpz*_xC2aimqpi#T;V=e}^j7|rtG{6PdDoYN z1>*+Dm7V`^xD~AK^cYR;|Dc)K1(?3=-SiJMD$+19`Gvc$c@Dt4MHDkx-P!~VceY; zoWsc?ENt5Fl8$Hepg1t^=;-7mRNA1M!>z;uxDo#ZO5o4Dg znkcBZrT9fxd-{+Lozb(kfTHBYB2CuQ&4Iq(;-4%uh_G3DLgj>=L1kCns z{kCuuoPDnj#>!~xyvL~R-$D8CrN)ngXkvA$rTa}i3eXCE-ahc&b z1n3KI{l>j>!HCr=kS~itRe49o$$9V)lTF{>xL*eI$&?w*Ktk=pxrN+XBOf2$71JVa z)7Zaa%;gKGI2`h3uZ#;#DH^!HNdvJ<)oV(3R)fvIpI+AdOUs|xLXhIOliW^K0}SJc zyQ%LaV@JwXl~xc>W{Xt@HNbW_dtc5AU-JWiST%Nj*5TeB&MuGEIQX8=$5N_pi!ZP@ z9ms<{>pS~J$MunaTv2s>Yq~*At6+Lq-|jigVYw1lyN!6!qqx5tPU!d+c(P6 z6Yr#I1V;OKBH2Pm6Ll0J?1*Ynby@6hk8lHDQ-=jB?iGP{23iD$BD6cx*of_<@H=hwT| zgq-i$TAq)r=<;-#I%4XZP`E#T>&e{Op^F>wU_NEWMeV}qbUNkL!lre1{mRYndA#-y zpQ*cJ$o|gtx|Oit)U?Q~qNGnedH%2vj4zh@?zFIs@a<#X9mBW_;1h==r&lXU8jirw zRv-LYT;4R4A18vusiIeG-2GBHE9B`Q)$`Nyvsyk~MR(j9FT+Q2Q;GngAqu$+?GopB zEmbto4V_NgU%>6B(Od*TvU+iyviZQmOvvnw$#a@XK*u!2Y14Ou>VN*+V?WzhV6qc^ znyZQ1dh$u_koWgKc`j_B!HCI)%|jFC3I;)dhibFub#}@>fb_MLq~7`J>;w9+D5`{f zPYkJXoA4XE)48E;V#9ml>4$xuN+ia!r3Y1DQ7tVJKt$xg9YD~e&ASkPxKPUT1+wjQ zp|oR_25h{q4%VtOTvmlE=dZ8rGs^L=o*BZ{Z~9D49oew#nKguA5h+ax5BwW8S4liq z8azCFeA*L=ckQtGaij7J-mnq`4A_x=HjrpuU29dJtk#G|{>saB*Dd0l0sVIU_#KNA zJX*b2MG3L;(pEkrDFkvIrr+hcx9B z)M-<0Vdy_q>fxNYITNh6+W-qY!=l?bxY~)z4B~Ltnf;o zQwxw~@Y<{~IO^tGjYE&X^XA9b_ydT`eV%aX06_0M@ZMi#9(FAOFg_=sP2n#lMJcNJ zAO8uYZ468)qk8PCXFac40woON;h#zvGaWc3^Y~rg*wpXe&Eds{P|lSv6A98LahvOT zwO#hOu!?UBEkr4q-=fHpAZoaUZ=PgE$e88duqG;@vd{hrn7zUnhMx+^+iUBl1KadT z=6bCC=~{jc#M}4SPxNO75@ycl@6y`+dhbrO?iCOx>xBRSi`y#cq#)x87poV&t5jzY z+$^y81-qt0{~3dy{sn_4sQrS$L9aEEKiw<~-;bT1`dna@-xI&h$d_E|VOwSA#Y+)) z5##c!d%IF-eFuSd{m81(@p=W6P{9?Qk2#mGYa0kW%|AC)m!Ut>{rGl znlekFNO}ZkGT+DU?VVwlvPfjryVY9}*9gN%1zQ1BKoqLMM9OVNUromiPmLaNSVVn$ zdj$)G(26DexEt-}BZl+5Y)C*MX`+BA^Xq>3CCD|!Dm=5*IhbmrXM5Hu8hJcaqTKiE zDzY!Pa2jysF96RMwg#atOVru3r6-z^jX5BNDqFZZJYNnreX9ovZxSu-N$oI@lS=|} ztZR~W!4t{4fHu?o{P%RL77x>+whdqQ=7)4~ve*&z{-Sqw-gV{EcDKITD^zrwA5;=4 zG;s10P2)GVRjNC1jRYS#^a4q9AKj#RVhnUUTLrJt^7s!GD&K=QAZRIv&|^l+w4k$R zzdgv4see>sOyPj|b^Mc?dcpH`w(4-C8JCx@B;Q;F!W1wd_*;*?kQ|Q$W{M*s!+w>U_kK znZncugfzQ`c4t?J5u|kEUo_T|q&_b|m@_$pU;k8IfB$I%Ak3<@184r>a3*LctScpF z`!8DUW!5_nM-{J0F#1)(z3i$jkapX<#>W0t*u6~g3<%F-kDgk6``5!C>;W40Lg?uJ zLkn*MG`@VV&ia2i{6DGT*Fij?pC*(SkbClV3%fl`=)nWa;GrEu>UxE(S*_LE=@&Ji zVY!&-tL^P3W14TZ{06pIVXyY$r7+nB<+iDB&nuvI)2o%`K~gMNwfYUlYT&$Qzy?!!RK`M|I2ju z*UCQ1{!I%om+qWO5E|KS*6EGR8NaT zW!Rxl6p9<$P#7X$8oIa}-N_5c9F%^BKp?oe8i@BLm7}{r8vQ*r=gLov9{gW2@qq$c zYK!}DHnJMz@JVRmP|&Fzx7Ym}Bz;+Nwdw=H>J#`?+d%dW*zE+ALMTdaFuwQ7y^O z-t&gG$M-oklz#KODD4|Mhgl7c&;f6>j5pxDWwYE;oN@7JYcScY2^)A|(gOtJOptYtHImZ5I*cHq z+FB3xuXL;quW*KuJzz6~DRyzqYgN~sD5&i+yc^!y=*u9fZm9bWdW`K2h=-W|L-z!W|tquulJJ$Wv|IrZ}CLk3f6nEK+L>l(F{n_ zD`Km7;wawlOX;2Crde>#eXY`(i*#)!?e;I?&%;p`DkkJ6_v2> z7a!q{@)9#jwIwWZl$6?b|3&F{GR}9}@HPX8ybnl;y^5bzGvaFmE$Rzl{Y@HLvlTC` zJ&XVl1TVuC!b)u|>F}@FJU!gLJuV@atZK6DmTcR)^EM!9z4;`+m+erb87tZF+W_(` z#F2U7Ka-c=ZpwN11SsYE79oWUSBf5ernRS{#D&mPSex92r6G9K@`t+M#v$Sf4U-2~ z$!hJN4ORNE^_lc_fMEmnF|y(BXkoA4luz>Otm4i<(wTe64CKi0LxO}qlf=4{sX8Bi z=!6aIZtcn6UEt_KzNk%`xf`=-CST|Yz~@`bq*XD_kfvYpI5E&irYz~mvZba#;7A7n z9mc9&ZvDf`(2k9GH%^S!75Z8sI~^H*G63^X)k!{SCCfTWElUR>NeidF`yn_VQ&>IZ z2x&5j0WK3ts#!r)kf_b4^L}7=ob<)jmQja+I$CsVC%Off!Ym&6Xu+R9AQ4aECv_Re zfSCIR;wOuwjL#QUiG(kUO?}6!yHoALYjy|+6X%y4{5AsUxDbBDA*s6$fyOn@6lDf| zY+qbm9EGZzHuF_WMiUal5F0HWk)akWiLT6mnJB)K@4t(MQ1p}3cW9~&=inbv!A4P<77wFdzvVF{EP=^+LrLYG_7|KXvyK8Dri0Q zcRmjVp@7*1AaK+b?2^K(N@i0=x}^19yFyRr3l@1GqWm~OlTOe{>^5h)8~`kTG!{TO z+idM4Tr8cGh{3($pcRhofKyzJGPit5Xi?8TzE&0R^UO+1QAn=@px5j~3$>q6HPLk! zZBU=U z;ov*8JVDTlirRpBP^E0^)sN7rw9nbBvh)%V#`6ZP12b*xAyyIW)5B$zlMUcxu5Kyk z0AI%RmTF5x-?o97ord2W!z06lL{C?vVJnDLtB0v?5XrJ;d%0lh3rq=DqmBCVaC~sfm)y05lLBG412dwfuIba{mQ2A!2Q(HlIFE9q(Q9X;m?gt9bn_B#qR7@ zwmy_&Av|;sW><6A)c2kQ+vEB-%vSxPE%%}VsraDgeGnk6@!H)K2<3N|eQksoxr};2a1}uMYWlzM+XsFWdrY)HZd^5ODo&Od#;9Sq9q& znQHz%%cXu#Ak2j81_ZP`O9Sxwy13dKeyjBP>dm@EbP_Ys6_xN22*i~Wj>ac%+w)yg z-lFnTQf_nsX!y$r8;`C&p>x7&Svy3d@u{wRcYFjJyro_}38>EfeU!7JSf4oECAHYF zcjXy+;nUxMD7{$xhljzZ0xLZ`YJdv5QVArjj-M*Zwg(mxVC&!ZBJE3hO96<9U|*=_ zv;;K%hb(>hRkI!q(7Z0g-w!><)teh$xUoupe?zkP>)P*`Ts;B(EUr&45mzd-m-0Qf z>j=L^Zw#x9oH@mE&!T?Q!}H!K`!yedVsfMmpKz|mr>)8fR#MKv>yz6dy(Tei=lITQoCAr07P^lY)0s8DogW`xdwMXzkiP&5c zc|c5Cd6D0HWhlD*mMNQcB`Rp9n#@m`hjPTV#v1T1`f%mb`$AU8sF#^(5Xx*N;e~DL zB3{|04&BJ+2jDsWQ;`1tryzayFM@R4oW>tr8Qp+hdf!6nMBr2NZ6^XJO}8Bj+-Fu% zJ2cSXw`zO5Y1FDfY-iC??r5wepLzzQF?0Vp$1?e&Cb~y0PCHF8#;dKd5+qlETgNRqqfgzIzJ@qbQ|x$O8kD}p)Sho8O=b6g)p9_G0I9M_+-c#+ue%vrq5`RUA= z7yY+6Mg2Fh{C_Pe*Hghb-6%t$+N!5=HOMx0&K_2Ir_i}(DWOC%>yT0Cx87qf?F;M}PD@0GP{aNg8pkWDg7f zs_CZqH*&iu4tV(fR~osfoXCcM;c?MTB@~-k6*Ph87mOb^V+v@zYLc4fPl>$HUlDWTw||zAn_Xve@aY5bgaZA$r?`A) zpV1;9`bR{aRslQm(WK9NMgLMrp7^hXZo8fcxZ}+RWl<-Yd(#6qamqIVp$r#} zt*LMa94;Ajw(FCR;|*d`#6I!ec!gPn6eT+g zSc9vgy*HOGWr@AsACgT!8bz)2!^foaFX;{4HDb9|hXBzhQk*SOjwW_*EZ1A!&_!lrVq$){45 zkaQ^qz2b8)x(2jX^t+cTEzl+I{ zKPkDByE{ru&`4pP20CioO(us(OH8$K*Hp4yi@NN~R8%t?O}5kOi3v7H;TRbqH7083rF6Zi2c&z?PzO_;fsh#XBHDsfc?s_mDgz4sWe*5oF|^?1bgCDMt-wi!PHMQx0QIkT;| z?5ly4+*uu5)8VRzbL@^hAo#=6u5c|ORPM_N>qaR>nYQ@WEPV>rm;8}kYp&0+PCpbn z!7#_J`=Maxb9T3Y?^oK+Zx6evy=6-aV6-Y|zI^`HXZY;X${>!A1K6+z`eX3T%`DI# z)-4zDZIt(wt*iG9cFD8`bpw9X2#_CaQ$d$lF5E{Cv6 z-gcR=FT&MsZH>PYpSS609gP`vq$kR!l)4==s4VB-{qhMta{q=n^KAH71q@g=lL7hU zQ_Fxm6UE$M#@(R1V4ZB&bvH{$kTsB!_otpS)XV~AlZ;wT^jM*t-bdy1?;(eBl1WES zQrkQ36t@RtRXMUhzu|s}tA?Wd$-e7$Wcp%IF6+N0m%09@a`IP5sD}U}-+dQkV5JB5 z`dNr3QhVMX!hZ>dv2R-RNF>)L3K1t*M7de4Hf7lG3jeh^_68ZSUk|#dS~xmvL2C7` z3DoGRF?4ZT8f);HHc6B!6yiV-bfca)!OQk?*FwExKzS~{KL$k#992z+f$BUMLGlr9 zZv?h0_EjO{C8*{CWR%UgO6_-1J?n3xdd-(UI|rCRK`g}Heyk~1Cviq`Xh^Qy z<_tL=;D^;fK}??nwNhRixkH95qEr%zvjO<{xaJxg@PcWf-JU4Bz|q$m$Ry ztSy-6rj6`ZE$?%fRXsXY6zAltoFOfi8sJ-*DxHvcr>F^@AV&D??@lc{-G7mrt1P-C z#C&RL-Jg_lKHm}&JjHsc;j+IzVxq<*&1LIPzc(MFJFCZf7Y|$>NHi;I62M0d?_UCo z2iR1*31tj~zE;+eMtXs9F!u77w+;bnq78!_CZ-JwOy0@B;(zpUU3|C1hl_7_=z!+Q zLDnJgK{@0MNvw<=&hKbU-t=Q6{3cEWrVAI=`2$F4*dugpJfccmn zGOx23c*Z6FM;N>jVDREtJ%GU(ZKG)bgHP@<{)@}$cKE+G7X_Eq-2P^6fAinAznOEw zowFSd{a@h%gs-lR-C7anN2^jP^XX!6xU*4yY{~Nh#d{ z?1(s_FUu;^hFO}!*1K|xP<1IVCDX>s!i)X@T9%?s_((op`St5|-sTJGT32FJ9CgR2 zkz|W(hr&}Ab}wjO)Zmht#pLYW4!1clVYhBO`~WIG2TXvma*_+S?+s_$X^gAS#93&0 zaUAV|F!a`;>)Y6w_fqG%ga`+;Rc~!xnHloI9XOqnt{qlf0i4k3JAUR5Hn{#&GqZp7QE@EK$B|-@;#64seI2ck zh7Z?p&p4I^i!ae*St!m`Q+0FU#GPmBpQ$Q1`?yvFLLTEQA>^*NBJS<|DVX*qkzmk$~izl zfOR(e_|$P;iRm=zJV$#fPhPC2A4%G>Cx(I(q-S^~#f5GSScamTlRvM$c$< z6lR9SECJmnP8&e zD(hVJ5YcWE&4SBDKvo!pK2e*d%h&{dS!Qxt=~%EY^1bNFyCf_U8K>rmsWm1H-@-*` z!B=<7Xt=r?^51w0roA-KI@yF{2l3u_fE%v1fA2*cEv!(INdX8RuPcbu;U-C~D0aqo zNvb)5x`Y^D)n2M%zIW)8l|}j3n7`^<9Lmi(u;l;@tW|7pHlH<<2=28EqGqof5*wfYECs zFQrm3zNfQ~m&=#d%OkC&QSU&<1o+*I+$o}nLG(sFHV6JJBhc20z~LT{Ls5i!>R{OQ z{A^Ch0o#V`b4K7ps8k}lNj(G7ubkzXordeqf_oOW;Af2zyCc+BmfDG<-v2!wkff92 zTnc=ohiuyG(W5A=LB$B&TZyBuZ1MHL3;CCL7kEV*+ivz7|4{gH;2D0t&5wS`E!MD- zyGemM!3JHCF%oumEmfdcsur*ncyxG7iMVDU*DjbhBm6)DR_IufH)BY_;1X`O38FWH ztjg!L4M8heeHe8NDS}Yg#*WnT;Ym9p2VsCDAvV$w08`(V9`bv#2;Ao2 zSR%iZR16$N#RpWxgb&}P8Kx7334J0{|BILxrp%AIW~_qzEblPg9Fw>bVZlp22(ERE zB@)`WwlyW28}&PBv(2O}1h{;e*-M>pg0qb7Q1?q!VLN z9Ec$EB#7_C_O2^*_3BN@Nv~*rPNz7mKR4Ka`IU0laqaBAKroL2L|tO=@5luu7v8%F zb=gp3;d^>t;3rukzmW!O_B8bkC0S;hcyq>TicS6Nq)=@&>z-AH?K;v!$4lXsI|%PD z)!Kd=4q^N}&~yE_ub9AB%GhP^B~?}V97gT}$BSiQyf3y&b}}BuNs@>NxDQKeK4BDu zqXr1bN#HyE4anhrw5%Xz{3pJ!N1jM^CuDx@Ug!RqjP4zGtE4IXqki1sw?V|;+)O5Q z(5n(#L{8KBapJcn(+!6hx*n$-k>pOlou#wg`p7-sKePr1&Y7|8Wj2WMOD%6^c%gAuxQKG%4Q7{MlAl%1uKW#95g@$umD# z!KE7$1Um)bmhB{W>+-<*8)f60zZ}Vw<9@*?Q$8QL>C2MRrj(HAFQu|{-d6|Zi0rc$ z+?@QdKI^vpgnuy{pATVdmMkTk?seO@;s z9#;?tWUq9qk1YqaN=enE5QR1h{nD$5q9dRV3@H2H+YrAE>)EnJb!;Y8Cor$rg-~hg z8SDHIC{C=oj232YBjf!W*Lc=S7ARYDlF^$>oBp7#U#i2NdmNH4t zJ8ML(`;rV})Z|@v9dUS7d~R}~z!2B#tXfiMcXoDIJh&S_(AN5V=*IX#Z;*}>(iy}u zG(2x6Lq7}NL5gyV#fUitD_FggZCvma&Km%6KRv-a!-XY8qrTQmJ-?bet0XmMZ+6_X zb4~%~^meX@ARwrBJw6|3F27d&6Tj5D`EFPv0h+-1FrFJT7dsmg&KVY(t1}0 z)?z`$_YWoZC=z)bjwKEs?iyahf8&-H6@2yTIbhYUHL^xQA5y!FLP)23L{59ew~T-c zJ=+J{AB`;Fubi;c{K1+j)wOqO-PQiVO5GbQTs!~ScwHJCR{Ce-^|^Qdt1;;TVH_k| z!1Qv=F0nky(Wbn-Q1i@f@4;BG zzYqRP735@il;1FbP?fIr3*~2~EVO#`RYf$c!OTV0Il0VCSXZ{o`II(i_nV;P+PCp0 z-NW8TdICRRQ*#Q7Qxgs0SU0`QOQF<*#3Cy-)|~O4h$0;8LJCoxq___ZVOc%>-S>*%aD##_9+39RuB5c)aW6 zzOlzZfq#9WK8#eX-Sa{ZTKc}DwS2L5(cS1NY$wRL*OL&H#P7Aho-ppi+CYqfk+kP^ zfD{e)vAXgr3IrJdUGS9z8OEQKbKlYQweozb)Li(PKY|Qz%~0L-sEU#E+av`U*RZF@ zYGC9bNnur~Y<+SBvippHYh;LZMiDQpdl!q1^cPr4AJ+oWm5Eq8-!H?}5DR^5`I1vn zoe^Io#`OLKj%~UjvN}dqd+hT-*tQ_GYze6bPR*T_czp3PIv*Vt7RH>NY{QlNkkTyg zmSfIeO=4XvAkL>x?icC;8KGUJQ`!Ff_(47(zPgz)l)S|>NB;&`mrveL;%y=b~a}P zB&)d!6;Z%*X3Ui^U7D#S*u|g`>005FOY}mg#>?C5RLL^V<=xe|FUr`@2{Ju5#E7#q zGBF{wiASfZq52m|ZGsm;Fj#jsVqhUR6Hl~=+Q7pQzr^I0O^f3WO*G`0Wa$N$sM(i3 zYx_IM#y%4kOzZ(d%y-HU48IM_ka?uuI#u*(El^}TV&9cU+V`){*~=r=UqVomLv5x4 zi#)WQeSH3(_Rch_sY6}EwY0_73W_47Oj@f{5YQqZB1s)83KaxIP$oqsDwBeM5<_aO zQba%tf=s~yLh?gfJvR1O$O3M1}+;B)Q)P+Il*iwa)o**E#pnA6=^y zA@1z#WIx~Yz3)ERVL;~|Cu28>e>~z5z55n(h}Rj=gTb+F(>K=B#T7NaHAeC$^|Dv* z!%%`3EJL4bqEFiSNY>_vu38SX$rG4FGB}Du+|!&2jJx-e)v;9KGTbwQv^^-K!m7_6 zs2lCx(uPT59jc)b~{aZ%V(iC93CCXE2^kW7W|Qcijj|0Pgs!H_K4J)~=?26}jd zq`AC8{KDAPud`R@Hd9>K02EZ=*MWj`lYY>SGLYOUwZz-8hgqI}xWFa`3mQ^w3BUO4 zwLBO~69=YqNfTY7vlB}pXyqY+LvO0o#Ge+dN0s(+*N}G`UV?|DjG|;h2o8FPF>d%C z1LyGaHxLi*_C}#vc=3}Mi-8-TKeE%2SFCXMN;GsoVFxC4lg+`g{>Eao>OOiidR~5F zGe8>FWEhL$^wddP<1kSTq6?N`N_#H)h18z9b3#`dm1PE5P`>DPo^13Q08u1{s4%C1 z%3^VBVJ_lV%u$#Mr)3P`aES(TvA4$h5#=x~veVOovl+ zUq&zS7qJE#Qa;z=L3f3qJIcsEcqd)Q9s%#wu*7ZnNdc8sV2145&_yU^h^OC&ZWN@} zL99y|`k~OQQ<*c3(Cb|%7N&@G*;e+SzUdN~rvX&143A(|Vgr*Q4T9cgp@yqB+>FTy z|9MqC_$fEmFf2};G8L&Dke^ALR0lO(659DDn^YM3cJ%Hfw3@m|CqK5lphAqARZbJ0 zmn~7loV;gwK3AA5zY>BBHs!%x8=86vt;qMc$#$JJeekxoFR;N!O4qXXl_z|V-9$3U z+QOLXH4?yp^&YOa{!P~t$<&grN`>OERg}+cB7<++c`bs~hL7MgF=RR;e0|N)fEB6T zan+#L3e=20_dM#i`SID;X&7Vg+Wyg-f_NC{F?f|{_Ys}Fq7)zsDF_^UIH&5FH~{gr zwrh4eeriux=r0Yr7qe%G)Ji~v;;u3T7P?z(O5fag} z;rXXSk=GwgYQYIGRBcJ_b;Crk3m+ZrT*BchHKX-I`f8IXVWPosvm6%iaJ`aVw|SbG zYvI$p;jTzl9YnRfC`g2#bJ!`qZBaXC@H@_m=?TQy@3{ffNPXzN#PMX?tAoA}V~rQ0 zUg#YY8EdzF>JiD3=FETIZNFEl6e&W_^M^>Gs^|QMi&4scXz5r}UGMI@r`^5d4d7(> zb~ZJ$hu{)Qbt|N@)G6EG`sgYM{Foa6)Lay5t=VKkXo&Rls?#Uu88H#OKAy5JWJCHA z&W|H|u&XE{Csx?@;VUCn=t=wOhJ=cgjV9qHvTVnml=Y(V93R>}x>Wl>XVL=^D#P># z6{2uH1aLrCa+Fo7qj*`YS%vO37$2|$OnUZI!_0t4y!oTBHEQtcttVl)T3cyPutON8 zyZiLV4!9=A`Op7^(mw~3{>4s_595YCeg}0z9<>!a(G|ggO*B`mV^%k>M@Rm^z>_}p z@{vTJ8eOG_B48k#%k)U~BG*FiXjNetf8<0ySA8AP7di6QGb<$OB-lxC5?SX;R=&|` zpC;`L(~P8lPTF@?5d%-YL)x?c8EId2ag+>yo$g01>5H|S+YjlWumwnNO^fVgp&(4U zQ1UV!jRpSkE_2S_oU=FQ?9DlQ|8JeW z0)Y*``kGZ6!Ln>ynK#V?QzdMd5ft{PZ|hf#OV#g2A+E8K*OxlO1#PXnW9;E`VU2Dd5=raSnX`$w90ozpG;S9Obzmf|OMH3S4ze1L2ExgyF5D@<0JC`+xa zH%P7h<%(?zsGs|6GM8moDk`--&Kz&+4hZL<&h&jLX^9?pf4o=ZbC1a>qML%{=i`^Sy5@M&7InLj5t*ObhVvD z>#`U3ztG;JqYpFjkgGK3{n{nJx&Yzc*$^G=%w?RfJw^rYsr2f^qCV6auC-+FHX}#1 zbq%Ke{;ekwoXUt6HX_8`$RDPh83_G(O#zVk1x7S49Y)`3?5oK1)*>)FjY>6Z9rVcK zINp}B@J?s&-fIUOecOb}#Ssi@qj>1-lTust!_?dZsoSsBWi$POQ`-Vgxu8aL@PfGL zkQ8H#ZQg_WWRn^t^wdeD4OYM|u32rHdJIJWo7EJV$s*i#2+6;m=ut#kJ5~Eg7m0H> zx~3VJvGGDY++cbmLD*(fq=Aqcz#!lI^k=H2Q}yNDn!JJb0Gk9UF>PHbk#geChWXp~ zV(*Gn;EvdWF1Zh?f2-dfOhnx*s$NB@5ceYWW*ubdDoKsHGg~ILr1;jUXov_)Iw#xt z<_A5?&s~%DE2ao~(rL`TbK;*{&>K}_tC5tD0RxZS(qAG_M?$$c>Zm^X9SrY>OeWwc zB%kOhtU?wb++n*-)hB4;JQ8E8k-lB_Fx~SA*E2{Bq2({#0>8{uNlBT`s#O{@%SAD? z3K2@CICACRAk{=8UC>+kb)E;j$Je?xVeIoF_~a^`>p-hZV543?yMjoHMav!lM`xBW zfkRVG=^%z1J-D+Yx|`)W-HSX4oC~%rudgl1PCOw^^{7}~AKJHW>YXn#F!Q_*)& zVHv`sbqjvw(XYp5cr@SNaJmY1n5(}YA_Xj7dXh%wu}uIXj%424eeVxKxI&R$3HwCv z6g)N4Y`MhRw=ig$S7`IHrU7%9;Vd00zMoxR(o_@fewPzEFhxgP zZzM=vdN|9i`i!ZIe8n!$HO|CQ*BE5hou2y%t4r9zjpwQ5?i7rI^f1&-Ha$ zD4`CDgxQwx{3|dnC2OC@dezLcExD?S*?IMZBz0?KFJ05GVN#=m@3AkF8EEHgF1hy1 z9(pou0c9aUEo;4})nnf8p7&+lc*dKHM$Mha$fWV4W7ET@Py^A~jD3Kkb6)L379}=s z$ssPuNZt>!EF)CMU$ux0R_f^-C<{#=&_aior$IIf&2}tOmvFyZepzxPYF(N!re0Xi zJBiy$gg@XEsAyQb_!-g?n>@P%iyXJk(Y!OU5#F8(CAwmS8|TVKjT(;q1G4_2MxQ)# zG;yp-1^&Q-MnKn7O;z|q&gYI0g44dfn}^m7xhcHrBIlbY>KBn3@P4VS{PwP}<;q_+ zyKApJ{A@A&${6?vctwSWSyl+WmH~qHZ-9~4>wJQo1;MKEZ=`^uAuEOy+n$Xtx3w02 z=ay=?B`09hmYk$a2QS{Y!PWOy(1xP&@Gu)ZTjL4J=G>c4OKb4yKzN(Qx~te&#DFi-q_I`^#)_Dz**E-xdL)BLe+tBi9m zROe{Aye+~teb4ucf9=(*#Qcc7mm_a^zr3`?zB*s!D&QKK#F zil7M8#OGlTkQ{SlmkP--H}Pq4In~fmn`EpUe$LKk{SM|**cj?$Oxj|H+Vf0e5<|-m z+W8*{s7jlN2zmD7Q=rvjWL@@$g=nofBE@tT7f+WZPtsulhXFeUFoUQm*GtwUo?(g3 z3p@Qn&-K(yW8f8GT~+x-g;}19`t2VWKX- z>-tH|z62Pm73BY9y2%zQzBMp2J>Ljj*0k|8q-?Gq`4Cogdeip@TYZTQc12!c7gdPp zk&Up_uyM!`K3BNME0Na&TJD-f%X69JbGPwxxAAkg@!Gp07a9;Ag~6p^*WA>J{4cjX zR#S|~E3LE-Va0qUL9!?;O^bn%V3m&Gly2#4_D()yWfdVB8#@-KG~Bpy$CdjXWpCX> ze)9aNJrxD-R;uPdBpKdpSdKRVygB-bd8xq0cjtF4Ti=`auA6-i%0pfAjI6Bw@ylP? z!GN|r8s#H#jiik@OIXR>zS>5tF~ zxzRD+qcxq|QhB_pjb0_JlWpEgczAWZ;qiCx8I^~R5hX}&OW^ZQv7ii|pUZ98d_gqd z0T(ylq3((b@xLRYY0f8~^U32~>R=^$&L{s5@yYj04W#=PVfRhZ{BR?U?3_P%)e9$f zCL3LcqEWc9Taeg!IEwMs?^hYht};jVjl+WH&2TG1cxvpEha2BtlhCH1 z7{3ZVVv(0Dv?)lqYhIl^r48H{ghQS(&t26Y*3 zlf4$8$5pX|c&P>-?neA1^;}->?@G1pV=6(otK68y^O1d_5Ri! zO6Cr!TdY@^A}DojyGCE=Y4!G0FKEqEa6=o>Cd%F4s=aoxYY2BO;Cl*wY;-Snio#q} z!ySt8>zw%B{cB9V`$%Mi;+YM3iNGH_hdUw!1w6X{g-sT|;wy!|ZwD3EkPRqCY9N;L zaon(9^3%ticzphSw4_>*ZxM=5ht>R6rtr!=MK}Aeh2^8R15T}#7yLMsIwZ}QH#9`q zonorB*%nshkjNZLh~Pl>B5dj>j@FC*B9Wx3SmFW1LL2O_q`1E+@)xt6Z? zKkL8xZ$Rd#^~jO&9iPqT_&BbL;}nHD+92bbmvyfk{fgXeM1y<|mPc?$Nf1Wh^c>{p zNUI6?6Y%PMY!cQE6y9sOaJu=Pefj4+15Ix)V0vW|vyk?Sj=odOCcfk7GEPu~Z}Hhg z{7)*JhND*IiWs%f_3;PM2yvIU-c7-i|?vJBPmyEx%tE zWA49KK$dq80MJ-gcEe`+fliHEBgJ!@l^mhkuuI<%{63pOjhbj@Zw*{Xg?!rpCW;l7oJ>G-4_oB zjx+elXD`P!gb{v_q9O5mZ7_9dw|WF25=j~`nj?*S-VrG>z5|A8W-2YO{`Qrb+SI)K z?*Wq|IUrpy&Vu1mUUVnwi=$A_LQjX@%R)2UJ#L&a6fmOX+aK89W*O5fT$1vAVM zj)?lImIh#*LkWT6W%$9egOA?ll!XS6mTZC-r!SP-DvlPAbg#J=PCwJIA>*;KM7JeM z5V|_kt@^T}4M1er;DFvdmMQ_+Fb-@}JYt*gsfaJD2nJFTMdv0joeN^2HVJZ^89_js@K z(H%}m8c4#d76K!ASBS08<@*IMtCLwjGf8p9>3vfRXC$Tced~#EB-VrB*bNiFbGE)q zd_a*lV#e3O;|kxmIHTguS3Mn(DYR(D=O^n3=J}+i!RX%&>3=*hluc0uGex{JKlsnl zABoFZW@*1ZauRbe_O?*Hj*y+lsKEDNs!cVxT->JU2Dj0>wnqhS!Bw}IgOqoo9bup{ z6#Ec=WdBTRy_B@@xVMzm7E^DV64>`CPjhpdwxmlsy@7L>}(Wcv{n2vmg>^FOGwFHVw*3e+6z(7P;#=e^RBLQ{JlSHn7$_eCKS^;|6VA5 zdCp+3$vV;UyhK*_6i;Y7#av?NGK}x3ils(qhxNaFO*e&ZOXn4NXr2X3bW`9xXrcPo zOfak+C2fStuidVjRB1isbh3B( z;N-(lA;p3|+zXl7^vjEAZw@=N+naji3MIqw4)dQA6cqWliOV?Src)W7f_Hf^d1bHd zu4x{e$<*%ueeNltEuS6j`fj^aNvn4-B7gX@G?>RH$tOwEqOx`WH)a6S@$oRLd{MZAhm$xwP8z6%KDWA7z9z@*=PH!I!vV8^8bXRy!CwblK0;G z==84bTlNAxM=o7E=6V|E)TTKmI@{Ce;I%CZb6*=S6Dlx}oC`Pk7p1*9UwPgKbLRA4 d+2j*=%n|eMq3uoc-h+R3eQUctf9t^?{s;c{mD~UT literal 0 HcmV?d00001 diff --git a/docs/examples/FP4_linear.png b/docs/examples/FP4_linear.png new file mode 100644 index 0000000000000000000000000000000000000000..2cd4511adc96f04320bf9bddd2d67376877837ec GIT binary patch literal 54101 zcmeFaXIN8P7cQ!Z(t8s{z={wNK>_JtLlIDEks1&I>7jQJX$mR|Mi6O=6e$6uB$N=O zh|;BZ3`p-Kv=BIR!FKEZ?sx9_b?o`Qvu6*( zjq6u5_UxgA@7c4DhWY^bfTw%vK)pg5z%PNbksdgOkUof5-sbDO2~EWf7-vx#EW zelB*AH2T?qbD?x!kG(x}_RObsxjU+_J<{{iKc#;XdY?Y9B)df4FCaFQVE^;R)TXr9 zqOTu$3bB32n}uKT;hwz|l+?_>{y~2dsX(Z`ZO3@*KflSR41}Gf{PVZaL*>*I6)2W@ z9+m$XhnX3TJ@9+jzs5gKO%1n!(>wdCJdkz`eKW-5;~t{f-c7 z{hx2+KlkV%(u`6zg!Mm%3522l^KJi5`ai_?|1(Le9Ox`}a`Vt!>lX2St8`N3A1dK* z$>Cmc%8#qeq32Gk-j-IUu)#OujgDTM8`#N=xo?TAC;wJJ=pW_zv^F{9-pT^5IQguR zTQ2<-DEgFJ*oe<3PW%69Ihtqr4Q@#KSD>6%W?lxtxK$sq{;B=vMu3?(vz<1$VIgAh zt=x#_Po<9L*h^3;e#4IbkjI6-Z=C%PBdZU6h&lF^_}m&MPriV*$^2#_HQcUrofD?Dp_+v>D@6NPT!yTG2vCg}`{la|R*bQIFNqyZ zPn{wK9vIb9#A^5r1;L0}`6qzn{JVDqb`uDUwDJ|m^OO>_htU-NCKV`qpD~3>s^@ZN zy=VHCdM(zr<(4c7 z-m_%lh{;r3R4~_Ej_t)6Y!cUlxcTnZO6yex?AYqS!f0oR+wAs6heM}<-{p&=YmpT- z3QOI{CBCh$jgq^qBqn)&*_m%|+tGrv!cMFGtAWF}N3PZx9(8F-x_GX!+UxnogSvuo z(ul}<&HOtzzP|ssJy^Ux%6nANtUQWLrp&Aa{6VCEi;Sush_tkTICgtK?aJC<7zWn&o+7 zLU?WiGygd}p_%rc^k6+3-fcQnop#K3t4asyNC@ACmCkcyTNC;Ap=xBUhLP3UcYR$y87Mvi~8ZB+1lbRjUHn~gPy%JCg14dxpWL&Q1cU-jSs{{LM|^f z1;i@bY{$|0wb7>Di_4w6$hA;vk6lP9)Qyv9b;hNmAf-7a!f)`V)I;?BpN3!@Lm7%3 z>6!8y+}e*b*v?BR^zaJ0x)A2+9nkECcdvb4+)*{r*@*XW?75Xb44Ub~db@0m>AHAQCU``mWmsnSJ$l%@+b&{1T zQ0}u1W{no|j#s=+-oS_s4R*?fV~a}qs%;N&Uw+pjCA!g~W6^^N#VtYpqsKWiTw^wr$rSWN8N;gpV@xDa)&!+Q4p* zuCfCZSAZn0%=6NzP$77Au@N=0X^((1716cPsD&anNiz$g^|h!jw{`5A&bw-vY1W~` z!9#B!hR#*jSj2C=5-2R|EIRa|bbW?~&(lxl?l|%8?e@@!HW>VXRCwc5x_ezSo|n8Z zBC#Oc(_tA&BBskjlq|Xw(4WQ8>c=|Z5U~$G1I{$oSJ1u)GN&CzS7}ZVudBDqA9K+` zSNYb3_m-0U@VNz|cDY&f??#w+>bSn(lZaph1*5xNIexA z=xOUNgv{T0*ZkR7nbE4$#rF>Wd3uyzprL0)w|VH_URRTLsYD&S`#79b`ha9x$yMg5 z?zOLPaW#J9Lqi3Vujlti4GtwitP694EAeJool>1XXHGZjJ3NoGr<`1U{jO0;1T*KY ze|y0^dGTOCymBn2-)J>uejs3=^ao3#;~vhMna&jbJwWescJ$u36i9qwV`b9tKNLS8 z5VlYniWnI7n=5Rkt2ktx$?Lm5Rac3@yAd42lYlPVLsVOZt_{Z7QwW9Y5??0{b*=S` zo)mF6y6e)eUBT5TFHqIV@3zvJowAr7(KR?*GRuR;PP%2md27h3yrnDqD*3iI$5uv5 z`bUcojvbvsv!U^XRYFIv9ika7i(XPw0ywiSd~1Ov6~e0(S@w4gw6H!A=seWV4xjBX zhfdIxI_<2T;}7uQh&^P#;D0ay)yVQhf{HBW;Dyp*gaW(oGJ*@EE*{Lhy*zZJatA-O z*~yQde$wi#ZeVXls_}MJeCn%f~yu- z;oA55vI5fiF)50DY;_ic6P+#@o1JNJZITIh=8S3Oc4WLt~$K?y=fX{e86~;QF)jS3bFMzmr`Op1ug8vuk(9Tl8l#F%2)F;d}{#3oc ziPkw7#QM$){_o5B{$1v8=i%So1n12E4z9|-O8c+ULg(o(9R62nDXIUz5=Qtx-Ztwa z%O$p_zRuUx*~g}8x}t%ie!ceN`(e3@-czRbjVZiAfGoX*c*Bwq;0+0e*iY@&;f(mQ zp&t4xug*QA2=77S=4WnA%REbNQfpZ7vq7w&dq+pTpJfT_1;5)us63}ECOt|+)Jocy_Fq# z`~~Dm_G|qe$+_eN<}-SxracmAX7FhPIDCqJAY|otT$q~qprImGteB7^M-2}~SZKx2`oH~+2kMI0me>)jCoB5tO&vUt z;W+TsC+t7$-%F812V%^UJcU$*+8jp)UW*ZxfaqeCQ-E|{LwFezRBVM9fVbR?eTYnU z&^W@%fBPsDw$g|Z) z1q84`_j!f@p%#0ecoWQf3rK70njZLdG4RTqD^0wZ(e$v+9}YU`3y(i$+y`IZUoN2H zM7O21K+&mZV8(|8+K=v~;DGi$D(GLa{417!wdH>wBRV%nsO@(fWHZ``Ay>3_8gVrO zFFI#@E92zbFuQ$p@K0}9_p|VyQwG*86=RdC!v7B7>k;3Mou`hxKBLQuF58$y4tR_O zT@B-GkC5J48?N*v&2XJ9nf{bBTRhbveah@OK}E&KRPrRh0So$qi||AUK#siAAVOKK z>!MD%$Oc8cBA82EuL!b&U-OwKNHGa-|DL=4cS4zSIgj#6mF?zCQSOY#X`G z0S^Mfu$GOIp{XiU1!gE-t?&1=rfka(4~r$OvsJdnm+DT<&92oUDpK_26=cowYC>9{ z%M|+@fJ*qMs2tBd_&?VQB38_QJKTvP9~GS14@ zlQ#&a7B(3U^nnP@(4B;Zu{+qqdDkRfDHnnYy^j1!Zg_!Y#xteb0k1inF;mo~pWEU$0nCciw~3qBh*q04v%CHdA+QnVXo;$3f?E*ZCI=|j0z0cGRxMvF7{Pa6g%Oq-wZ-r*#(zQt>{w9Tz%F* zDWcQ#3XY+d;nS|6fycD9aRH)o0O%|IXtMZTib_BQf}N~~(oPUm4AfetGN=`$;pohq z3R;s$H3^z*&-M&GnAT9|5x2hFzCrE%XSHWy_Sn2r;{Dy3l*qqQZBEr^IruaVs)+k?4n=07OO?<uOsf)VGoM7f(J2MB&US6 z+wLP}_O`Z2ec7_}^`aox`5t|*pf=8}zkMJUBgAuPO#Gw)cPhFJB2fAuX-|7p(U=vW z+Ktoja8YeStx!fLD+Fha(o>l$-Boy*xzeQyfI_`~aH5bo48DQ5zk}+|T+#Qj8g5Of z#mSdg-5J=nAcL1j`*9%1y{q z@G|H!#bb14Ys4C@Rk32~hCXb+N(lII5MuGDYGWPyT{?p-q$cJ4Zv?{}O0kztmR(I-YOjm)g9jcGaj&+n3@aRzC^at0(kI2Tn_Q>Y-sip?!KBP#mf}Ld-!Qp1*8!jK&SW;c<;aQ)Rk~~MVm9Z{v)~`xX zGM5ud%Rb4WuB~g3qghDp$%@S zrBE3|Rtc@~M&q_t>5RqY*P=BJ@_!T^4Nt<3S8MTMlVgG!uTh@EEaT*55YL#azMP_c7+J>7rIWc;4_xFVsJ zUM)rQ=!;xvUw(pp(I|Ll`(esk2IwZ1PRIrOX`zxc2tJ!*;y!8G!I**mbZP3uY|&~q z!3~c;&O=%s4CLD!dQDJ2la~}v4x|dXj39O)$NE=;;ka#O!p7U}2;HDc)Oj=*&WcAH zoQ@ye466|Q2>b{>n4xpv!-_~A44tFlUadLE8vp6DPL|Tv(qeXS)7E&J=RS!(tGo~e zSm_4$9KGA1lZE?SCxXqcY-S@^IojK;GWf8XF1?KJk&o^()0p;8^Yh}8Ms{VP4sbn_ zWMucVsKAmr9IVf=m%$WYSR4df@ODLP7S4YlCa^un|CmQEh;y|`UA5g(29}}$LdAbu z$^z2tm1Y*L+C9H0LswmJI15+g*6JPHrY5f|T|Oj+-YwDFPY_Iq zu`)??RrW!2>C7;fz?&qc1BRB>q@bjM`KtMG1|6o5`1vK|=W0pei>1_XG;ZQ~&gx5r zxHs*v3X~E$;({$mWcP*n#UV`F3-qYkXaqwMLt!k}XyBoU!eJ3?Ei9HxhR65}RCV~q$ZCepRy@?ZO!H1r9zXBtwRj=so-Hnk5jwypKC z5MfN8*ekxhLK5lZQAvC^t!{^*TfcM6<>#_9K9hyZm4{t@ z?^}!4b#obUy^dSVIGQZrbQg~_6Vq1pWhd1BY@j-F2di-_lNzq0Ke0tveCx}*<=Zc? z(r@~Ht#cDGFyET!C4K*JY{$@2`XtO=hz$+Hx*n9Jh%~FG;Jx=uW>^HGzh@@cHnBme)&PqBH z?Jf6a^W$;8p*4?m^v$Y~L%kSk>#SVIwJxr+bg4chH`LfUweS|;ZWhxY=~Fg-mEcBD z!PRB(pkrMY7EHzjW;uMFhN}%%qfc)i^ZY(n##ZRNd=s3|Y1WvSc4RW*#cPmJ%pUED zln!ACf;nBBP=BLH;qSo+SKlIhV&48fq!h**qbMKvkq#bl`k6=bol2CZuaD!}X9lj| zknA$wE<5r}2%|0l_Ud$$Z}?Ys|7uf715+hx&A_t?CDY*qHNK(;pH<~N1H>!tWDlO{ z+Z>aL^4dIQcbc>~X2)FEXA`%wbrx>q=e4%RC4!hgcFocHZV z0>Tw?YrEmm3YBpG37e%?E<)MS$7F>zHb>h1?Ai_FF^&6%t>PP%!{SS-FIS+qb9}oV z;7Ao!U868A1Fq+!Mf*tkd-q3bqwU$ZTgj&t<|1(3kXH+fVxLLmN$&w(5uQ=ERJBrw`S;EY>PkOX;U*%**W>Md0cC{!hkR*)BR$ho!j@u*ne0>HH$KvX} zIeV0s3vX9M4o0>xJ*UPO6Yv$49ers90?L^JzU0N=gg~~l?yHW6r&bQ-S+!r$#)>`G z_Bh`6Y4D?hlys@%g3^rBM4^ez*jjb0o7aVu!PT&JVTAPv*&%dTCOcs5v0iq0_C>VS z*$k^tr5mwz!hFP`mEaEssGOUbc#e4MKJEzON@8h-J-NGgO@Ybh+FUB24u|%JQ zwtm2QgE8x)!KPd9NgCZi829E0aSk+ls+i5!h+_tV6_2IEN$g6hQrEdp8F+jd=3{5v z8dKAIr$VoVE&@j17?TmGEbmR!xqXN6d?^3VmW(YP-n`F;Op;NvT)yrT!iK#yKTO{? zbkf}N{Hd3c39i4O&0wS<*cPDonFND)_dfb zoLeB^r~DHJ>k-&E2Jr?L3*16{F@O^G7 zF3-6o?~<#D3bfXP+efqbPj0pNeYVFRieLKbNAFiUoPs3s`5xijc*vl??Wk3-dXtS% z%dewNP$9qv&Pd3&aNwmcM*=F@X?i{&b9%hvR;f`uIgP83Auu{`Xlu=Tk@`&<3usD6^Ddqu6{Ix!@^$M1?Cc|2r z1APU*SX8xDj5ADY&G#ONBfV?WrO~SQ!x#3CrcTb)E$K5(58c%!eQ=iP#BMxD5Z)HA zVu&jVf{`cDtC13ZS)M_#EjPz4G?j%aNEuU};`jvIQpr>{faV-Jj7{$2c@+{&y%Ms< zD5~Wn3aiLuD^%?{7`#|XD~K1zD?GV^16cTekd(#DS(;vScSMsTv5m%N!)VWO2@tcR zzBoj1@f`qX-ir(!CoQs=J~51(xcwf%5!l8+pI#B~9*mPf2;=3(w#geyKC6*}Ri0UG zX-lVT`S)JwuX%Kh0iMq5F~J$5^~P0ZuGn8&s(xbSx(=o+{ma?W6?`_y*S+s@LZLBH zP5Yq6vLW$-X7)N-(+!u9zixff;0<2N?F-0rA7zHOeA%9PdbI8${uC$r>C%Z5+*}GD zUHZ~Z9hB}YK}8b=4|NJBp=C$DGtuAj+w950hlUT?EJRZk>ZGxgFw3;Up$+cCNm3(l zp>W`J=r?F?k%sOhE%I;{q`0noxMCmq;tdIE)JE! z!gW`_5f+&A`-IGcNLgK;CAZh3riuwF9^%*5>o;)SDW&imahYMl;~&IC`Dw;OS5AA)36V>-0#2)0zF)%`;dpKc@*sA#J-mzfX4mIxV=#%t#&9&5G=qj!_=~03oG_jL{};h zmK&_JV(`&Q1{A{_)Z!^6otk-ODI=^A4C`-9?M%Up0J?9w&eD+H`Zu-+W%w$gH_Y-U zxd!b_F3FaaI#^kpjHHJTiF~=4u+q_CB#e{D%{tu0j(^akg4GZY>8HqOv^H*Teup60 z5md&AYfDqU%qzWCx_5n01r;G>Z+a^C(o3^lHk=b42C)AQ`Uj@OfC#q24#cdA1hiz7z&Ut`BfDlzA{Ao2 z9aObz9qAV_4&e!FxJ4pxg>ha=ty#PN%O3yQ?S_G^8A4B(dcw`v$|hg4h}+5K@g>NS z*^)WNo3GE%z!5$!7a_my{nb@mV_|P+;!H?UD3C=kcU~{kka<%t{^P6_QIqmeoPG6# zESB12wv;rm|8r%c*ZXB!L5&a?Zbj<}FS!gs=l%3qAk4JNHqDx#BKH;S(Yni*1?_^S zzP0r_oOVj@l-+Qxe|o3XEolS6PPc`EZSJ8YseTDzVe9tp8nYKDFw*k%^}EgdUg#?c ziv^#i_3f#%2Nxgj4{brw;BXn5^HUYRMhkLFe-}Bk*6etB9na>QtlRMFft5Tzu0bi z6BX2y;HOrSaxi#$eG^YtVuSOYnv<4v64vp`%zco&Kl{wUnn_p%i5s^ zG05c2FL8JH=q9-WwSq3(elPug`7U*G@tR{?gG#`GuY5PeQHGNcKVpGxnNhpcKcLkL>*H=q?|mCS zv@Y+HG^2m?s>P2s>&<(Oe%oV%K`tHcYmWZ<*%|0TUSV8g7WJc{Vfm=R<)xhO0ldAw z;!f*ccswyCr*V==f3~1On?zl$N5VZ5D&%$!QpY8fcH0ABBTKmkQ#yid%!ju@MRf3% zJ8pa)@o@ANR66;J?e5(7SeW*?k(FCr4;q*e7#@5KHN4hzbG}8tLfRV5TDBE1RcSf6WjIa` z->}so;S&1U%hqt#`UM$m1)THR-zo2>t1_m zw>}sByB^iry>Ev&$AAamJ0{yS$c3(7nwD{+hy@ln8DP)hdKYPH%qI*~1FA;8`z11VOH{PRT*_>DhaSHwEUJ2V)rDq1Se+k6q{4 zn*KDo`T~q&kjBBDJIRQY_n zV1x&F(A!JegmHTj?=$DoRvweVy@jFA?TJ!^iq+t->>}edWb)mY-(k$5gEF%vA!k;p zuk+w1Y98H7ScqCa!iV>Qn2p_=D#W}9c5`~3nBXTdz0Jg@JMXTBSF~i#a73h%*T!>} zEvbT!X+3#&$-7s*FZIJ_yY;5y^*duM8I2h^r3xzRxp7@+1XhF5Rhq9(h|0fzrTs}k z4NAJ&g>&qO41WFI#ijSR@+H5{c%MldOj0HK){f?(kY@5@{km+?`5?>JGURvIeuz-J zP{YUfTxv9Wz98F2*rsfKln@F-KhSpiD`uQ;sB426^d$v3HA)FgV~3v;M&H`oVWhS< zmnvWD+jZNOw5rPloo~Eo>VFR-(Z6zAaBC|wuhEK_w=tO!Qg^oGOeSun2j!JtHD{Y! z&6!mknjLvdUTV6q*N(37U1JfDIC)Q%vd>WMvu7P8_6f0#&F_pv^q$=WQNpv`qZ=T+ zY4I_2f@7Skjo!(eeCdUe&x)66^Eu+vrD&tW2g9BEbjS6vVsv@qJk&oyO7(`M@{J0V z+-l$X%ulsKSqx@eB1Sx+{2zF7s*uTfhWCYfIF+ukI9BneP+Lg%Rt>KV`fbnWr^*RW zf2uoM7j*J!&L>O$?esv3H~+yB8H9`_+lQBFE*wEG_ev%w-1hK$bx$PLx8}p9FMULO z{+3h40vFGK-i?urGLX|PL%jyEP%w4w0lT*2kwLJy4B2~Y4zopLg@Zpd1DQw4lZ7as z?5P%tT2zU2ysdF9U9ZTo=7?LMnnd@B(n;-Xy&TU|yq(?h!Iu_nFKd(Gag`uu-p}~* zo}acU4g4DJ>5+s05cdet9XAXBc@QZL!rJ~AMtEqAXF6|Dx|alwmm$5k=0uOUt>lv! zr$d)_Lh3=~(@-M5!8QVU5-s*b>AJzUJKog~Kk!^}-jXVW>3rnzn$D%;#XrfoRtS z?x;-P)EkN49wvTlFbr!Kf#Q!N5-%FQx$0d?jIFONpoTlfal5cY3q$QD>#?=SpPlX#xRs+zc3m{lnM4G=cDC%Nbj zyqaO6?RyaI03mp&=o1NIz#ubp^yWY-wW@K)a79Q$) zmCtR!!vee9amw+i5(aQXJO%H4;=*sCIZo ze+O~xG^t>a0%`!gJ24JzbbBdqVC56&7xw}6?S0h3@`0}#{Au#1y#_KHPwg(w@`7FD z>`(y7m|{?XU~q(zTJ;iu37U~fQSjh&?{Uy?6is>ZDX7=weuJDoLJj|-ZnFoZMx4P* zu@^zXNFV?{)>=@qQ0sFIRN2dj(;fyT%wl_V-lsq{@WD`_2(=P+NbK+6rv)qegG{?O zsLB5+BPo;t%=Du@Dm@u_5>@H|1h!k}dyX3L*))dfJor24Es!*89M$uYq@RTY%GYEFPM3+(9!2(2B(U7DoyK)AjickC zCYhjYXAmm=2R&jY_rX%W2Aa-+?_BTg12t)h84ppgvy7m-1k|0p33MbL1oGQ$(Arg( zzik;(AOrxx!RvGslz(lS4O0cmWaN-+V=d}{s0#`teWATmAl_vXOEkAkW1_NO&K4OMwO z@LLrqhi{WnK(WX`oygDf!LOacn#O}GSHL(r=+DEM2A0g4@Nd9F`ZjLObjIMLS~4hQ zp*#fC?1FmwWeQ3m@YTdwx?MW6+d6d|D$+T{<3&}0%4ItY^%NO`;r|uOzhe1UTmGLl zqGfE{Ajg(-Z`|;t5mY$G-+cMU-Y^}SZJ#ekpF#~|I>5T`dU&~>1y}&84;Q=Qn3(*^ zW}fd>A6lQ$6W-i>>ZWQsVRh}uYO>Jmu*LeRdmP*Gq?@7z)KpT z&Z>*VNIphbR{p)MheA+)*k5xoo(2-jWpg`HUi){;@TCM+5dXSFK@t9|79cfteI&;@ zVV@Apz$;6*Hp1n5W4g<6n(3S}{)V%ILrCA+ZI2P|P^hi0NnC!C^18o7| zs0n4uYaM&Dz3VebymtdZafPebI8jHOq(e+@+u5R&oyD~L){&!TSXS#yPmZPFDTR6? zn2W$|u=W%kezcy{Qk6)75pHL0^A_cV3xD29rqliDo3ECMacz2xI44u%WDU4_GlCX7k0$kT) zWMl=NYu|pxNWAe0(8MKS$iv(IML?!rzYsE55E?B${{fgx-s#&$1sbskd^=2gwc%LW z*22}OGd0Uk%;p%qrVHE4xJR_$^V(ua*=Vf{Jc*W3L|AoL1a zCuy~^^aQPD3OthsL-7VpaKjzFPXwj7!2hTu9)-$Yz9*H27Tsygb$%^o@X=HA_WB7K z#CKWFJ#l~)@xN27^+Hy(N?h16zJqtcThHm5Qh}=5B%ZS4P@z<$v>1_!luoATFtTnK zk+6t;72*Ja;SK+`)gL$FGjezArTjE7+B?hFWOT4%F^J>JbnrtaJk|p8UEwJB7?{+! z;1Awu3gBaaS^t;D^4Ez)Oe-|a3^#xVY_lBv$!joO0InZl0$;Rnn;d!ZQHf2 z3pyX~i06+^x!pSj06-jAHy-@iO}}yT5)Ub;*OpD4{4=Cd1(wxyUvL9l*`X-ruQdBP zU_dem_ARBo91_&I0|0yv-nMoXH26p(-e?X8-K0e8DnH-!TkeH2qx2QMiB3E`Y=v#n z2WYgK@e8kcetWvmY;$>F?3o5{3h6H4pm89w7%@Ba2cz70b?z$_X|w5C5cod&EcE6I z)I^MaJCON9#XKM@b2+LBHIP5;i0+| zS!2KlqO+XN=;o|eW zC=>ZE$vaK_htErh#>+!vB$el^{;&(z_KX#)(bd^8!VAEcu)O4em50ik#64djlOa=( zn5H6~3VpJ@u`wKWcB05lk!^T9V8@`^im3dSLH${_VX^gH%VZ8GIR=;A<>Uee`bE9Z zWzTQ6{_Yly3sCW;GoCXg_nJ@+wNr2R`14Rs_63C1JnUn0iWU`zB&;X|>``c)`8Q_1 zYosC-`!KYiK^h3O#Ljs}kIV!-wz~PS`~qzoj*9DW_kc&I;b)w~Aq%~R9aT*S9*Wjy zLO2(YVrSP&P%nQ?s_^-wZw{aSUTHu#%6aIpR80#%Wn7@s`XJ8SD9OL&kDYiN9UW zJuu?P?zK4)z>|&0!5sMG2sYxIp2|Ec7~a};_)O>zvtK6CulXHCg|pDbu9b>}wCfm(`(5V% zHfl@_6pr};BLrCfMeGp2a{a?C_E9-F8bChlHG5cMV4j7Jx7SjPUhq=fVf$j>+isuy za0z&&%n*uMh?_(nU8PPPxxFipoz{w$o+GqyE6r!w?~rERg=$1u02LR_S-fqC_0`9N zU@pEl)~Y_ad7Vl7_@skrFNH2Gy(7~sCZ^d)TUR~zJ!B7wM}b0JesS6PxB)js*2HBf zLnSddn8P^+0_aAL14WA; zCPv16Z=}ft0OEMAfkG`979uQ@!r|FzsC2Qx&T5AUtV5*tt<&%GZ=4G)>l0#FA-O$# z-Xw`ZGOgF__{@+9O@jQ0^Wh$e_aE6nlMlxOF@tYJo_3(*U=}eTA#>r~MFT!426srV z)>P@4)n2*P~kXR1X29NfbVJ9jlSKrtgOfIm;!_HP}1( zCJZP1mAJE{^MVjB$@AGW}1cjS6 z1Y+f90EV?RpntPr`o+&>INyP!*ZUPMrb^o3^lBN;ib0G93<%nVB4$0Ilkt}hD0J2MzP(%2 zGWYVhJ#G3g6$lsz$QWSPlHYb}x|qB8@|~rGh76=e%#R`Wsnh?D8ECci4ZA|1y~1vQr`jG}&lIV3)-nAd>ehw6yVRXM@x%>iQ*%_RgHp zsrjiO1E{uf$9^cM7fvJY!58o1HKq)`trgU0-jGE6vAI>?m@q~45`kw>NCu$L>tuG& zpaxj7I7Q_^KuOS`M1EY%NBGXIVf3Kmftxw=^JoPKfQ8J|)`D=@DYrfd6kt5dqVc{N z_ezByICegKNzHs+;M`X^tk5f8f@^QrTcQ4C?l+LIPC-o*{H`^(p(*@>SY>1=RiZY{ zw19Kvzd#27>FlKms%&yOa=`<^n2}XHxkh(k_mKE{1;UbJzt1|87fc|CSDX$xiHY%$ zoy=|je4j&9zgY>DPxF>fqC7kFcGS~(M?Dk%x!L-7wJd{NQPci6tw7y$nGX_qdFF68 zSEUlwOfL88VufP&L7Jik_01V_{TAK}P~L4*f7h)LV*DL|;^avrHA1Z+lwW|Hxua~h zeL>aDvo8Y`@i0>XUXO>+!BYA?)u5bz`<|8QmzSyPRDi@kgzlKYkUn838_KJHuwdf3=N(nf~Px zKqloylzM<=C4B-*)P*W4bUDD$(YH=L(qo;)eO8u~pYuKVaJ^9Y((c7`wvdj6fgY2c zt84&}K*ZM>+tQuzjC-mv1{gwpD8z#x#>l9!7`>nV*AW}bacE3w2N*daU^8U?Ev$2Z z5H)TM3%xR{;pI}7b=8GVjwwcbVjR-!v;-A^>1qnKzJR)nJyhK!tI+OO>d-h&KwoZ1 zPW&FGc}HPg3iCvVa~z)ox$97B3eTFLdmrs&9ErWK1oMx~){MP<==73|8nF@D+o^MVBR9&fl zZwa94)Dq7}b-I9pTXL%rtQ7*fay(?cE{3HUOhWDp#QVzr`qk+(M~k zs~jWgvu~SpDF5FiGoz#t!MXYVT>7VB8s@KtiBG(c>%{s;6MYs2K07;*326G=qVM5s zP#(zr&KbDeVVP~kZB7q> zU6J3-Uu*S8pJCh1i9jh0g)z`H|8O}z%6GYe8@;Bkgh?7z`90|;Gf)NdB`mG~_@20# zeH_r&bFqr)Uh~Ly=}>5C)W5@rf1Sav_@>=JTnRbaJ(hn{7*9SH466fIw8grEmbw-d zSclvZCr1Zf>Pt-e`pi8I$Alu>uCRI|>I?SUf0x$;6I+s(R5AB#pi^)bv+E zfq=|p1eTSLTN=bFdbXu|e|GIeLC)@M>fGBA_<6izXI$mCR6o(o)fX;(l1^_7Ty{8H z5@0mfq^KrA0}*ql`yo2`m*JlK*-&`Gv9Hn3UpU*qqt-l}Nq?^N0(&^UxD>h!I@uzj zDT6ZMF3e=#d375(b<1jm`uPk_Xz34jD;CIGP!Hah4h^iBM9&v#UVyKr0b93}9i&L% zBE=sHvUmEaXtjnChNE!%Vmdnb_EJdE0>pVvx)bNIA~mDw0r*P(YDy@f_LFHiod9o1 zy^*$*dSMKhfTHn#4-x<@3eV$&`k;&lC=cMMq9k)>yfgIZjUd-;#o!Gl?CUz@@u@@b ze@gEk@svcqX=S;j@c8*TLcvWpVs?7cCF%G*{%;^>fd`%j1wcAN@k4c1_%~RdSjzuy z6wyp8E+n)SytMm$=_{XCTi_%CI>!ke0n>RJE{$<8-IVUDGth1-CTc=_Ip&{yxpM{s zFImJFZ*+E?xVW_MR+g&JW99?YK#={%_t^ZK_C`URZ><`la>UMCgLYj;ZM|<>1L@Xbwnh z%+b6%mI_hlU!H<|SLY05SJcdD0wCse(T;WtL~AT(GFS&W+4rH=#@5W3GCW8BqK-Sq z(Az)e0Q>|y{O*eV_OAZexpn|r4#Zj9K+66X1$zYcCmC?Whl)!eD$?+F@+E{w=Fm>q z8ST2n_KTUx1SY>?;Q*Ca&1yf)=8Elo)#dNAoYDb2^gz#Ri z`k3VxL&t?yU7G&rbC1Ga+9#yruClO{*FDoiJ3-7@HYAhluParH0Ml_yI7&1-`Xj)H zvo#UKW=6Hj<=oW;1}vdN4`g|utL47m7Zi_s$h0ys{R|KZ5Si)RL3m((C4iEya%&(s zDdf8>L>WRMjo#lEFy=fIC$THwV}Km6tG(7}rCBphXdL(oXL#qto-)ot5+8`}=nX%0QuC z1*g6ZOJR>y2Gk8r-yRb1>-7OZKnK3JxY_}J=OjZ^k1aQ#iS^rBH>nfG(7&}W^MLBe zg({y7hbzcWgK%C!&vbsP`o56ER~a(0%RC<+{Z>4MmYL@PqnvX3CBye%8CMiTnnJ)| z`O$7&)@V0>nOW>{n@`_P5qhk%ZG&K8ph{8Eakbwqaxg9j_8!i`Bu^O?xaQsGlt<#)j%G_Uj0_Dcu%&ohvh=a+Czq^&Ee+Bv=szb^Q}pwFx)AkA699&*M4rE$uPL5bZy5YgZ-cL9&@!OO*r#SL;Xe zWh^P21aD{ZeDwT~x;Bs(Pt2U}pYxT}DiD11?3v^Mb}H%Wa4@(IS^w_#?)EAIAG}&z zDKQwhEC(V^0aq|kkwvx7V=S&NOZEvVYZ|vYKnP@0r_Kf!x?k3Qs4@MmrD$>F?ck>Q z#m)K3N+8o-M?Eu?c7;jE5zZBGN5y`xl4aU!LrF6BzDsW_zbo?@S{eq*jmWFZ5!?3H z*c5h~6Cx8fe?C%<4{ev$j(ZIs(~7HXNxrF1q%T>YRW-Pv4&7!7W=9R~=Gp6i2q^NQ z>oGvf*eLo}cm=oQ570*$d0KW^XX7r7ZJp?UE=&=1Fs=0GNf8<*_ws$0=;iD{QS6As zml;mbg?=N{es8E(<>q{df(;ol?G`S#LJjuP)LLBad(CCOjb2ZrvRa)7`s5C*$_`mr|zdB(M{%c}L#%#yMn5Ol9V&%OJ;BqHojR8`$W;oW(5Ig1UXH zQacm%@5=P{_qrTgZU_{>U$c{)yKS zBv*LPtEOdi$t-QcMq4VU?4h>%fE)9LX4upsuJ(_#i!YGK6t5Odl^haKZl%LVOp=~S z)LIW|{e^rdfOmKzNpk)Ia8pt&3)@s4DzZ_Ga0q}qNNzsPkcNsi_mjf33Z9Unsvh14 z8)HkwX%Xk5(-bk3bj^Y36;{wDGJL-{H@b>#ysufp?)hPOeU+TJeplR>X!BU#fE=kTBl1sl=JE*g#L zIQ(c}wQN-wBQ;xZutz;`S0#7zK~Z0ppc0SHH*N!zqwqwYOQqx)dwbkSYH9Y3SAuhh zMewBLbWSBBcuT>UN}tO@30FVIF~ zhaMr;snJr-v2Dsa8U@S6+bP}V*-jSmv2@=i4dTJAwn_m>Z;CF)Qtu*u4yYTgzWS%x z$13KE9gn|@L4Htwdo1uVnLT`&{cZObe1>Q1t@h*BMLfTU>A`er5lx;Qc@;^##Iz3U zIIsN|szv+3jlXkTW32EhaQ_Q00fQ&oPHorJJeTOcABn8yfag!)p=w+RDB8_N&mIuC zvWr~H0I{%&)NSC{X;W+o1nvOa={YA>J_gU#l0jUb)Yh2amIYMAs{c-TA!bzQ@mlC` zR}GWjT4e{u^x;Cih5UK%@ymR)w>(=SdqS2y^>6o)PNLsGso_B6;@A9g@mRrGa4|%Q zUFhNP36S%j<8y9S%EEzrJ|;7=2YRX`Ry;;f@ZjUmFTOA6RJr5lZGC&wsqb+obax1J zs0{*hSdaI!KmZ1-!KQK^eEL)=v9!o;ECU7Z1rG&%-;D=rTTbr&9ON57X{ji48pxA= z^#4EuKK*jyb|5Uh%S5S1Y4kLm-@sZX!>9k%)IP640>oXFH?p}BGpAh_D--Tyu8umOQ{ zLz|&c*L&cuZnp^>Ouhf%hj{_dHf-~fmCJqb*l_uFFjMWje}Og$Y-9-!aL>|bebWb!T)f4`6`Q1qpPRtT5>M6m6);m#uVY#z!5_Q(!>Uwk%T zjA@FBA8)t`K5#+%=1H)fG?2vr?zM?7NNH3- zN;A)O7@8Xgpe{S01#6GuKV3L#|Djp^Ogqqz;^Ux-J<}j$a&Qb+gn=$kg~3oP6kxCa z6~=#s@qeLa)ILM+sfX};$T*tW97D0akN8wSXsJA-%u1I{NCsJodrCm-E55oPfGH`I zFLqd{UztQ%;o3Vz8(%@41SMcHBw5pToeod(9ioAwZ%a%=It$jG*q{Gfd|~Wuk%s}7 zpo9KQ^mqlJX*NWCXa4{kW3$;GFatwXfW7E>r>r9ZS*k{0Y{FMB#Pfr${-+xHFa#C3 zq3gibvmyFKG`NLVV!*wB`79i=UZ48GXIt+B$mkcAf0S-^Ir;!pE&{hZN2s958@k=$ zz~Ww=5x8pthN$FLw*%G}OZTy&)$-Dt_aFw-GcjkIQi37y&W+*KOgo#x`CfyHvi%x( zs88^WeaYwk({;RC@V!6OU4g5!^d@wq-OaGHIC(e9_nU7%TW zXzQyeH4)K`eRR<~l&zo~65cAXPZp1Uyt@XlVGqhJZ=qr|{TKVfhg#>41U5%m^t(`A z0wbed?C2%!44Hm8yzSQXRjYp)OpGe{um{$7MHza|@{eigzp-j#Y@E&dT{V=D8$zVX zVAvcr(n3VP9rJWN73M^3HTX>cz*2z#QTMh*S3pTDl z5=dtwU>C2ze^FxIfVB!eQYtR|*Vsi2sz1e+*7DVZqqJh3@0Z;RB=C_*ZpB5Ni zN?^T>U(yA_f09&11B-|VMv6-w$HwwYN{3+HuT6;0YmP}5NJiH=%XWb93HE4YI@DoJ z=mTXyP%{^;c~2(SNj1^!^hIY0lwR zpvrL`Kp=n6r|&y~`W*T|VNTr&1fc)fwFQTXzWS2G$+e~lL&x8!4sj7sM~Dnd)eB4T zg64g1yGc7=eUn~g%I6CB=ibc;_Hk@?-LACi+t=7=J-eq-J{>t=?V!H>?L7eFl`v#O zxb{umhmSM9O)xY)!pMm-LAtrf4+$jgH-UujJX5oV>NNCXUqzf*7RtFpuYYs%$a)ktn^HBWVT1m9_S`$2oN%M8ua2zs z@fvQqu@JluH>+pf>m(x549hEFoFS{ z^UZ_{0wbY|YX!xnCHvBbLtF??BQKOt^yci@%00k0`g-ziqne5>RNpj&%3X~~zrgD6 zb?pPN`qM3_DEruCEQM$O3|ax7*K#ZU=_T^Y*Vj6g;wWz^IFv0fq*lCcqEFZnoIt*~ z>S@X>SZq5bJB(?zQ{V8Ivof3;eb|1XGI+lf7mIJ#XA#xREV`}tVx57&-PreJ)vhv$ z;In9XM`6ngIriSsXQ4oABLWgg$8T!+n;rX5{D}?dWyofsx_(EOd4bcg3;a--VeN9T zLH?AsobmpWntjgD9*1}lBgdsZ1TEK|N{igL-KLT|A2U|i)~VOoHn;Y-rCv<*?P--W zT@M78l~Rx&;WXePEbp}?*jzAeFX$s!BIJ`gUfFySa+v)9$&<2X{sqxjumNhu+xG|E z3JS!y7o=-~&)_iVqY$_t#Scb#eQy*4Jnd0XILtFRsOz1@&e0;$J73h_^vF}wxiMEg zeZVa1f$EkKlD~=_*f~(>{|FU3K5-D@tVrkWOBOlK2@yb|358OXk2XKv*#TH={h%(9 zp$z2$4KBcUax}!6AH?NSAnIm$j}V4NH~&#}*?6Y*y?x>;*H_y9;uKmaK7Yn{U@d>V ziwF=n>_23}5}?h6gAm(zaAn_Qp??&eJiWV5c?x&ce{!r-Raua0%Hk)IIzn zO)gwhKjl`y!$p(wIaebrWkOT`=^$@zv8H13i<4YI<6~^&G?em(@gE+&lRISKQ5Uy5skCyOkY)5XkMIjEV#Tvc$iS^5Z>6umxym6Fk}^-9nK4H4b8kr zw}y__B>;b+;_a(1NPAg9e_UeywD5H&td;2~{h`Bz zL4k*mRcdJP=~;i+@$RDFPfp!M0kJgTkIp+pi$d3(690o`(&4R_Azu||*~@-X?3E7BJa|Vu zQFJ}7yJ=4gTUIYiJr~$4wt{ycrBiID597SbuK2_Dx|!e~`%l9r6g_hS)(jwS53V?kD-DCVH&;N(7~t3JunC(_Bl`3BQ696VYRdNkhO zK05m4lai&*B4P{!n^W!|C&bJI6zMG6FGZq!9o&CL$$ym6HCXlgv+qF9 zuDJ6z%Z3yP`uRh6r9ce!{%3Fg2Op9|Z!poFw~3=+t!tQqrzF2(90AJug7=RneLwxf zbI^~HJsk8dA_!GBBv{q`>o00qGEe%z~r1H4KuhSrt!=mks4 z2LB8aRv#xgOatu|3AFuBdZEYi644S#k_Yrmi}fYql3hguDV`vtZe!r^e%$K=b!)qK z2ooznqYqTHPWBMy;hI|*-BF8E;EDA>_5XB#lEMxl#mxoVkAVyQ?hrCzmWWrBaFQOf zD@XpK4zj#px8oMcC-F|P2_42s_mFElp89h05N3pn00+(mONl?LA|r+_42XTF4R+f& zB@bjRZ-s|NWwSQy{IiJtVBu}=`M`yk1sm6r3*EdaMgR%+!hZ$(U%~$WQ)4GBFiU&W z9mo}U86&Z((|t$FN~6{S6L=4khU2NpyO1r(V)*yAL{nyCYX%*BiIw*8jAr*+Cj9`` zE#AJ#V&cmo%?`sa4Ua5eJu2u8S-offYW04*l7cGejfNU*U`mqx?Zkl>MJMRE+1aY> z7xFgE$GS~qiS`8MCnT!rmF!t+x{W@2T(l5>;BohBto^X%0kcd zNkt;eY|wdLFOK7Vgv|Ju>%X&VG5!ed=U=@a$pxjMcU>>oV2^HY4~5ia)kKq}V?+7Z zwZg7A-dhyjrd0LxmtgNUdn>r6xmkCC{>yT~}pYb|<3`4|ZzsV@w+b|JV(`yOfC zOkW3c;@roXlO%gPT^_VSt}OE^X%5Y1+gB+r`YaXXoqTFt;;vH?KPww^!4}Q8&tUbaizDnJV*4p zgV&B$=RJX$;x}VocWb1=bs7%mmvqglu)=4f?N%nAkh3I?%W&;UE$S#cw zTV77^kMsiVWj;4BM##kUaJe`mm*FJj1o*{sCxDj}3a6a(B|Nfy34 zebQX@g`#%9?VFtq){tX5LERm6+fI4`7cBd0oi9o)d!OcS>2tANnLKUS#qIDkOT)lW z!RZCFx5~gsco8K>qQk}&<>iUOBSvvXh?@5_M+JdLsZw+Kx9SyMi<(qVH-kOoh@NEk zN&awsBE3yB(=0kfbZ7n;6&SW^s{HneIC3&%JtbpAcPuO!kp+V{1jl?h$79*MB`368 z{ISnA)-t~F#>bPjZ>ekCDh=CO6C9pEj*4 zi)7s$4MXpp6IFEFjF>`p(X@@Bshm#aPoCQ}ZA68%H*x4eE12plfW77@j{O{k$+q7V z8ffy$+LR~|%?l_}@yyX3b`d)#yxs7~vH-v5DP8gY;78tPPpRJM(0|c82x4ykJIDUT z3qch9=K1CwxSm_V&b^J1@}+ksG}MuKOXJjUT%7i%n7go^H8c#BR_7k{mA2LK`ry1W z>d}f%RtHmihl|aSY?+ajIv>-OvSRxlH?;7oEu9Z>q>ni`h+F^mAX>4MUA*v$G|2|_ zJo2&^A2?L9ulun{#MnD#4&cg~OT|hZT3hm#;T$@4R&Qo_rI1xC2Jj99_nbAdwQHYg z?M84hl0+tPvW}AX7@T`TJls*x2Ri=e|EFKAte|Z@Ut%jT&vBaR8cB3+^ekpwWNE6T znqVSQiR!s#{ZN!`Y~NaFZu6b7x<_&%mxrQlkw$3`Rwl+9LUc2<8y0yLv7s~JaTEF7 z>*bwPr6WD)fF8>N?mQ#87ITA!WeIu}yT$~=lkD4Sm&f6f(8YcAZ!fN&BHFFWRjX~X zIAD_M;<$UvhZxVIitBT{W-xec!O9(z$t-1a!(B7}jY7jA&ozxL;`-3Sv{04L_j`r} z#^LuN^NeBocjj65BD_vyA%0T}6D5#h;z;7mP6k#aL-P;OnW7?IiCqTs5RG+0agpt- zNV_iX?41%Vt#|rHs-q%&QX_Q*Xfgw9k2ulYDv-&8 zrK}wP(vpOgjd<%{ulB?N_0yMACk(szY8@*f?(QjB=$v)xF#bm zzS=!ylGDftpYtDIR`Gy<8NHrk{l*mF|%~3btGvlZiZCu#@UKJ(Zyy?TZFsu6`r;_ z)J64}&D~3>`HExGs;jw-Vu%{qB4+H0@)^v-svPZrL1LUV8G|DU$OM5;3LvQJ+*)LZ zpW#xYFW8+6L4HZLuuE5tNF7Y#>otpQh;&2^;-%d4gB2gu0#SZS2&|x?^d8HJ{qW+A z@mcqYTW+G0BHyNWmr{(@B+9OL3@Ce(Hu{H(ltTK)n}Md!||T3`TIkz)6|w{S1=!BUe3zzI5Iok9xEsoGrjwA~~K;8)DtLHYA-f{v?74 zylwVwb{HhBDgJq)UfOG0TR64pTquJ+bsMr!dAex`PiaaaBiwNqd}o)^l$r`NQthb- z#YWR(UYnnUwwUbrDOg%aYkYc1`X)z_lQl_4>HK`;ZOk`kcuX@^H3OZM@##V{3VXD$ zv^>aVx!O)w4+0~#7o$W^R`;Ygzq+oEMNOETIO@JDf5W;qCwnu-$kl3m=-rX66URkj zYHCr79^D`ANNt(TS1me{TY=$sK@bloyO^_s%`a$|akNg%hfgX-W@jMDi2E`E>c z97DEom%HmVjTybeh?_=-PDZ;KBn6Mhhd|0j^95giM9xVDJou7UR4+M=HH6Z z)vV~wD5EbB-AxnoyCbUP{z|gd!+hJh|1RpfS&J_5{SDUw73QxW_33XUrmVISBpyq9 zV3;wkosb;#VbpSUe#odp4#ZU66Lnx?=NyMIp?l{CE>mcokFvI?dQ~q?AK{df@1xJ1 z=KA)$oucVicsch8i52a3*t6j~E)mL@QA@lL--XYM|ieI%h7w zaN2O6`QA!}33+p1R&IF4hY#j&`xM(JKPa9=CbR^fH|Nau7gyIKtI2ICR(bVUY5a3m z&hBR;Vf`S^q!#rZROn7b}xiy(uLq3zaK2Sa|>0&#Tr z)U^N(L*nPker08N#!NY!SRe8X3vw{QZ5;(0pY<$+IYc;oHMTh8AyE0eZ&ebzM(sPBm@aw0!AHyg=Ze4c4Xk$J$7TqN!DShyZsot%F{8)ic?G}qqk>Y~z&;I_ngB^RP# zJ{wh1KGze?$eJu8P>pH+F=p@J3a9cI;pfdhTw^K8$>{V#NCm_4#hGdYF4jBj_ElqJ z(&!>Mu#{>v{$kz<6TjJma;qygc>^Wp*EQQ z){*RvWZ+aRK|$_^?>kI@?u`DlQfeO!8-cI*H*y9ZJPJq~zLu&In&tM)RjzIV`7#7^6gQ*w_r9x*K_nSnh?t10&z31B~xT z+sQsyGRbN(9%h8+HlM=83NuW*z9d9e4JkutRw>`V_zj;iulFGqs#51M&qY3g_GPK; zW$hyCzNpcqa>2PeRN;|(5j5fF9hH}YVLrxW`WED}Ip5l{k7v%m294j}IIC-&nS2*m zBEx-5c%3wQ(s&B@Sq9vcbIL3mn{+*eCRjOeu?6BirHW1;q#cn9M9-U!>#ucqA@o+7 z5;LwL;JqCpwPjU;Mfk{soaH!MLW|>2+z|hFaYJ}LXUO4?H^s%*(cVe_(!N?d}MfkLI~K`w%u zF1-f$X72&pp^tMg-l$SV*^L#snf$JcjRGgrOx3zXZA|5()EtVaMBB7|OYrTiQ-}vUG8}Tg0v+x8KS8Cz z1c%S2_*q4AC@Z%PUTIIxk9WKpWz;&Gn378slpu^u;E>`WC7e7+!Abui1?MEiv#CdE zPb3z$8x&L<@;s9fFt7Ia)JeEli%L4X^Z0VAP*?P3jT)#EB^)|juT!1w~#Eb!gTf-8u%H|DK4!n@LuWO(`bt3Ruk@= zd?GZIgffyrk0u#~Sd-}0boGMz&^4-Q7cmUvjK=;cfv$HIi=vxe7WC`z^1s@3y(X*s z*<7Pm=6ixVMOSOWCZ#O3aj`zpGW&=diz9T?XpJUUoiqm=L(o$iduR%k>_dGah*7)|Q^Zu~8;eLxCAFIgZ zIxBXuSTraDeZ0$-j1QdtnV4^cP`M%!90n5ai5Nc?g%_ikEQ6DNRUsS7jubC1dWm>& z>eekHEafiM19ct%%kwZjh00cJUI5K!%kJ40M{DXRjnVqS9?O|y{5%ufYdBmdGaXSI$l|cc z2yC=?lo?bHdJQ7`QhY3BNsq-XEi6j*Zi+AvLvJN!iVgf0!^I#_wTM1V$pWey7kOjx zq20uS-FPuZl&ImPt=u2aAI3Z!TbY@7i+HejnI2Bz2EIUtN2v|&Pp!8f2^0b4f_g?P z&&LY7J7|ixh3!Y^;nFlFskIF7hr{fN+{lD;Xqz=KJmH33YX}zzRMc=(XF=)eP@l`y zvLlePO=g}1J}=Jd@oR+M*(sJ$=vVyYff$@56hgtX0~%hRf$B_Nj9`FWELM8~L<#TI zQFdZz1%cfykjYaN6GNK;TlElo6$U&w4^#XRsLsU+ww3FRUL^2`y<)n&dI*bpSQDX- z#G(`mz4W{hm)s4aNQVP{tAn-@;Bxo`D+?i>#chjDq^zXr6VqR8Lm0{ z!+NL`F$6eiR9j}W&=1vwgiNRq3aU)Zytf3Q^yUksr1g`>olapMULSc#3QAd)#VITx zUwD-kT8%|LFC6ay!Ov2>l-l|cp4hd9Oi!0;9dF^@kK4w5BsKiHjyZ-Y9 z0IwwZpI0+6Og=tTIPNi@85QE`6cH}cG%e$t(FfUwW{|#@1UlKX!j1t;xo{d>m4DJZ z`9u^%OPFmc>DwG9Tb@*Y+eLt8udDEaMl3D1s?V=no zo}x$CfkabfVoo5nHtl4zs);8oEnd48n$L0W z#d%CCNuBqDtKisM+@|QDqOu0Kr6nc=NO2Rsck)!o1oDUaaz0qpnyIni@qbn{u)zQM zyGieYXxTmkZn}0G5sCpW$dIts4$KDWL}J`OA0N=+|J^H)&r^@;NQ2jZ_9WNZ+u7tX zAHO#>J&in$1C#0lN8bv`jyHrc&19C9w()A|x*;C;G5)hy|8##nWCLNK*DM-puyE#} z61KHav5;{0K7UR6a`ECVfyGZp@G$}NjS}ei(Z_cUaMBD)=5rL70GO=cGKBaM-Kc~J zC#MwAx{gFpl+vQ695yDMw(0iT^Qh5WbNyB+uyZ%F|7>hG8urvs*1?(gjPDXdntclUPOJ=_hwG%i zTw9_h*Z>OY2Tmb598{jhM&59tX%r3a)ZuCaXPQj8hlEC6GsUZwIaU{iwK5^h`4})8SFP0%6@{SGOpdX;Erugy>9KABxa@(DRhTVh4 zO6kL`D*FrsCcAHa#7zcKlwHCCR_n?5(>)&0dLZh9Gam4!$6r2ue_7}OIv2F2oX3!r zc=dRNM+w{2-RB=rsN+8hF1;dH$|-Q|Z^zs^*c>Es`9t7J9+ro-YoTz`aCk_eG4U<5 z!X%e-VhSVtYQkwXVx5DO_1C&#VBHz=qf|06(#Qnsg+U&(kvkah#$bJ}mz#qPVZ2&k z3waAw17|+hQSLS6z5+1>XDIJU7euX_ugA~FynG4_&zgA7<}nAY+P;O_Y|6Age%*4W z$5Lnxr4M2vo3(Ovjr>Ch)Lq+W#2=EoEogjN> z7qIB}>+ScE=T?h2p_;ZP%mkOkXy)^appUFiW85RH8%%q&a~zSqlyddyJ#S`c^N?e! zRWPk3jBkkcsx%OHzjHndZgNfx6kW=$<<^7x46}7S#`-m)mgW-8&mA77JZfJ{2~%uf zN5M=S#1jVuGT+wdNSLtI-0NmnYDCVcoxuup31_;%;}`>QUehDI6kU2q=1PSlo@Bzw zoh$r0O&6%H=?5WR48KBtbtcfJij9jN%csZx#_*df`c5W;_V<+G>uq#;M^VzwQeT|r zv;AvhBSVe;GYbya@>gVgrCkwvV{`TURLde|d!KzatvMJWQ{hb`q*$T+a`t8(c%=iJ z{z<2r)9$V6c{f?0`44ZQ*$5$`&e7k9I{2p+8Q`<2sCZ5uMI}T2c&B%6S*>HW0w%TJ zL>*$^*6~HGhAh7wd=Si`++nosBOt#9vomJVu)cs08q%7tphVU#sHOH*Yit*8XNe<6 zqg78fxp8ZRXNt=Vp`!aYo~vwM3Use4o4;|Lj|S#z7~R~ccBcJ8l=8~WoN(5UdAiLH zsIeiUB`Jnk58z+=bHhXooDo|!AtlD#`j%sLeQ9TEZ!e<*_%el) z{!Eg`_>YD0K1tnq1q`al&mfus+`{h~!uIkdDdxGBTfG$>?0E%a0wuV=YcqA$Bvj|` zQZgy1Di=RIMmou)9VL0fY2nKJ3-U5~I=tUTsdI18QocbM@I-WA1YlFAu$6Uvdy)(~w z8VfUPc)?1SYxLoHVozr(o@!VyPFf+gTNZ>YtnHbiyA>aA; zZvtr>^Uqw{dZLJVzO54H(XP^A2;clnL~&OA(I;4-H2P`)yLgxQ{{F?qmQ|1X77Ztb z>ts;65=a207fS;!M~oa;KW+Tl(>c6eP&R?VIfTvU>zew~ndpWbt=FDg#upjT_ji&dv4G7vKjx1e*pKna+ zd0#NRb*-LKS%;~%B0uKs=3{>)HW-}jH02OrY26r2XQGE^li_TBm>c?-K%h{qll=i~ z?{An8X9rZjeod#gZ9mtK((ZGaZ4?>)xTFdsHB+S6{Bcf9yCJ^g%K5?McoeMVd|UuN zCYGVfUSAB2tK@#fg+<-l=v}%38Dy%Xf2<{MPOT zX$MpkAoQ%j+!p_p<$OS%Jls}Cxm}3j@bE5V>$+sBERVka$kP2!LHj|Tga6Z5u0Yqk zqV*HPbtflc0&B8hmvO8L>&GOreOln_qljZeZ)nH$;dUzv(~qsSz6DcT*Y>FO-8{CY z_NYQx>pZAF_y|2Iz-RXKyYAU$XczcOz3*!8mIm5cHFpJ!ZS(8hw8~?Ld!;@oLOdz4 z&AS(Q{lR#1G#9TGuLEexiafTvb>z6YOCBZea8IC+(?0*)>bWT6R45u z=a^sHzhOOm+{WmeM`VNjDDIJv8ZS;-uvbq@s&{QFwYK0Il36W^bJZQ8N2IJwcMF|M z$?tl=Dn0P?`KK3|;j>2@>)voV>XY~A2Rh~nP~Tko6!>N=3{#H_B~r%6`<&i%eZ8x> zzOtF)WG|J#w}ogb#%-UTiyIh!mNp$uKeYh?xyyo%te%788c~~dQ4zK9weDgT1@GiS zTyO~PGnbRX_~obmQ~C7DI|ZfLKA{3(Tuk;&h0|1F*X+#rG0)*cp^y2@)C%^roFmN# zmh1FRY~Z!_`+Ov-j~i4n%Bah=Kc13ZGyS^M12tNcY9?;S|; zy2teT?h@i@yv*RRv&%ka~xKv?Lx^Ya)hrAe5cr$>6`b2q;;tsu&t zK^N5>Cb;nBey~zjYmQEHcs#cQ=)o3P0<#ekmA4CRN2_zqKg4S9u6CW>E=hiE=1mO> z!O4u^PmO9!T@mJzun{kYm(D`3MA1lcD^Z0+ksjal?ryN)6e&4_zL5Bs&8HGv#fwU3BXCekde2rd#C_KiZSZ@X6CP( z^8D%RMlJ84aW}a&=&&g7Lw$?akqIaLsPKN8xfc+j^-#&3+lmj+Zp{K&4||P zw|sOsoOyt)V)wlYgIAc`?x*4yj^+=9<-eQL=tzCMGz&⁡Q=EVWLisA9&nRu$=(| zX~53i#|O_mCvO5_=&#N|a0}E7%AUHjZol^=6_ukj9$7R}Ehla%(I1B)vT`1=&% ze0F76)a;40{@aimgpW1}Ed+fVt$8CX%4};V&xCoNo!>><*WbuanWI(nnsi%B%VT%h zaeMA;x~5((lDT(qO%I;hZFJMhTleH}T0-SDfQvnM67q*_9g+N5UkjjWfzDp+0Wc)u z{*8lN05Yj$U;mPVWCjTminqaFzW+?n+1eAk0LlAg&=!YC?H8A#%*irCLExZxlL zW}tyG3S9t;C1_c^5`B%hbipizVDc*F;q;XoawM1$fkV7Eeny!8uVGq4;=2FBwEu-^ zf%^YKOa27&{3kFiga{VEIRijZy`SgENm66MXQ8u21MAb$H_^Ix6U`08&=$s4Ht@HI z+6Z0ZU64Az>>QT_bE37so`4eqi^8BZMFFniOL9<7BTuZLW&ofM0(008hI&&x>;|0s zAU>24{9YOVb#v=3!aL_nN-yw?#Ky|!m8dDAb7^c+j@>~Au{YfdyouJV2x+dwqU^7q z=^;iY-2J#Ra2X`+R{q)-AR9UW4u7ST@+X+|e^FTi6uHG0K#{v8;y$Z_J7UT$sP_>Q zGC?>p;;qX+L)rgD|Nrkq|NmEuOf~m+ zJK7&5+yc(2RpQ@yn;QV`i3jnJ)M3~}A{G_toZm`=OrX@Cb8t7hIXBj*TLl6eib~W$ zw^t$fccFajNzm$s0uU9fyfa!7o!fM|fYl#{dC!r~gp*d+ohtMJmd41fQ6BP~jNsqJ zWWf6h0bsKCrhwxKsG9ocrmcVVn*TxFU6JMQNwDEj%ZO?R!vOnnjX>Nu>t%nts8LgK zoY!*N?c1Oc@(8d*$6ip;fja2X&>HZ9zgpMlSq`8AIVQ6(2M96%+GmvLfy&oCb<7Ey zex*HvbSMpLQOR*m2i2ysbvFg<(ldW8`;g~pY?K^r1@Ox$M9i?TGnXSm1a;0fz3V)N zY1Q+5z;#scpdr#2s>T~IftulOCB~RkWR?9VA=}v}JeptRH^0&?)UPC=1B9Y5=u01> zgl%ifiy|)yndIAjnQ!JAVCfP;s`B9Qp3WB5vW_hd$f7zHG6Q&ZE?klc=Hhr3v}V=w zgX=2K^ONhEeboljAS84 z>nU@@mpRkz8l7jJ(V{oLogm!a&lfG2G^6@hj=cwi2e{p&j)n8zNkcMcdQw~xOD$vuuW7i8Hr;Cw z9eRqBR&{cnA^_2o1jLivg_zuGfE@X=u$6l26=~ys-if|vCR0z%-vuYWK}`Lb<*Fm* zCVKG~%QbhU`kG@h;JMZ@X80pgcmgMf4xPpD>}T{WEvA(PI}~iIHaX2KP0i0eeuJB? zVX6OvwihIq1htsrKs+A+Yf^akr+T#4L8UxT`ctt7U`__lb?yhv7Jqa5Dr($j&j8=5 z^PP|62(%Iw*l_BhHfkZtyfN991}mJqXe{F@pkxU#DHiY7++B!raA7-BDhzjA$0g+D zFgF}OXJ;&P6=vt^wJX0h*>8s{J!?YLWZR?rN%FJdk&n;DYg;qdUmAOq_3GGS&kr=M ze(1Y`I2u5;JUC$z@U4%IeOzEU&!UPEGd`R%wkf{PN=@3$c%8ShU858Lq@f7zne73R zP9r|;evkR<;5;91(Dz>{8@gdty3;p#DGzg1P=i&lw6VCQ6BQrbXE?`xGDn)-Rz#Lp zB7eWIVWgU`g7HT4WRcdklx;!7rtswjvqC4{d>gfD;}Vm_LF#JXh7w6y+c6S>5%MId z4Kd~I$j1EXf@9YLVT!4t``6SYeM&_dGuQL;#L>AYuN}D=qXoQraEq4XDX4I?M&bvm z%x4BbXsuhZvT*BbouqG&^`P=ogO8228k!S6YpC+wd9$S@ z_KW4)$mV7KV|bO0hoXvbPkQXXUVs0(uv+=ZCGR{;7-kK zM_kB_oW;{r?bfbh@uB8=+9mc8-aXUHOMu43v%kQFbvGdOw=Cj25ihr5JA4BteO}QF zS+&G#9=1IGD0|Cw-Gx(S9L9b>7BDwciI)w-QD+;s^n!gFdEOeWh$;2?3f#t)+VN2@ z&xc*uU)!qWS$-8kKE!E0iLST1;Am}_Va2gM*vDsPZi;#&TZR=Zd=PoSA_nzj0)qQf zrxro>F~W7|cZ*zJhq_O2qD_s1GC25?i&k^(G^wIx#E|5%c@8GNLAiiIIb59wvZ~HY zJ56NPkk3CXr83-B1GDlxla#_`hq0P9rNt)R=E^E%W;i8svc z6Ny=IPAS)?P)@@m>c*hrZl3?ID&>lndJO{g%Yk#v_Su}1J);yUT>dn zy^wStjbxTx3l)o?DLr2*JR!%m!}sk5)wau3W%c1~ab};FrUPe!9-}`SN$Op^KN89p zUXb?Tc)f>&a`s5|2dX`z^4m*UC#^G-wa8fw&Z`(*u=X>~nbT-=d~(%xg5yo$+@r?9 zl#5gCO+z6fHMh_k@y*?*l1U)0S^BTu^zwOUD`>S8H^06;+mzOKrro5`QZwX!>BUOE z{yx3yV9{gJ{ixDL5mSD_HT zSM;pE)`qTNF>Y#9T)I%&my$(*Um%0NI&zlv=tyX_QJ>j>w$`#fOJYTTB_sAm&bk!= z{zU)1b0SX$RZ3#>=h$Uua{=o%E6Z|XleR|cGEj&3`VXAgDK);UE}$j`b*@zU{=s?@ zW)_PHB<$a4zqGWK4d&G}JM!4aHP>Q&)&qf`6OucO0sRhStT#qPuzUeTXacWrQ?Fpy zxa&;3N*QHjN9*h9)8K?;T+&?`CgD(UPE4s9Vh!0{h*S`b!! zyYpG^dDBhqzOIHR{Np-n&EtztmOJY5V+nBQ4Xe;Oynfq8bD)*_6~^QJpw12kGs{4ewSK?&-f0dkh@s z4)dR7S-ia`RGx}bEQ4AnnC)p_VI$VYM`-#$CsB$Gl=q=}M zEPd(^YaARc%<-ytnAPF~*MlEIt*-MR6VhvWxCkeq&?w|*XaqEb9Y9u~gG-T)Ugs>lV60BuIMflL1jxpv>#|p5Gigion~iu()1LLVWISxF zCan*j?7&KuaRoGWWi|J$Gt8?H5tkKO^VnC*7Sb}gf|{U`weX98LYZw^kD(${5%*=# z)}lFPSFhC=<)BZn6uyNByO_mU+3XL~y6Z0>CMF6Q8Sful!jv7r5P>j= zW6W7gS;QC@cFGGq!uAv$O4>47AgTbl`2j@~m7E%k)qX3{U~e;pSm;o06X~^N3)3urOoH#YrV=`9Ulk( z=qe#Qb4x0+u~KhHqp~LbSoT6T)xsZK-(}GEI5B8>&&V#IUNI$ghgJfuz^pe;bhr#C z;NHfMK0G<1Q^{n`sdXrJEkqUOBcWR^BHK_y>LWLs5$EullmmaD^ONevCBqye?I&Av zpLO2&!4@)AIVGJ|VcB0-yKr#+einyqD@(@kr;S<>a*`f5ztPA*-c;}hTNJ#N>@m=x zl!v_?VFrW9s>}%M(Oj9A4Im zM}Lu|oC>y9e?fP90(1fqsP5bzuxzmL?$M_cmr7Ke`L=rjS!IpqP3o)<9=A~s+Lyjh zo&ckGQQ${YgZRt=88Y`Y=!M+d`r?ZG^0~xdY+k^a4zqZ+%|*01Vmljj;!5o8ttaF~ zM_ZSVEn&DWoZ1mErfwvajb`jpHBaf^{bcyWx}^MOk@b+urBN%oO9$=B+l`mX*Y6dT zHH2T%d1sw=(ueijKizk)ZF3p5vZ( z`#F*4@rJqI5}hrG-uUQlrz5MdfgjQdi7AyIA$%kv?+JWM{-`szK5MRcg2r|O7aq#( z;INI;aKWfpOy%Vrd66-R+3Io%BKH-FWr5pf_6S}7P$20Wb&X16t8ir`x_!6DRG$ys z8{$&&gX}|7bO(Lig*_XP?f_7JC$)YqL`~nEo7>X&D+gb zy9ONFK~c*XW?VtUH=soF$b5q+bx?)Hk$cI#S2)d6wOw{7d*Cel>^#K&ruw6js%_x zc`Q6u2q%49N&K4lZ><003r8R>gws(_e*S=k77~3e@HDvf_&OKy z9}gx&eIL#8%Q`hs;+^^;LC+2!Y2wMcrUU2-#`UuOd~t=8mu)f;+{;-ypYQFAlxJJi zj)=vEVx0cc01k*oQ(VQ!nsj&1L6>+=>uG1PBJg9q_jjU4f6SyfFt%Vrs8($>*bpC{ zSpmMdNp}`ctkeEp_P4|c0{$g86Ed%S-@wxk{K*Cd$wdm{Re<{YRkF^p+Rc~v+S_J+ zD;I#}8#5{It+kWqx$49%+f3KijBP)XJCmb@HMH@HFu8D>S9=3f=xUo+Lf3x8xfrO? z5@T2?9ZuR-8zWktAJgyO3g|gdM`eWpR9*W-=X@RWJT(jXNilz-(BY#Buc6Z!VafOw zXyVMw+~W0d6(}+*U@;s1$T3p8&yTlNrw06d}jsFC}%MLyR+6 zqg=#v4ooT$FsTkknh=N%-X};2K)n<(?7h=?r{p2W#2UW%xOu!EG!a5GcqpBS{^SXN zKO8Kjk*7TA(K^r-H^K$~Cxro3OU~2au>U+u>_?>cSULg&cx%1j2y&Tk37=wPg<=lG zkHqFj%ptAE3nvs4qXvJ>zmrVx*AZra1YHW3jK+^2Ce#IEsak>wSH)r?Ay1ST@c zxY_=76579Lw0J#gCbQ^iV&n^j0=*yZNq&1zjm}n{WOqq~XRf+WYvc6;%eje0hbJuF zbWBO_Y@%j?f1@hczH4K-*YC~-!im}e(uQEmF6w>?>)d>Gp|^meVH&YA*T;;uG^h`> zrH=u)Yn2-4V*fs7?XPI;AbN4Ti|X{#{<08jBISSTx)|EQ@3#s#R8rl|lU^Fd{*M<*(^h%@=Yf3`P3 zHyfubK-Cl=JUjEAS9ux3xpjK}W1B33^IAv?rPewn@|o`=vg$_VyQq|v!*1Nreb=4!!2qcvD`kq?nEsaFvv z8`{{&<{|0P+=D3>pcmoqBSAoLip7fu?XA{Deq(gjLJyRxL1(Pgh=g(}I3Ui01A^dQ zXk6aE91wsJM8*w1W2utBs|9$fBx$^f4ioM`nEX)RU(?c`#wzGOy4(DLY&i0x{X8*& zZsmR9Q^_3A)tLMz!{wY^OPQlT77pX`|Kp}6oa6@;Ia?v)A6wtP`q-E9%uj0l~|c=kazwb_OV$6yCWnOaXB1zYKT;mgfi0 z^*h}$8}dOiWWNsc4<6|sd`S5-&_<49~(pgSV<+a3AS2gLq7!3S8E zA3V$NPhca4j_KlW$Mok9XhFTtAs2s%`R5ORC%bYSXw#1=hCrMC;PC#j$RNO7uTl&6 z&Fuc^1J3^++1A)oT*l$P` zIGryNUk9w=b%g_xb!p2JRxqO#V-4Jp?~SKx7YhEL_OARN>V1v7Qe7poOv}j9$&#q! znk3txMV7)jrAbCgn92+>HI8M7sUvO($FT?LJ@eyyzpu~h^L(HAKF|C8e4gh$&hE9gl#-PQ(z2I=Kx(g9=QEOUQ~zPo zz~b{|a$Czf8!AzBj7r#cdZ701wYA<|nAy~bkD}H1Uxq8+4Kyh<%87-aM9f~i`&DwN z^l4F$3Tj}OM~z_7i?G&BjGyIDq0qoxbF54O+O4?%2tyRrQT#rX)ce_keZl}Cucrhd!6;vymLc$&Ob#1zj>QO0}3yGY$Ae<@DALybwtxYdr z_y(9x`FW-_jfGu-`0>e)4e}-H5+V9lmywtK1IklOy^=OO&T0}r{GGmN!=HPQQCA=u zPhBgXE$F2E*gDQsW*PKYzU5qw87rfW-b9g&mL8rI0A|uTW`tODpp2_G=x|NCG^`4B zV__7}>7d}TaU2cei3AJpeVXfUxO4=K;`(2Q;47<@KiCjE+J8VXK-{?9h8SiREGDcAGto$)}uM3y8~#s#h<>u zMpfVFMa_(ow`8`lJ0VXQzc9)tjhZ#N!OAVswSUMNi?M=xzX7;L>O$qWpX1lqcy_xC#l zVuv?FThCKGM}Sc?{+XnryU>2jbRnOk?Nr!8%i61`)b2JX)H@ov9Z=e&y^p=J;+0L0 zBg2Kg71Du9r~1xBZlI@bM?2=4a>t8816=8($c7JugcetAP<(fLT37=_E70!xmH-d& zhgandq0#qa{hqQnMz4~@b(c^m9Vjl9MQ(sb7imdEciE>oOv?#SC!XN9kuP=(-cHn2 zk39CCi4UC^dp-VaBNmiCZ|v51RVZhu@79ejak)`lI$vg0mDQ)DYjERYMrf@`f82Ri zN(xh~;e~p*DWk=1qO^XifgOR5J50x)N*?3_O%qjJ`!}Wa+U0!CZEDc8iP&)uA8SXa zw>%*yx2GFPBKvcwOG$U)hv+XAY)MGd9JbXIZ8a~yUs-*1!Y04P^crYdUhKmZUfskP z^)If1`(1atqexd`TPVZe6$GvxcS=$&9PwQI$;smV-$N$EVI)_b6{{)A3-#g-b>yho z>LWpw6s|1ZJ|!l@aXRND)uRquv?gHEq1B0J4iE^Gf|`NgZ>r#NLc-PW_hTQy#>L2L zz(u)>jC+DR7XFS$UOq|yug7s*x`GJQz&77@RYIb)yO0L(`gKTO)g(p(r%T2=(l?8L z{+$OY`x=ULFN37RP@;d_K87~fA6C&^X=Rz$XTyF_=x2@_-!rom;wM2AtVHPFq=W7y57;hd8$0bP_12(WDMG+J7*(>zGd=pCYwKf z)-P7I-GNk^eV@EP86Ddg_$Eyr{AXPQY})}jmJO*b)u*moo)Sf1?5tKF@$+|FGwU;` z{G>K43>;0+%0KNvLR|4$A zaSi@)oru8XEkWE!g0eL0*_q6R*;grkU;y#NV$sp2-_DBPp6-Zk^e6bVf6e$3 zko_|V$arDCgcbjSm=b&~>a6BPOsx+Y9Sc@Dw(onGui5-iW>o>-Ig8jk;u(AA4@W$~Dj9(Bs=Hwn%}G{XwS#lzpe8{|m*>+*tqs literal 0 HcmV?d00001 diff --git a/docs/examples/fp8_primer.ipynb b/docs/examples/fp8_primer.ipynb index 788d6c37ae..a8ebd770c6 100644 --- a/docs/examples/fp8_primer.ipynb +++ b/docs/examples/fp8_primer.ipynb @@ -5,9 +5,9 @@ "id": "7b3e6954", "metadata": {}, "source": [ - "# Using FP8 with Transformer Engine\n", + "# Using FP8 and FP4 with Transformer Engine\n", "\n", - "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. In this example we will introduce the FP8 datatype and show how to use it with Transformer Engine.\n", + "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. Blackwell added support for NVFP4 and MXFP8 datatypes. In this example we will introduce these low precision datatypes and show how to use them with Transformer Engine.\n", "\n", "## Introduction to FP8\n", "\n", @@ -100,19 +100,66 @@ "" ] }, + { + "cell_type": "markdown", + "id": "fd7b4f37-50a2-4d41-9067-cf0c471cb2d7", + "metadata": {}, + "source": [ + "## Beyond FP8 - training with NVFP4\n", + "\n", + "In addition to MXFP8, NVIDIA Blackwell introduced support for an even smaller, 4-bit format called NVFP4. The values are represented there in E2M1 format, able to represent values of magnitude up to +/-6.\n", + "\n", + "
\n", + "\n", + "
Figure 8: FP4 E2M1 format can represent values between +/-6.
\n", + "
\n", + "\n", + "### NVFP4 Format\n", + "\n", + "NVFP4 format is similar to MXFP8 - it also uses granular scaling to preserve the dynamic range. The differences are:\n", + "\n", + " - Granularity of the scaling factors: in NVFP4 format a single scaling factor is used per block of 16 elements, whereas MXFP8 uses 1 scaling factor per block of 32 elements\n", + " - Datatype of the scaling factors: NVFP4 uses FP8 E4M3 as the scaling factor per block, whereas MXFP8 uses E8M0 as the scaling factor datatype. Choice of E4M3 for the scaling factor enables preservation of more information about mantissa, but does not enable the full dynamic range of FP32. Therefore, NVFP4 uses an additional single per-tensor FP32 scaling factor to avoid overflows.\n", + "\n", + "In the NVFP4 training recipe for weight tensors we use a different variant of the NVFP4 quantization, where a single scaling factor is shared by a 2D block of 16x16 elements. This is similar to the weight quantization scheme employed in [DeepSeek-v3 training](https://arxiv.org/abs/2412.19437v1), but with a much finer granularity.\n", + "\n", + "### NVFP4 training recipe\n", + "\n", + "The NVFP4 training recipe implemented in Transformer Engine is described in [Pretraining Large Language Models with NVFP4](https://arxiv.org/abs/2509.25149v1) paper. The main elements of the recipe are:\n", + "\n", + " - Stochastic Rounding. When quantizing gradients to NVFP4, we use stochastic rounding to avoid the bias introduced by quantization. With stochastic rounding values are rounded probabilistically to one of their two nearest representable numbers, with probabilities inversely\n", + "proportional to their distances.\n", + " - 2D Scaling. The non-square size of the quantization blocks, while increasing granularity, has a property that the quantized tensor and its transpose no longer hold the same values. This is important since the transposed tensors are used when calculating gradients of the linear layers. While most tensors are not sensitive to this issue during training, it does affect the training accuracy when applied to the weight tensors. Therefore, the weights of the linear layers are quantized using a 2D scheme, where a single scaling factor is shared by a 2D block of 16x16 elements.\n", + " - Random Hadamard Transforms. While microscaling reduces the dynamic range needed to represent tensor values, outliers can still have a\n", + "disproportionate impact on FP4 formats, degrading model accuracy. Random Hadamard transforms address this by reshaping the tensor distribution to be more Gaussian-like, which smooths outliers and makes tensors easier to represent accurately in NVFP4. In Transformer Engine, we use a 16x16 Hadamard matrix for activations and gradients when performing weight gradient computation.\n", + " - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `fp8_autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n", + "\n", + "The full linear layer utilizing NVFP4 is presented in Figure 9.\n", + "\n", + "
\n", + "\n", + "
Figure 9: Linear layer utilizing NVFP4
\n", + "
" + ] + }, { "cell_type": "markdown", "id": "cf5e0b0d", "metadata": {}, "source": [ - "## Using FP8 with Transformer Engine\n", + "## Using FP8 and FP4 with Transformer Engine\n", "\n", - "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n", + "Transformer Engine library provides tools enabling easy to use training with FP8 and FP4 datatypes using different strategies.\n", "\n", "### FP8 recipe\n", "\n", - "The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", - "Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training." + "Transformer Engine defines a range of different low precision recipes to choose from in the `transformer_engine.common.recipe` module.\n", + "\n", + " - The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", + " - [Float8CurrentScaling](../api/common.rst#transformer_engine.common.recipe.Float8CurrentScaling) recipe enables current per-tensor scaling with FP8.\n", + " - [Float8BlockScaling](../api/common.rst#transformer_engine.common.recipe.Float8BlockScaling) recipe enables block scaling with FP8 as described in [DeepSeek-v3 paper](https://arxiv.org/abs/2412.19437v1).\n", + " - [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) recipe enables MXFP8 training.\n", + " - [NVFP4BlockScaling](../api/common.rst#transformer_engine.common.recipe.NVFP4BlockScaling) recipe enables NVFP4 training." ] }, { @@ -122,12 +169,13 @@ "metadata": {}, "outputs": [], "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n", + "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling\n", "\n", "fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", "mxfp8_format = Format.E4M3 # E4M3 used everywhere\n", - "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)" + "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)\n", + "nvfp4_recipe = NVFP4BlockScaling()" ] }, { @@ -135,7 +183,7 @@ "id": "f9591eb5", "metadata": {}, "source": [ - "This recipe is then used to configure the FP8 training." + "This recipe is then used to configure the low precision training." ] }, { @@ -235,13 +283,13 @@ { "data": { "text/plain": [ - "tensor([[ 0.2276, 0.2627, 0.3001, ..., 0.0346, 0.2211, 0.1188],\n", - " [-0.0963, -0.3725, 0.1717, ..., 0.0901, 0.0522, -0.3472],\n", - " [ 0.4526, 0.3482, 0.5976, ..., -0.0687, -0.0382, 0.1566],\n", + "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n", + " [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n", + " [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n", " ...,\n", - " [ 0.1698, 0.6061, 0.0385, ..., -0.2875, -0.1152, -0.0260],\n", - " [ 0.0679, 0.2946, 0.2751, ..., -0.2284, 0.0517, -0.1441],\n", - " [ 0.1865, 0.2353, 0.9172, ..., 0.1085, 0.1135, 0.1438]],\n", + " [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n", + " [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n", + " [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)" ] }, @@ -263,13 +311,13 @@ { "data": { "text/plain": [ - "tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.0233, 0.2498, 0.1131],\n", - " [-0.0767, -0.3778, 0.1862, ..., 0.0858, 0.0676, -0.3369],\n", - " [ 0.4615, 0.3593, 0.5813, ..., -0.0779, -0.0349, 0.1422],\n", + "tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.1134, -0.3661, 0.1650],\n", + " [-0.0767, -0.3778, 0.1862, ..., -0.1370, -0.8448, -0.1770],\n", + " [ 0.4615, 0.3593, 0.5813, ..., 0.1696, -0.8826, -0.1826],\n", " ...,\n", - " [ 0.1914, 0.6038, 0.0382, ..., -0.2847, -0.0991, -0.0423],\n", - " [ 0.0864, 0.2895, 0.2719, ..., -0.2388, 0.0772, -0.1541],\n", - " [ 0.2019, 0.2275, 0.9027, ..., 0.1022, 0.1300, 0.1444]],\n", + " [ 0.1914, 0.6038, 0.0382, ..., 0.4049, -0.4729, 0.0118],\n", + " [ 0.0864, 0.2895, 0.2719, ..., -0.3337, -0.4922, 0.1240],\n", + " [ 0.2019, 0.2275, 0.9027, ..., 0.0706, -0.5481, 0.1356]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)" ] }, @@ -300,13 +348,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.0346, 0.2211, 0.1188],\n", - " [-0.0963, -0.3724, 0.1717, ..., 0.0901, 0.0522, -0.3470],\n", - " [ 0.4526, 0.3479, 0.5976, ..., -0.0686, -0.0382, 0.1566],\n", + "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n", + " [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n", + " [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n", " ...,\n", - " [ 0.1698, 0.6062, 0.0385, ..., -0.2876, -0.1152, -0.0260],\n", - " [ 0.0679, 0.2947, 0.2750, ..., -0.2284, 0.0516, -0.1441],\n", - " [ 0.1865, 0.2353, 0.9170, ..., 0.1085, 0.1135, 0.1438]],\n", + " [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n", + " [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n", + " [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)\n" ] } @@ -339,19 +387,14 @@ { "data": { "text/plain": [ - "tensor([[ 4.9591e-05, -1.9073e-04, 9.5367e-05, ..., -3.8147e-06,\n", - " 4.1962e-05, 2.2888e-05],\n", - " [ 2.2888e-05, -3.4332e-05, 2.2888e-05, ..., 2.6703e-05,\n", - " 5.3406e-05, -1.4114e-04],\n", - " [-3.8147e-05, 2.6703e-04, -3.8147e-06, ..., -5.7220e-05,\n", - " 4.1962e-05, -1.9073e-05],\n", + "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", - " [ 1.1444e-05, -7.2479e-05, -3.8147e-06, ..., 5.3406e-05,\n", - " -1.5259e-05, 2.2888e-05],\n", - " [ 4.9591e-05, -9.5367e-05, 6.8665e-05, ..., -1.5259e-05,\n", - " 7.6294e-05, 4.5776e-05],\n", - " [-1.5259e-05, -7.6294e-06, 1.8692e-04, ..., -3.0518e-05,\n", - " -4.5776e-05, 7.6294e-06]], device='cuda:0', grad_fn=)" + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0',\n", + " grad_fn=)" ] }, "execution_count": 7, @@ -370,6 +413,53 @@ "source": [ "The differences in result coming from FP8 execution do not matter during the training process, but it is good to understand them, e.g. during debugging the model." ] + }, + { + "cell_type": "markdown", + "id": "d45e8b6c-803b-4a4f-8835-c19b0a94bc6a", + "metadata": {}, + "source": [ + "### Using multiple recipes in the same training run\n", + "\n", + "Sometimes it is desirable to use multiple recipes in the same training run. An example of this is the NVFP4 training, where a few layers at the end of the training should be run in higher precision. This can be achieved by using multiple autocasts, either completely separately or in a nested way (this could be useful when e.g. we want to have a configurable overarching recipe but still hardcode a different recipe for some pieces of the network)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c663f694-41d6-47c0-a397-5fc56e692542", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.0547, 0.0039, -0.0664, ..., -0.2061, 0.2344, -0.3223],\n", + " [ 0.0131, -0.1436, 0.0168, ..., -0.4258, 0.1562, -0.0371],\n", + " [ 0.1074, -0.2773, 0.0576, ..., -0.2070, 0.0640, -0.1611],\n", + " ...,\n", + " [ 0.0825, -0.0630, 0.0571, ..., -0.3711, 0.1562, -0.4062],\n", + " [-0.1729, -0.1138, -0.0620, ..., -0.4238, 0.0703, -0.2070],\n", + " [-0.0908, -0.2148, 0.2676, ..., -0.4551, 0.1836, -0.4551]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=<_LinearBackward>)\n" + ] + } + ], + "source": [ + "my_linear1 = te.Linear(768, 768).bfloat16() # The first linear - we want to run it in FP4\n", + "my_linear2 = te.Linear(768, 768).bfloat16() # The second linear - we want to run it in MXFP8\n", + "\n", + "inp = inp.bfloat16()\n", + "\n", + "with te.fp8_autocast(fp8_recipe=nvfp4_recipe):\n", + " y = my_linear1(inp)\n", + " with te.fp8_autocast(fp8_recipe=mxfp8_recipe):\n", + " out = my_linear2(y)\n", + "\n", + "print(out)\n", + "\n", + "out.mean().backward()" + ] } ], "metadata": { From 0db0f4d2d7ca7ae6e761294aedc74b6e30a8aaf4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 6 Oct 2025 12:05:31 -0400 Subject: [PATCH 63/78] [JAX] Fix for GEMM + fuse bias + AllReduce (#2230) * not fuse bias for output all reduction case + unit tests Signed-off-by: Phuong Nguyen * norm to reduce dgamma along tpsp as well Signed-off-by: Phuong Nguyen * clean up tests Signed-off-by: Phuong Nguyen * fix test_distributed_layernorm byte counts Signed-off-by: Phuong Nguyen * increase tols for jax_gemm Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- tests/jax/distributed_test_base.py | 17 +- tests/jax/test_distributed_dense.py | 253 ++++++++++++++++++ tests/jax/test_distributed_layernorm.py | 19 +- tests/jax/test_distributed_layernorm_mlp.py | 51 +++- transformer_engine/jax/cpp_extensions/gemm.py | 84 +++--- .../jax/cpp_extensions/normalization.py | 6 +- .../jax/csrc/extensions/gemm.cpp | 16 +- transformer_engine/jax/sharding.py | 15 ++ 8 files changed, 382 insertions(+), 79 deletions(-) create mode 100644 tests/jax/test_distributed_dense.py diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 7c08539c36..4693086b83 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -17,14 +17,6 @@ def generate_configs(): configs = [] - if is_devices_enough(2): - configs.append( - pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") - ) - configs.append( - pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2") - ) - if is_devices_enough(4): configs.append( pytest.param( @@ -32,10 +24,17 @@ def generate_configs(): (2, 2), ("dp", "tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp"), - id=f"n4_dp2_tp2", + id="n4_dp2_tp2", ) ) + if is_devices_enough(2): + configs.append( + pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") + ) + configs.append( + pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"), + ) return configs diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py new file mode 100644 index 0000000000..9541ccfcbc --- /dev/null +++ b/tests/jax/test_distributed_dense.py @@ -0,0 +1,253 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from functools import partial + +from distributed_test_base import generate_configs +from utils import assert_allclose, pytest_parametrize_wrapper + +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.dense import dense + + +DTYPES = [jnp.bfloat16] + +GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in] + +WEIGHT_SHAPES = [[256, 256]] # [hidden_in, hidden_out] + + +def _generate_inputs(input_shape, weight_shape, dtype): + """Generate test inputs for GEMM operations""" + _, _, hidden_in = input_shape + hidden_in_w, hidden_out = weight_shape + assert hidden_in == hidden_in_w, f"Dimension mismatch: {hidden_in} != {hidden_in_w}" + + bias_shape = (hidden_out,) + + # Generate random inputs + x = random.normal(random.PRNGKey(1124), input_shape, dtype=dtype) + weight = random.normal(random.PRNGKey(2248), weight_shape, dtype=dtype) / jnp.sqrt(hidden_in_w) + bias = random.normal(random.PRNGKey(3372), bias_shape, dtype=dtype) / jnp.sqrt(hidden_out) + + return x, weight, bias + + +def _get_sharding_for_gemm(mesh, mesh_resource, partition_layout="rowwise"): + """Get sharding patterns for GEMM inputs and outputs""" + + dp_axis = mesh_resource.dp_resource + tp_axis = mesh_resource.tpsp_resource + + if partition_layout == "colwise": + x_spec = PartitionSpec(dp_axis, None, None) + weight_spec = PartitionSpec(None, tp_axis) + bias_spec = PartitionSpec(tp_axis) + output_spec = PartitionSpec(dp_axis, None, tp_axis) + elif partition_layout == "rowwise": + x_spec = PartitionSpec(dp_axis, None, tp_axis) + weight_spec = PartitionSpec(tp_axis, None) + bias_spec = PartitionSpec(None) + output_spec = PartitionSpec(dp_axis, None, None) + else: + raise ValueError(f"Invalid partition: {partition_layout}") + + x_sharding = NamedSharding(mesh, x_spec) + weight_sharding = NamedSharding(mesh, weight_spec) + bias_sharding = NamedSharding(mesh, bias_spec) + output_sharding = NamedSharding(mesh, output_spec) + + return x_sharding, weight_sharding, bias_sharding, output_sharding + + +@partial(jax.jit, static_argnames=("contracting_dims", "output_sharding")) +def _jitted_gemm(x, weight, bias, contracting_dims, output_sharding): + output = tex.gemm( + x, + weight, + bias=bias, + contracting_dims=contracting_dims, + fuse_bias=True, + ) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return output + + +# TODO(Phuong): +# 1. Add supported recipes after FP4 is added +# 2. Add communication type/byte checks +class TestDistributedDense: + """Test distributed GEMM without collective operations vs JAX dot""" + + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_configs(), + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES) + @pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES) + @pytest_parametrize_wrapper("partition", ["rowwise", "colwise"]) + def test_distributed_gemm( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + dtype, + input_shape, + weight_shape, + partition, + ): + """Test TE GEMM against JAX dot with bf16 dtype""" + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + # Generate inputs + x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype) + + # Get sharding patterns + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm( + mesh, mesh_resource, partition_layout=partition + ) + + # Shard inputs + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension + + with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + # TE GEMM result + te_result = _jitted_gemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=contracting_dims, + output_sharding=output_sharding, + ) + + # JAX dot reference result + jax_result = ( + jax.lax.dot_general( + x_sharded, weight_sharded, dimension_numbers=(contracting_dims, ((), ())) + ) + + bias_sharded + ) + + assert te_result.sharding == jax_result.sharding + # Ensure computation is complete + jax.block_until_ready(te_result) + jax.block_until_ready(jax_result) + + # Gather results for comparison + gathered_te = jax.lax.with_sharding_constraint( + te_result, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_jax = jax.lax.with_sharding_constraint( + jax_result, NamedSharding(mesh, PartitionSpec(None)) + ) + + # Compare results + assert_allclose(gathered_te, gathered_jax, dtype=dtype) + + def _te_sum_dense(self, x, weight, bias, contracting_dims): + """TE GEMM function for gradient testing""" + return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims)) + + def _jax_sum_dense(self, x, weight, bias, contracting_dims): + """JAX dot function for gradient testing""" + result = ( + jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias + ) + return jnp.sum(result) + + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_configs(), + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES) + @pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES) + @pytest_parametrize_wrapper("partition", ["rowwise", "colwise"]) + def test_te_distributed_dense_grad( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + dtype, + input_shape, + weight_shape, + partition, + ): + """Test TE GEMM gradients against JAX dot gradients""" + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + # Generate inputs + x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype) + + # Get sharding patterns + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm( + mesh, mesh_resource, partition_layout=partition + ) + + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + contracting_dims = ((2,), (0,)) + + with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + # Test gradients w.r.t. all inputs + te_grad_func = jax.jit( + jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), + static_argnames=("contracting_dims",), + ) + jax_grad_func = jax.jit( + jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)), + static_argnames=("contracting_dims",), + ) + + te_val, te_grads = te_grad_func( + x_sharded, weight_sharded, bias_sharded, contracting_dims + ) + jax_val, jax_grads = jax_grad_func( + x_sharded, weight_sharded, bias_sharded, contracting_dims + ) + + # Compare forward pass + assert_allclose(te_val, jax_val, dtype=dtype) + + # Compare gradients + for i, (te_grad, jax_grad) in enumerate(zip(te_grads, jax_grads)): + te_grad_spec = tuple(i for i in te_grad.sharding.spec if i is not None) + jax_grad_spec = tuple(i for i in jax_grad.sharding.spec if i is not None) + assert te_grad_spec == jax_grad_spec, f"Gradient sharding mismatch at te_grads[{i}]" + gathered_te_grad = jax.lax.with_sharding_constraint( + te_grad, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_jax_grad = jax.lax.with_sharding_constraint( + jax_grad, NamedSharding(mesh, PartitionSpec(None)) + ) + assert_allclose( + gathered_te_grad, + gathered_jax_grad, + dtype=dtype, + err_msg=f"Gradient mismatch for argument {i}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index f3296277c8..5fa08fa089 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -66,18 +66,19 @@ def generate_collectives_count_ref( self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) + # TODO(Phuong) is_dp_enabled = dp mesh axis size > 1 is_dp_enabled = mesh_resource.dp_resource is not None + is_tpsp_enabled = mesh_resource.tpsp_resource is not None assert ln_type in ["layernorm", "rmsnorm"] - all_reduce_loss_bytes = 4 # 1 * FP32 - # for loss, dgamma and dbeta - # TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp - weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1 - allreduce_total_bytes = ( - all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize - ) - other_bytes = 0 + # loss, 1 FP32 + allreduce_total_bytes = 4 if is_dp_enabled else 0 + # dgamma and dbeta + weight_count = 2 if ln_type == "layernorm" else 1 + allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize return generate_collectives_count( - allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes + allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled), + allgather=0, + other=0, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index a44921c641..d38f43d002 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -48,7 +48,7 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) @@ -59,19 +59,47 @@ LN_BIAS_AXES = (W_NO_SHARD_AXES,) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) -INTERMEDIATE = 64 +INTERMEDIATE = 256 # Only test with FSDP and TPSP as DP is not used def generate_fsdp_and_tpsp_configs(): configs = [] + if is_devices_enough(4): + configs.append( + pytest.param( + [ + 4, + (2, 2), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp2_tpsp2", + ) + ) + if is_devices_enough(2): configs.append( - [2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] + pytest.param( + [ + 2, + (1, 2), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp1_tpsp2", + ) ) - if is_devices_enough(4): configs.append( - [4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] + pytest.param( + [ + 2, + (2, 1), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp2_tpsp1", + ), ) return configs @@ -229,10 +257,7 @@ def _test_layernorm_mlp_grad( fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 - if fwd_test_type == jnp.float16 and use_bias: - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5) - else: - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: @@ -381,6 +406,7 @@ def _test_layernorm_mlp( assert_tree_like_allclose(params_sharded["params"], params_single["params"]) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) + # TODO(Phuong): check if these tols updates are still needed atol = None rtol = None l40_tolerance_update = ( @@ -404,9 +430,10 @@ def _test_layernorm_mlp( # within tolerance to the float32 ground truth. jax_triton_gemm_precision_tolerance_update = ( with_jax_gemm - and isinstance(fp8_recipe, recipe.Float8CurrentScaling) - and dtype == jnp.bfloat16 - and activation_type == ("gelu", "linear") + and fp8_recipe is not None + and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()) + and dtype in (jnp.bfloat16, jnp.float16) + and activation_type == ("gelu", "linear"), ) if jax_triton_gemm_precision_tolerance_update: atol = 0.08 diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e5fcdac3c8..865efe89da 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -451,23 +451,19 @@ def _dims_are_consecutive(dims): output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) # Validate bias - bias_shape = (0,) - bias_dtype = out_dtype if fuse_bias: - expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) - if not grad: - assert bias.size == expected_bias_size, ( - "cuBLAS GEMM bias tensor has incorrect shape, " - f"expected ({expected_bias_size}, ) but found {bias.shape}." - ) - assert bias.dtype == out_dtype, ( - "cuBLAS GEMM bias tensor has incorrect data type, " - f"expected {bias_dtype} but found {bias.dtype}." - ) - bias_shape = bias.shape - else: - bias_shape = rhs_non_contracting_shape - bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + assert bias.shape == tuple(rhs_non_contracting_shape), ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {out_dtype} but found {bias.dtype}." + ) + # WAR: allocate dbias regardless of fuse_bias so that the sharding propagation works as we + # change the fuse_bias value in the sharded_impl + dbias_shape = bias.shape if grad else (0,) + bias_grad = jax.core.ShapedArray(shape=dbias_shape, dtype=bias.dtype) # Validate pre-GeLU pre_gelu_shape = (0,) @@ -548,7 +544,7 @@ def lowering( } operand_output_aliases = {} - if fuse_bias and not grad: + if grad: operand_output_aliases.update({4: 1}) # bias <-> bias_grad if fuse_gelu and grad: operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out @@ -927,7 +923,6 @@ def infer_sharding_from_operands( del ( out_dtype, scaling_mode, - grad, use_split_accumulator, result_infos, is_outer, @@ -941,8 +936,8 @@ def infer_sharding_from_operands( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - # Discard bias gradient spec if there is no bias fusion - if not fuse_bias: + # Discard dbias gradient spec if there is no bias and grad fusion + if not (fuse_bias and grad): dbias_specs = (None,) dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) @@ -1008,8 +1003,8 @@ def partition( # Assemble output shardings out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] - # Discard bias gradient spec if there is no bias fusion - if not fuse_bias: + # Discard bias gradient spec if there is no bias and grad fusion + if not (fuse_bias and grad): dbias_specs = (None,) out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) @@ -1019,6 +1014,8 @@ def partition( out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + # We should not fuse bias in the output reduction case + sharded_fuse_bias = fuse_bias and reduce_spec is None outputs = GemmPrimitive.impl( lhs, lhs_scale_inv, @@ -1029,7 +1026,7 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, - fuse_bias=fuse_bias, + fuse_bias=sharded_fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, @@ -1039,13 +1036,17 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): collective_op=collective_op, ) - if reduce_spec is not None and not collective_op.is_reduce_scatter: - if is_all_reduce_in_float32(): # For unittest only - outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype( - out_dtype - ) - else: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + if reduce_spec is not None: + if not collective_op.is_reduce_scatter: + if is_all_reduce_in_float32(): # For unittest only + outputs[0] = jax.lax.psum( + outputs[0].astype(jnp.float32), reduce_spec + ).astype(out_dtype) + else: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + + if fuse_bias: # TODO(Phuong): rename fuse_bias to has_bias + outputs[0] += bias return outputs @@ -1068,7 +1069,7 @@ def shardy_sharding_rule( operand_types, result_types, ): - del out_dtype, grad, use_split_accumulator + del out_dtype, use_split_accumulator del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer if not collective_op.is_none: @@ -1079,12 +1080,6 @@ def shardy_sharding_rule( prefix = "Gemm_" - warnings.warn( - "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," - " please turn off Shardy by exporting the environment variable" - " 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems." - ) - def _generate_operand_rules(name, ndim, cdims): specs = [] ldims = tuple(i for i in range(ndim) if i not in cdims) @@ -1118,7 +1113,8 @@ def _generate_operand_rules(name, ndim, cdims): rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) out_spec = (*lhs_non_cspec, *rhs_non_cspec) bias_spec = rhs_non_cspec if fuse_bias else ("…4",) - gelu_spec = out_spec if fuse_gelu else ("…5",) + dbias_spec = bias_spec if grad else ("…5") + gelu_spec = out_spec if fuse_gelu else ("…6",) return SdyShardingRule( operand_mappings=( @@ -1131,7 +1127,7 @@ def _generate_operand_rules(name, ndim, cdims): ), result_mappings=( out_spec, - bias_spec, + dbias_spec, gelu_spec, ), ) @@ -1161,6 +1157,13 @@ def _te_gemm( collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: + if grad or fuse_gelu: + warnings.warn( + "GEMM + fused grad or fused gelu is not well tested and will be deprecated in the" + " future", + DeprecationWarning, + ) + # Prepare non-quantized GEMM operands lhs_data = lhs rhs_data = rhs @@ -1228,7 +1231,7 @@ def _te_gemm( grad=grad, use_split_accumulator=use_split_accumulator, transpose_batch_sequence=transpose_batch_sequence, - sequence_dim=-1, + sequence_dim=-1, # Dummy value and will be set in the primitive is_outer=True, collective_op=collective_op, ) @@ -1618,6 +1621,7 @@ def gemm( rhs_quantizer = quantizer_set.kernel # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + # TODO(Phuong): fuse_bias -> has_bias and has_bias = bias is not None fuse_bias = kwargs.get("fuse_bias", False) fuse_gelu = kwargs.get("fuse_gelu", False) if not GemmPrimitive.enabled(): diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3348c725be..ef63736880 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -28,7 +28,7 @@ get_cudnn_version, ) from .quantization import _quantize_dbias_impl, AmaxScope -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, @@ -801,9 +801,9 @@ def sharded_impl(dz, x, mu, rsigma, gamma): norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, ) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) + global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh) if norm_type == NVTE_Norm_Type.LayerNorm: - global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) + global_dbeta = all_reduce_sum_along_dp_fsdp_tpsp(local_dbeta, mesh) else: global_dbeta = local_dbeta return local_dx, global_dgamma, global_dbeta diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 1467fa8873..f2007efcf6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -158,18 +158,18 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; - std::vector bias_shape = {0}; + size_t bias_size = 0; DType bias_dtype = out_dtype; if (fuse_bias) { - if (!grad) { + if (grad) { NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); } - bias_ptr = bias_grad->untyped_data(); - bias_shape.at(0) = bias_grad->dimensions().front(); - bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + bias_ptr = bias.untyped_data(); + bias_size = product(bias.dimensions()); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); } - auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + auto bias_ = TensorWrapper(bias_ptr, std::vector{bias_size}, bias_dtype); // Pre-GeLU output from forward pass or input to backward pass void *pre_gelu_ptr = nullptr; @@ -202,6 +202,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", output->element_count(), " elements ", to_string_like(output->dimensions())); + NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, + ", out_shape[1]=", out_shape[1]); nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, workspace_.data(), false, @@ -220,6 +222,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i buffer_shape[1] = out_shape[1]; out_shape[0] = out_shape[0] / comm_handler.tp_size; } + NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, + ", out_shape[1]=", out_shape[1]); auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( buffer_shape, buffer_dtype, collective_op); if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index d3a7952d39..8eeaca4cc8 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -365,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) +def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh): + """Perform all-reduce sum operation along data parallelism and sequence parallelism axes. + + Args: + x: Input tensor to reduce + mesh: JAX mesh for distributed computation + + Returns: + Reduced tensor + """ + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh) + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh) + return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) + + def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): """Perform all-reduce max operation along all axes except pipeline parallelism. From 56e2fede5ace495c0aa817e802ea3504c5974b11 Mon Sep 17 00:00:00 2001 From: Kiv Chen <34561254+KivenChen@users.noreply.github.com> Date: Mon, 6 Oct 2025 09:38:57 -0700 Subject: [PATCH 64/78] [Build] fix: TE installation failed to find uv-installed cuDNN libraries (#2207) [Build] fix: python platlib path Signed-off-by: Kiv Chen Co-authored-by: Kirthi Shankar Sivamani --- build_tools/build_ext.py | 1 + 1 file changed, 1 insertion(+) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 3aa45f0241..349858ac49 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -57,6 +57,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: build_dir, f"-DPython_EXECUTABLE={sys.executable}", f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}", + f"-DPython_SITEARCH={sysconfig.get_path('platlib')}", f"-DCMAKE_BUILD_TYPE={build_type}", f"-DCMAKE_INSTALL_PREFIX={install_dir}", ] From 9f3e79bff824d3a9f10267dc414308011c87b093 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 6 Oct 2025 15:01:04 -0700 Subject: [PATCH 65/78] =?UTF-8?q?[PyTorch]=20Fix=20tests=20for=20?= =?UTF-8?q?=F0=9F=A4=97=20integration=20(#2239)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update test requirements for HF Signed-off-by: Kirthi Shankar Sivamani * Update build_tools/pytorch.py Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- build_tools/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index a974e370d7..3d44d8740c 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -19,7 +19,7 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: """Test dependencies for TE/JAX extensions.""" - return ["numpy", "torchvision", "transformers"] + return ["numpy", "torchvision", "transformers", "torchao==0.13"] def setup_pytorch_extension( From 127b6d3ab3088c403f3b38b8405b70fc33ee3f34 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 7 Oct 2025 10:59:24 -0400 Subject: [PATCH 66/78] [JAX] Activation/Normalization to output amax for later quantization in CurrentScaling (#2238) * reuse amax for current scaling Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 371 ++++++++++++++---- .../jax/cpp_extensions/normalization.py | 198 ++++++++-- .../jax/cpp_extensions/quantization.py | 61 +-- .../jax/csrc/extensions/activation.cpp | 99 +++-- .../jax/csrc/extensions/normalization.cpp | 60 +-- .../jax/csrc/extensions/quantization.cpp | 6 +- transformer_engine/jax/dense.py | 30 +- transformer_engine/jax/flax/module.py | 16 + transformer_engine/jax/flax/transformer.py | 5 + transformer_engine/jax/layernorm_dense.py | 19 + transformer_engine/jax/layernorm_mlp.py | 41 +- 11 files changed, 677 insertions(+), 229 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 925c1d01ae..be1f9f9564 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -148,7 +148,6 @@ class ActLuPrimitive(BasePrimitive): name = "te_act_lu_ffi" multiple_results = True impl_static_args = ( - 2, 3, 4, 5, @@ -156,7 +155,11 @@ class ActLuPrimitive(BasePrimitive): 7, 8, 9, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params + 10, + 11, + 12, + 13, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer inner_primitive = None outer_primitive = None @@ -164,6 +167,7 @@ class ActLuPrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, *, out_dtype, act_enum, @@ -171,16 +175,23 @@ def abstract( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_act_lu_p abstract """ - del act_enum, act_params + del act_enum, act_params, amax_scope, transpose_batch_sequence + assert ( + not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value + ), f"scaling_mode = {scaling_mode}" dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 assert x_aval.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x_aval.shape} and act_len {act_len}" @@ -215,6 +226,7 @@ def lowering( ctx, x, scale, + amax, *, out_dtype, act_enum, @@ -222,24 +234,34 @@ def lowering( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_gated_act_lu_p lowering rules """ - del out_dtype, scale_dtype, act_len, is_outer - x_aval, scale_aval = ctx.avals_in + del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence + x_aval, scale_aval, amax_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( + assert amax_aval.dtype == jnp.float32 + + out = ffi.ffi_lowering( + ActLuPrimitive.name, + operand_output_aliases={2: 4}, # donate amax buffer to updated_amax + )( ctx, x, scale, + amax, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x, act_params=act_params.to_ffi_lowering_dict(), + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return out @@ -247,14 +269,18 @@ def lowering( def impl( x, scale, + amax, out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ to describe implementation @@ -266,14 +292,18 @@ def impl( ActLuPrimitive.inner_primitive.bind( x, scale, + amax, out_dtype=out_dtype, act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, - is_outer=False, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=False, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -301,17 +331,19 @@ def batcher( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ to describe batch rules for vmap """ - del act_len, is_outer check_valid_batch_dims(batch_dims) assert ActLuPrimitive.outer_primitive is not None - x, scale = batched_args - x_bdim, scale_bdim = batch_dims + x, scale, amax = batched_args + x_bdim, scale_bdim, _ = batch_dims amax_bdim = scale_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim @@ -319,12 +351,18 @@ def batcher( ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=out_dtype, act_enum=act_enum, + act_len=act_len, scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -337,8 +375,11 @@ def infer_sharding_from_operands( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, @@ -349,8 +390,11 @@ def infer_sharding_from_operands( act_enum, scale_dtype, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -402,13 +446,16 @@ def partition( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, ): - del result_infos, is_outer # Unused. + del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -452,26 +499,40 @@ def partition( amax_sharding, ) - def sharded_impl(x, scale): - local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = ( - ActLuPrimitive.impl( - x, - scale, - out_dtype=out_dtype, - act_enum=act_enum, - act_len=act_len, - scaling_mode=scaling_mode, - is_2x=is_2x, - scale_dtype=scale_dtype, - is_outer=True, - act_params=act_params, - ) + def sharded_impl(x, scale, amax): + ( + local_x, + local_colwise_x, + local_scale_inv, + local_colwise_scale_inv, + local_updated_amax, + ) = ActLuPrimitive.impl( + x, + scale, + amax, + out_dtype=out_dtype, + act_enum=act_enum, + act_len=act_len, + scaling_mode=scaling_mode, + is_2x=is_2x, + scale_dtype=scale_dtype, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, out_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return ( local_x, @@ -491,13 +552,28 @@ def shardy_sharding_rule( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params + del ( + out_dtype, + act_enum, + act_len, + scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, + mesh, + result_types, + ) prefix = "ActLu_" input_shape = value_types[0].shape output_shape = input_shape[:-2] + input_shape[-1:] @@ -526,6 +602,7 @@ def shardy_sharding_rule( ( x_axes, ("…1",), + amax, ), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), **scale_rules.factor_sizes, @@ -543,8 +620,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer + impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -553,6 +630,7 @@ def abstract( dz_aval, x_aval, scale_aval, + amax_aval, *, out_dtype, scaling_mode, @@ -561,13 +639,16 @@ def abstract( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_dact_dbias_quantize_p abstract """ - del act_enum, act_params + del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_dtype @@ -576,6 +657,7 @@ def abstract( f" {x_aval.shape} and act_len {act_len}" ) assert scale_aval.dtype == jnp.float32 + assert amax_aval.dtype == jnp.float32 assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( "Current tensor scaling is not supported for fused dact and quantization. Please do" @@ -655,6 +737,7 @@ def lowering( dz, x, scale, + amax, *, out_dtype, scaling_mode, @@ -663,27 +746,42 @@ def lowering( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_dact_dbias_quantize_p lowering rules """ - del out_dtype, scale_dtype, act_len, is_outer - dz_aval, x_aval, scale_aval = ctx.avals_in + del ( + out_dtype, + scale_dtype, + act_len, + is_outer, + amax_scope, + transpose_batch_sequence, + ) + dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_aval.dtype - assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)( + assert scale_aval.dtype == amax_aval.dtype == jnp.float32 + return ffi.ffi_lowering( + BaseDActLuDBiasQuantizePrimitive.name, + operand_output_aliases={3: 4}, # donate amax buffer to updated_amax + )( ctx, dz, x, scale, + amax, scaling_mode=scaling_mode.value, is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), act_params=act_params.to_ffi_lowering_dict(), + output_amax_when_no_scaling=output_amax_when_no_scaling, ) @staticmethod @@ -691,6 +789,7 @@ def impl( dz, x, scale, + amax, out_dtype, scaling_mode, is_2x, @@ -698,8 +797,11 @@ def impl( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_dact_dbias_quantize_p impl @@ -711,6 +813,7 @@ def impl( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, is_2x=is_2x, @@ -718,8 +821,11 @@ def impl( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, - is_outer=False, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=False, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -747,17 +853,19 @@ def batcher( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ to describe batch rules for vmap """ - del is_outer check_valid_batch_dims(batch_dims) assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None - dz, x, scale = batched_args - _, x_bdim, scale_bdim = batch_dims + dz, x, scale, amax = batched_args + _, x_bdim, scale_bdim, _ = batch_dims out_bdims = ( x_bdim, # rowwise output @@ -772,6 +880,7 @@ def batcher( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, is_2x=is_2x, @@ -780,6 +889,10 @@ def batcher( act_enum=act_enum, act_len=act_len, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -793,14 +906,18 @@ def infer_sharding_from_operands( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum, act_params - del scale_dtype, act_len, is_outer + del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling + del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence + x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -869,8 +986,11 @@ def partition( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, @@ -937,12 +1057,13 @@ def partition( dbias_sharding, ) - def sharded_impl(dz, x, scale): - (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = ( + def sharded_impl(dz, x, scale, amax): + (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = ( BaseDActLuDBiasQuantizePrimitive.impl( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, is_2x=is_2x, @@ -950,8 +1071,11 @@ def sharded_impl(dz, x, scale): is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, - is_outer=True, act_params=act_params, + output_amax_when_no_scaling=output_amax_when_no_scaling, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + is_outer=True, ) ) if is_dbias: @@ -960,9 +1084,15 @@ def sharded_impl(dz, x, scale): global_dbias = local_dbias if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, x_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias @@ -977,14 +1107,30 @@ def shardy_sharding_rule( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params + del ( + out_dtype, + scale_dtype, + act_enum, + act_len, + act_params, + is_outer, + output_amax_when_no_scaling, + mesh, + result_types, + amax_scope, + transpose_batch_sequence, + ) + prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 @@ -1006,7 +1152,7 @@ def shardy_sharding_rule( amax = (prefix + "amax",) return SdyShardingRule( - (dz_axes, x_axes, ("…2",)), + (dz_axes, x_axes, ("…2",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), **scale_rules.factor_sizes, ) @@ -1092,6 +1238,8 @@ def act_lu( quantizer: Optional[Quantizer] = None, act_params: Optional[ActivationParams] = None, amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1108,6 +1256,8 @@ def act_lu( If quantizer is provided: A ScaledTensor containing the quantized activated input. """ + # TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl() + # Do the same with dact_dbias_quantize() and layernorm_fwd() act_type_id = ActivationEnum[activation_type].value act_len = len(activation_type) assert x.shape[-2] == act_len, ( @@ -1123,30 +1273,44 @@ def act_lu( return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params + f=act_lu, + x=x, + activation_type=activation_type, + quantizer=quantizer, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) + amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,) + if quantizer is None: - out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( + out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) out = out.reshape(output_shape) + # TODO(Phuong): ScaledTensorFactory to create NoScaledTensor out = NoScaleTensor( data=out, - amax=None, + amax=updated_amax if output_amax_when_no_scaling else None, ) return out @@ -1157,6 +1321,9 @@ def act_lu( activation_type=activation_type, quantizer=None, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, _ = _quantize_dbias_impl( out, @@ -1164,6 +1331,7 @@ def act_lu( quantizer=quantizer, dq_dtype=x.dtype, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out if isinstance(quantizer, DelayedScaleQuantizer): @@ -1178,14 +1346,18 @@ def act_lu( ) = ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=quantizer.q_dtype, act_enum=act_type_id, act_len=act_len, scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) quantizer.update(updated_amax) @@ -1209,6 +1381,9 @@ def quantize_dact_dbias( is_dbias: bool = True, quantizer: Optional[Quantizer] = None, act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1232,7 +1407,8 @@ def quantize_dact_dbias( f" {x.shape} and act_len {act_len}" ) - scale = jnp.empty((), jnp.float32) + scale = jnp.empty((1,), jnp.float32) + amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,) act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive if not PrimitiveClass.enabled() or ( @@ -1240,10 +1416,11 @@ def quantize_dact_dbias( ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: - output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind( dz, x, scale, + amax, # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset @@ -1253,8 +1430,11 @@ def quantize_dact_dbias( is_dbias=False, act_enum=act_type_id, act_len=act_len, - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) output = output.astype(x.dtype) dbias = None @@ -1263,7 +1443,7 @@ def quantize_dact_dbias( output = NoScaleTensor( data=output, - amax=None, + amax=updated_amax if output_amax_when_no_scaling else None, ) return output, dbias @@ -1275,9 +1455,18 @@ def quantize_dact_dbias( activation_type, quantizer=None, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return _quantize_dbias_impl( - out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + out.data, + quantizer, + is_dbias=True, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) is_gated = act_len == 2 @@ -1292,6 +1481,9 @@ def quantize_dact_dbias( quantizer=quantizer, flatten_axis=-2, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) if war_output is not None: return war_output @@ -1304,9 +1496,18 @@ def quantize_dact_dbias( activation_type=activation_type, quantizer=None, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, dbias = _quantize_dbias_impl( - out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + out, + is_dbias=is_dbias, + quantizer=quantizer, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, dbias @@ -1320,9 +1521,17 @@ def quantize_dact_dbias( x.astype(jnp.float32), activation_type=activation_type, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) out, dbias = _quantize_dbias_impl( - dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + dgated, + quantizer, + is_dbias=True, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, dbias @@ -1337,6 +1546,7 @@ def quantize_dact_dbias( dz, x, scale, + amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), @@ -1344,8 +1554,11 @@ def quantize_dact_dbias( is_dbias=is_dbias, act_enum=act_type_id, act_len=act_len, - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1375,6 +1588,9 @@ def dact_lu( activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1396,5 +1612,8 @@ def dact_lu( is_dbias=False, quantizer=quantizer, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return output diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index ef63736880..3ce8a19a76 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -92,7 +92,7 @@ class NormFwdPrimitive(BasePrimitive): name = "te_norm_forward_ffi" multiple_results = True - impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11) + impl_static_args = (5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -100,6 +100,7 @@ class NormFwdPrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, gamma_aval, beta_aval, *, @@ -110,15 +111,27 @@ def abstract( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ LayerNorm fwd inner primitive abstract """ + del amax_scope, transpose_batch_sequence + assert not output_amax_when_no_scaling or ( + scaling_mode == ScalingMode.NO_SCALING.value + and not is_norm_fwd_cudnn_enabled(scaling_mode) + ), ( + f"scaling_mode = {scaling_mode}," + f" use_cudnn_norm_fwd={is_norm_fwd_cudnn_enabled(scaling_mode)}" + ) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 assert ( scaling_mode != ScalingMode.MXFP8_1D_SCALING.value @@ -220,6 +233,7 @@ def lowering( ctx, x, scale, + amax, gamma, beta, *, @@ -230,16 +244,20 @@ def lowering( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ LayerNorm fwd lowering rules """ - del out_dtype, scale_dtype, is_outer - x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in + del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence + x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape @@ -251,10 +269,14 @@ def lowering( assert g_shape == b_shape sm_margin = get_forward_sm_margin() - return ffi.ffi_lowering(NormFwdPrimitive.name)( + return ffi.ffi_lowering( + NormFwdPrimitive.name, + operand_output_aliases={2: 4}, # amax <-> updated_amax + )( ctx, x, scale, + amax, gamma, beta, norm_type=norm_type.value, @@ -263,12 +285,14 @@ def lowering( sm_margin=sm_margin, scaling_mode=scaling_mode.value, is_2x=is_2x, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) @staticmethod def impl( x, scale, + amax, gamma, beta, norm_type, @@ -278,6 +302,9 @@ def impl( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ @@ -297,6 +324,7 @@ def impl( ) = NormFwdPrimitive.inner_primitive.bind( x, scale, + amax, gamma, beta, norm_type=norm_type, @@ -306,6 +334,9 @@ def impl( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=False, ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -341,16 +372,18 @@ def batcher( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ to describe batch rules for vmap """ - del is_outer check_valid_batch_dims(batch_dims) assert NormFwdPrimitive.outer_primitive is not None - x, scale, gamma, beta = batched_args - x_bdim, scale_bdim, _, _ = batch_dims + x, scale, amax, gamma, beta = batched_args + x_bdim, scale_bdim, _, _, _ = batch_dims out_bdims = ( x_bdim, # rowwise output @@ -363,8 +396,9 @@ def batcher( ) return ( NormFwdPrimitive.outer_primitive.bind( - scale, x, + scale, + amax, gamma, beta, norm_type=norm_type, @@ -374,6 +408,10 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -387,15 +425,19 @@ def infer_sharding_from_operands( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, result_infos, ): del zero_centered_gamma, epsilon, out_dtype, result_infos - del scale_dtype, is_outer + del scale_dtype, is_outer, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( @@ -415,9 +457,9 @@ def infer_sharding_from_operands( mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_inv_spec = amax_spec = (None,) + scale_inv_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec + scale_inv_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec @@ -445,6 +487,9 @@ def partition( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, @@ -453,8 +498,9 @@ def partition( del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) - g_spec = get_padded_spec(arg_infos[2]) - b_spec = get_padded_spec(arg_infos[3]) + amax_spec = get_padded_spec(arg_infos[2]) + g_spec = get_padded_spec(arg_infos[3]) + b_spec = get_padded_spec(arg_infos[4]) out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: @@ -485,9 +531,9 @@ def partition( mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_inv_spec = amax_spec = (None,) + scale_inv_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec + scale_inv_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec @@ -499,10 +545,10 @@ def partition( arg_shardings = list(arg_i.sharding for arg_i in arg_infos) # Enforce no sharding of hidden dim for x, gamma and beta arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x") - arg_shardings[2] = NamedSharding( + arg_shardings[3] = NamedSharding( mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma" ) - arg_shardings[3] = NamedSharding( + arg_shardings[4] = NamedSharding( mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta" ) arg_shardings = tuple(arg_shardings) @@ -516,19 +562,20 @@ def partition( rsigma_sharding, ) - def sharded_impl(x, scale, gamma, beta): + def sharded_impl(x, scale, amax, gamma, beta): # expect tp and dp giving same shape, or tp being same shape as global ( local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, - local_amax, + local_updated_amax, local_mu, local_rsigma, ) = NormFwdPrimitive.impl( x, scale, + amax, gamma, beta, norm_type=norm_type, @@ -538,12 +585,21 @@ def sharded_impl(x, scale, gamma, beta): scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, x_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return ( local_x, @@ -566,6 +622,9 @@ def shardy_sharding_rule( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, value_types, @@ -576,6 +635,9 @@ def shardy_sharding_rule( epsilon, out_dtype, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, result_types, @@ -594,7 +656,7 @@ def shardy_sharding_rule( amax = (prefix + "amax",) return SdyShardingRule( - (x_axes, ("…1",), ("…2",), ("…3",)), + (x_axes, ("…1",), amax, ("…2",), ("…3",)), ( out, colwise_out, @@ -882,6 +944,8 @@ def layernorm_fwd( epsilon: float, quantizer: Optional[Quantizer], amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: """Layer normalization forward pass with optional quantization. @@ -896,6 +960,7 @@ def layernorm_fwd( epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -918,10 +983,12 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) + amax = jnp.zeros((1,), dtype=jnp.float32) if quantizer is None: - output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( + output, _, _, _, updated_amax, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.LayerNorm, @@ -931,18 +998,37 @@ def layernorm_fwd( scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, + amax_scope=amax_scope, + transpose_batch_sequence=False, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) - return NoScaleTensor(data=output, amax=None), mu, rsigma + # cuDNN does not support amax output for non quantized output + updated_amax = ( + updated_amax + if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING) + else None + ) + return NoScaleTensor(data=output, amax=updated_amax), mu, rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): out, mu, rsigma = layernorm_fwd( - x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=False, + ) + out, _ = _quantize_dbias_impl( + out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence ) - out, _ = _quantize_dbias_impl(out, quantizer) return out, mu, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -954,6 +1040,9 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, _ = _quantize_dbias_impl( out, @@ -961,6 +1050,7 @@ def layernorm_fwd( quantizer=quantizer, dq_dtype=x.dtype, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, mu, rsigma @@ -979,6 +1069,7 @@ def layernorm_fwd( ) = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.LayerNorm, @@ -988,6 +1079,9 @@ def layernorm_fwd( scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) quantizer.update(updated_amax) @@ -1091,7 +1185,9 @@ def rmsnorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], - amax_scope: AmaxScope = AmaxScope.LOCAL, + amax_scope: AmaxScope = AmaxScope.TPSP, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: """Root mean square normalization forward pass with optional quantization. @@ -1104,6 +1200,7 @@ def rmsnorm_fwd( epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -1127,12 +1224,14 @@ def rmsnorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) + amax = jnp.zeros((1,), dtype=jnp.float32) beta = jnp.ones((1,), dtype=jnp.float32) if quantizer is None: - output, _, _, _, _, _, rsigma = NormFwdPrimitive.outer_primitive.bind( + output, _, _, _, updated_amax, _, rsigma = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.RMSNorm, @@ -1142,16 +1241,39 @@ def rmsnorm_fwd( scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) - return NoScaleTensor(data=output, amax=None), rsigma + # cuDNN does not support amax output for non quantized output + updated_amax = ( + updated_amax + if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING) + else None + ) + return NoScaleTensor(data=output, amax=updated_amax), rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): - out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) - out, _ = _quantize_dbias_impl(out.data, quantizer) + out, rsigma = rmsnorm_fwd( + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=False, + ) + out, _ = _quantize_dbias_impl( + out.data, + quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) return out, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1162,13 +1284,17 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, _ = _quantize_dbias_impl( - out.data, + out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, rsigma @@ -1187,6 +1313,7 @@ def rmsnorm_fwd( ) = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.RMSNorm, @@ -1196,6 +1323,9 @@ def rmsnorm_fwd( scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) quantizer.update(updated_amax) @@ -1294,6 +1424,7 @@ def normalization_fwd( norm_type: str, quantizer: Optional[Quantizer], amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, ): """Common wrapper for normalization forward pass. @@ -1311,6 +1442,7 @@ def normalization_fwd( - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -1336,6 +1468,7 @@ def normalization_fwd( epsilon, quantizer, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) elif norm_type == "rmsnorm": assert ( @@ -1348,6 +1481,7 @@ def normalization_fwd( epsilon, quantizer, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) mu = None else: diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 9f9e8fec06..38fd50a00f 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -543,6 +543,18 @@ class AmaxScope(Enum): TPSP = 2 FSDP = 3 + def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh): + """Reduce the amax based on its scope""" + gmesh = global_mesh_resource() + sequence_dim = 0 if transpose_batch_sequence else 1 + # Run AR across TPSP only when tensor-sequence is detected in the input spec + if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource: + return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + # Run AR across FSDP + if self is AmaxScope.FSDP: + return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + return amax + class AmaxCalculationPrimitive(BasePrimitive): """ @@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive): impl_static_args = ( 1, 2, - ) # amax_scope, batch_sequence_transpose + ) # amax_scope, transpose_batch_sequence inner_primitive = None outer_primitive = None @@ -563,12 +575,12 @@ def abstract( x_aval, *, amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, ): """ amax calcuation abstract """ - del amax_scope, batch_sequence_transpose + del amax_scope, transpose_batch_sequence dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -580,19 +592,19 @@ def abstract( def impl( x, amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, ): """ amax calcuation implementation """ - del amax_scope, batch_sequence_transpose + del amax_scope, transpose_batch_sequence amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) return amax @staticmethod def infer_sharding_from_operands( amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, mesh, arg_infos, result_infos, @@ -600,7 +612,7 @@ def infer_sharding_from_operands( """ amax calcuation infer_sharding_from_operands """ - del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused. + del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. amax_sharding = NamedSharding( mesh, PartitionSpec(None), @@ -611,7 +623,7 @@ def infer_sharding_from_operands( @staticmethod def partition( amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, mesh, arg_infos, result_infos, @@ -631,16 +643,11 @@ def sharded_impl(x): amax = AmaxCalculationPrimitive.impl( x, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh ) - gmesh = global_mesh_resource() - sequence_dim = 0 if batch_sequence_transpose else 1 - # Run AR across TPSP only when tensor-sequence is detected in the input spec - if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource: - amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) - # Run AR across FSDP - if amax_scope is AmaxScope.FSDP: - amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) return amax @@ -648,11 +655,11 @@ def sharded_impl(x): return mesh, sharded_impl, amax_sharding, arg_shardings @staticmethod - def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types): + def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types): """ amax calcuation shardy_sharding_rule """ - del amax_scope, batch_sequence_transpose, mesh, result_types + del amax_scope, transpose_batch_sequence, mesh, result_types prefix = "AmaxCal" input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) output_spec = (f"{prefix}_amax",) @@ -709,7 +716,7 @@ def _quantize_dbias_impl( dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -755,12 +762,12 @@ def _quantize_dbias_impl( dq_dtype=dq_dtype, flatten_axis=flatten_axis, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - scale = jnp.empty((), jnp.float32) + scale = jnp.empty((1,), jnp.float32) amax = None if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. @@ -771,7 +778,7 @@ def _quantize_dbias_impl( amax = AmaxCalculationPrimitive.outer_primitive.bind( x.data, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: @@ -845,7 +852,7 @@ def quantize( quantizer: Quantizer, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -866,7 +873,7 @@ def quantize( quantizer=quantizer, flatten_axis=flatten_axis, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) return out @@ -877,7 +884,7 @@ def quantize_dbias( is_dbias: bool = True, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -904,7 +911,7 @@ def quantize_dbias( is_dbias=is_dbias, flatten_axis=flatten_axis, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 0ecf791505..f512321c38 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -15,13 +15,14 @@ namespace transformer_engine { namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, + Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int, ActivationConfig act_params) { + Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, + bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto *output = output_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data(); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto input_dims = input_buf.dimensions(); auto m = product(input_dims, 0, input_dims.size() - 2); @@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_trans_shape = std::vector{static_cast(n), m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } NVTE_CHECK( scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, @@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal if (is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); @@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv .Ret() // scale_inv colwise - .Ret() // amax + .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") .Attr("is_2x") - .Attr("act_params"), + .Attr("act_params") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, - JAXX_Scaling_Mode scaling_mode, bool is_2x_int, - ActivationConfig act_params) { - return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, - colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, - act_enum, scaling_mode, is_2x_int, act_params); + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int, + ActivationConfig act_params, bool output_amax_when_no_scaling) { + return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf, + output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, + updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params, + output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, @@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv .Ret() // scale_inv colwise - .Ret() // amax + .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") .Attr("is_2x") - .Attr("act_params")); + .Attr("act_params") + .Attr("output_amax_when_no_scaling")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias, - ActivationConfig act_params) { + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, + bool is_dbias, ActivationConfig act_params, + bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto *input = input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data(); float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } if (is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); @@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Arg() // input .Arg() // act input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv @@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("act_enum") .Attr("is_2x") .Attr("is_dbias") - .Attr("act_params"), + .Attr("act_params") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); Error_Type DActLuDBiasQuantizeInitializeFFI( cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, - bool is_dbias, ActivationConfig act_params) { + Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params, + bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, - act_input_buf, scale_buf, output_buf, colwise_output_buf, - scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); + act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, + scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf, + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params, + output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Arg() // input .Arg() // act input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv .Ret() // scale_inv colwise - .Ret() // amax + .Ret() // updated_amax .Ret() // dbias .Ret() // wkspace .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") .Attr("is_dbias") - .Attr("act_params")); + .Attr("act_params") + .Attr("output_amax_when_no_scaling")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 5238193922..378e009c83 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); + output_tensor.set_amax(nullptr, DType::kFloat32, std::vector{1}); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { @@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si } Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, - Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, - Result_Type colwise_output_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, - Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, - int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { + Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, + double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, + bool is_2x, bool output_amax_when_no_scaling) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *output = output_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data(); auto *mu = mu_buf->untyped_data(); - auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); + auto _norm_type = static_cast(norm_type); auto _is_2x = static_cast(is_2x); @@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } NVTE_CHECK( scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, @@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - nvte_memset(amax, 0, sizeof(float), stream); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); } if (_is_2x) { @@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Ctx() // stream .Arg() // x .Arg() // scale + .Arg() // amax .Arg() // gamma .Arg() // beta .Ret() // output .Ret() // colwise_output .Ret() // scale_inv .Ret() // colwise_scale_inv - .Ret() // amax + .Ret() // updated_amax .Ret() // mu .Ret() // rsigma .Ret() // wkspace @@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("is_2x") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, - Buffer_Type gamma_buf, Buffer_Type beta_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type mu_buf, - Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, + Buffer_Type amax_buf, Buffer_Type gamma_buf, + Buffer_Type beta_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, int64_t sm_margin, - JAXX_Scaling_Mode scaling_mode, bool is_2x) { - return wrapInStreamCapture( - std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf, - colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf, - wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x); + JAXX_Scaling_Mode scaling_mode, bool is_2x, + bool output_amax_when_no_scaling) { + return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf, + gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf, + colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf, + wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, + scaling_mode, is_2x, output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, @@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ .Ctx() // stream .Arg() // x .Arg() // scale + .Arg() // amax .Arg() // gamma .Arg() // beta .Ret() // output .Ret() // colwise_output .Ret() // scale_inv .Ret() // colwise_scale_inv - .Ret() // amax + .Ret() // updated_amax .Ret() // mu .Ret() // rsigma .Ret() // wkspace @@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x")); + .Attr("is_2x") + .Attr("output_amax_when_no_scaling")); pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, NVTE_Norm_Type norm_type, diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index d17d83ec1e..05260741b6 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (is_fp8_dtype(out_dtype)) { if (is_tensor_scaling) { float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + NVTE_CHECK(amax == updated_amax && amax != nullptr, + "amax must be provided for delayed tensor scaling"); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 3cdf6ba7a1..28525a22a9 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -63,7 +63,7 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, output_axes: Tuple[str, ...] = None, @@ -81,7 +81,7 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract - batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. + transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix output_axes: Logical axes for sharding the output @@ -91,8 +91,8 @@ def dense( Returns: Transformed output tensor """ - if batch_sequence_transpose: - warnings.warn("batch_sequence_transpose is not well tested, use with caution!") + if transpose_batch_sequence: + warnings.warn("transpose_batch_sequence is not well tested, use with caution!") if not get_quantize_config().is_fp8_enabled(): input_dtype = x.dtype @@ -103,7 +103,7 @@ def dense( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -119,7 +119,7 @@ def _dense( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -136,7 +136,7 @@ def _dense( kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification - batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. + transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix @@ -151,7 +151,7 @@ def _dense( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -166,7 +166,7 @@ def _dense_fwd_rule( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -197,7 +197,7 @@ def _dense_fwd_rule( flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, amax_scope=AmaxScope.TPSP, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -215,7 +215,7 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, collective_op=collective_op_set.forward, @@ -240,7 +240,7 @@ def _dense_fwd_rule( def _dense_bwd_rule( contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -274,7 +274,7 @@ def _dense_bwd_rule( flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, amax_scope=AmaxScope.TPSP, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) # GEMM NT @@ -291,7 +291,7 @@ def _dense_bwd_rule( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, collective_op=collective_op_set.backward, ) @@ -305,7 +305,7 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index f02876d8f4..76865f7c12 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. """ features: Union[Iterable[int], int] @@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase): axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 input_axes: Tuple[str, ...] = () + transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: @@ -512,6 +515,7 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): depth_scaling: float, default = None The factor to scale the output from `DenseGeneral`. It should be a float value or None. When None is set, then no scaling is applied. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. """ features: Union[Iterable[int], int] @@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None + transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: @@ -768,6 +775,7 @@ def __call__(self, inputs: Array) -> Array: dot_input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) @@ -775,6 +783,7 @@ def __call__(self, inputs: Array) -> Array: y, kernel, contracting_dims=(axis, contract_ind), + transpose_batch_sequence=self.transpose_batch_sequence, input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, @@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. """ intermediate_dim: int = 2048 @@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase): dot_2_input_axes: Tuple[str, ...] = None ffn1_ckpt_name: str = "ffn1" ffn2_ckpt_name: str = "ffn2" + transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: @@ -1160,6 +1172,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): activation_type=normalized_acts, activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), + transpose_batch_sequence=self.transpose_batch_sequence, ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1178,6 +1191,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) @@ -1188,6 +1202,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -1260,6 +1275,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_2_input_axes, kernel_axes=self.kernel_axes_2, quantizer_set=ffn2_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 868bcfa057..c95765bf3a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1207,6 +1207,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="qkv", dtype=self.dtype, )(inputs_q) @@ -1234,6 +1235,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="query", )(inputs_q) @@ -1252,6 +1254,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): enable_low_rank_adaptation=lora_scope.qkv_proj, low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, + transpose_batch_sequence=self.transpose_batch_sequence, name="kv", dtype=self.dtype, )(inputs_kv) @@ -1292,6 +1295,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="query", )(inputs_q) @@ -2070,6 +2074,7 @@ def hidden_dropout(x, deterministic): layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), + transpose_batch_sequence=self.transpose_batch_sequence, name="mlp", )(mlp_input, deterministic=deterministic) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index fb97830759..136f43df41 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -16,6 +16,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .quantize import ( QuantizerSet, @@ -35,6 +36,7 @@ def layernorm_dense( norm_type: str = "layernorm", zero_centered_gamma: bool = False, epsilon: float = 1e-6, + transpose_batch_sequence: bool = False, layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, @@ -55,6 +57,7 @@ def layernorm_dense( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix @@ -83,6 +86,7 @@ def layernorm_dense( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -100,6 +104,7 @@ def layernorm_dense( 8, 9, 10, + 11, ), ) def _layernorm_dense( @@ -111,6 +116,7 @@ def _layernorm_dense( norm_type: str, zero_centered_gamma: bool, epsilon: float, + transpose_batch_sequence: bool, layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], @@ -131,6 +137,7 @@ def _layernorm_dense( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding quantizer_set: Set of quantizers @@ -147,6 +154,7 @@ def _layernorm_dense( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) @@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule( kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, contracting_dims=(g_constracting_dim, k_constracting_dim), + transpose_batch_sequence=transpose_batch_sequence, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_constracting_dim, g_constracting_dim), + transpose_batch_sequence=transpose_batch_sequence, ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 77daa4672c..c43430cf36 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -41,7 +41,7 @@ def layernorm_mlp( norm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6, - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, @@ -78,7 +78,7 @@ def layernorm_mlp( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization - batch_sequence_transpose: Whether to transpose the batch and sequence dimensions + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication @@ -130,7 +130,7 @@ def layernorm_mlp( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -158,7 +158,7 @@ def _layernorm_mlp( norm_type: str, zero_centered_gamma: bool, epsilon: float, - batch_sequence_transpose: bool, + transpose_batch_sequence: bool, norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], @@ -188,7 +188,7 @@ def _layernorm_mlp( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability - batch_sequence_transpose: Whether to transpose the batch and sequence dimensions + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for layernorm sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding @@ -214,7 +214,7 @@ def _layernorm_mlp( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule( norm_type, quantizer=ffn1_quantizer_set.x, amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP + kernel_1, + flatten_axis=-2, + quantizer=ffn1_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) # NN GEMM @@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, collective_op=collective_op_set_1.forward, @@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule( if activation_params else None ), + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule( kernel_2, quantizer=ffn2_quantizer_set.kernel, amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) # NN GEMM @@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, collective_op=collective_op_set_2.forward, @@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule( is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, collective_op=collective_op_set_2.backward, ) @@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule( if activation_params else None ), + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, collective_op=collective_op_set_1.backward, ) @@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) From 76bced540eb264a194b8cd28f8894f860d841e6a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 7 Oct 2025 12:17:24 -0700 Subject: [PATCH 67/78] `NVFP4BlockScaling` recipe docs (#2241) * Improve docstring for NVFP4 recipe Signed-off-by: Kirthi Shankar Sivamani * Add NVFP4BlockScaling to recipe docs Signed-off-by: Kirthi Shankar Sivamani * Grammar Signed-off-by: Kirthi Shankar Sivamani * improve wording Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Przemyslaw Tredak Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/api/common.rst | 2 ++ transformer_engine/common/recipe/__init__.py | 28 +++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/docs/api/common.rst b/docs/api/common.rst index 541118985d..3edd7cae21 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -12,6 +12,8 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) +.. autoapiclass:: transformer_engine.common.recipe.NVFP4BlockScaling(fp4_format=Format.E2M1) + .. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID) .. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1a9b029878..1204c37c5b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -401,16 +401,32 @@ class NVFP4BlockScaling(Recipe): computed from the high precision input to avoid double quantization errors. + The default NVFP4 training recipe implements 3 techniques for quantizing + to a narrow format (4-bit): + + - For weight tensors a variant of the NVFP4 quantization is used, + where a single scaling factor is shared by a 2D block of 16x16 elements. + - When quantizing gradients, stochastic rounding is applied to avoid the bias + introduced by quantization. With this, values are rounded probabilistically + to one of their two nearest representable numbers, with probabilities + inversely proportional to their distances. + - When quantizing inputs and gradients, random Hadamard transforms are applied + (16x16 Hadamard matrix) to smooth outliers in the tensor distributions + and make them easier to represent accurately in NVFP4. + + These techniques are described more comprehensively in the NVFP4 paper titled + 'Pretraining Large Language Models with NVFP4' (https://arxiv.org/abs/2509.25149v1). + Parameters ---------- fp4_format : {Format.E2M1}, default = Format.E2M1 FP4 data type. - fp8_format : {Format.E4M3}, default = Format.E4M3 - FP8 data type. Only E4M3 is supported. - fp8_dpa: bool, default = `False` - FP8 dot product attention. Not yet supported. - fp8_mha: bool, default = `False` - FP8 multi-head attention. Not yet supported. + disable_rht : bool, default = `False` + If set to `True`, random Hadamard transforms are not applied to any tensor. + disable_stochastic_rounding : bool, default = `False` + If set to `True`, stochastic rounding is disabled during quantization for all tensors. + disable_2d_quantization : bool, default = `False` + If set to `True`, 1D block scaling with block size 16 is used for all tensors. """ # Configuration envvars From ac5e868f143401f04664b8cb8f39d806ac912078 Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Tue, 7 Oct 2025 12:51:54 -0700 Subject: [PATCH 68/78] Skip fp8 tests on unsupported devices (#2243) Signed-off-by: Vladimir Cherepanov --- tests/cpp_distributed/test_comm_gemm.cu | 31 +++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/cpp_distributed/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu index 8355d5f96f..884faa4748 100644 --- a/tests/cpp_distributed/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) { return supported; } +int GetDeviceComputeCapability(int device_id) { + int major{}; + int minor{}; + CHECK_CU(cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device_id)); + CHECK_CU(cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device_id)); + return major * 10 + minor; +} + +template +bool IsDTypeSupported(int /* device_id */) { + return true; +} + +template <> +bool IsDTypeSupported(int device_id) { + return GetDeviceComputeCapability(device_id) >= 89; +} + +template <> +bool IsDTypeSupported(int device_id) { + return GetDeviceComputeCapability(device_id) >= 89; +} + +template +bool AllDTypesSupported(int device_id) { + return (IsDTypeSupported(device_id) && ...); +} + template std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nstart, size_t msize, size_t nsize, size_t ld) { @@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam { template void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { + if (!AllDTypesSupported(rank_)) + GTEST_SKIP() << "FP8 is not supported on device " << rank_; + cudaStream_t stream{}; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); From 66f9b3cbae214d521ac18883fe9a386b8893b179 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 7 Oct 2025 21:32:47 -0700 Subject: [PATCH 69/78] [PyTorch] Unblock fused bgrad quantization path for nvfp4 (#2246) Unblock path for fusing NVFP4 quantize and bgrad Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/base.py | 4 +--- transformer_engine/pytorch/module/layernorm_mlp.py | 5 +---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3ae3895689..838ac5281c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -40,7 +40,6 @@ from ..constants import dist_group_type from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.storage.float8_tensor_storage import Float8TensorStorage @@ -1229,8 +1228,7 @@ def grad_output_preprocess( ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - # TODO(ksivaman): Re-add fusion once kernel is available. - if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)): + if isinstance(quantizer, Float8BlockQuantizer): # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2097f01b1d..d680a9f8f6 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1037,11 +1037,8 @@ def fc2_wgrad_gemm( if ctx.fp8: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now - # TODO(ksivaman): Re-add fusion once kernel is available. if ( - isinstance( - ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) - ) + isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) From af2a0c16ec11363c0af84690cd877a59f898820e Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Wed, 8 Oct 2025 08:30:56 -0700 Subject: [PATCH 70/78] [JAX] Async issuing D2H memcpy for grouped_gemm group_sizes array (#2213) * Try async copy of grouped GEMM group_sizes data Signed-off-by: Hua Huang --------- Signed-off-by: Hua Huang Co-authored-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 10 ++- transformer_engine/jax/cpp_extensions/gemm.py | 87 ++++++++++++++++++- transformer_engine/jax/csrc/extensions.h | 1 + .../jax/csrc/extensions/gemm.cpp | 81 +++++++++++++++-- .../jax/csrc/extensions/pybind.cpp | 3 + 5 files changed, 172 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7a4fa268af..124e0248b6 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1366,14 +1366,22 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) + num_gemms = input_shape[0] + _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( + group_sizes, + num_gemms=num_gemms, + ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + prim_out = jax.jit( + tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") + )( lhs, rhs, group_sizes, contracting_dims, + use_async_d2h_group_sizes=True, ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 865efe89da..7fe433bcc6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -58,6 +58,7 @@ "collective_gemm_bootstrap", "noop_collective_op_set", "gemm", + "grouped_gemm_copy_group_sizes", "grouped_gemm", "gemm_uses_jax_dot", "sanitize_dims", @@ -1237,6 +1238,63 @@ def _te_gemm( ) +class GroupedGemmCopySizesPrimitive(BasePrimitive): + """ + Primitive for async copying group sizes from device to host + """ + + name = "te_grouped_gemm_d2h_group_sizes_ffi" + multiple_results = False + impl_static_args = (1,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + group_sizes_aval, + *, + num_gemms, + ): + del num_gemms + out_aval = group_sizes_aval + return out_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs) + return out + + @staticmethod + def lowering( + ctx, + group_sizes, + num_gemms, + ): + return jax.ffi.ffi_lowering( + GroupedGemmCopySizesPrimitive.name, + operand_output_aliases={0: 0}, # Mark num_gemms as the output + )( + ctx, + group_sizes, + num_gemms=num_gemms, + ) + + @staticmethod + def impl( + group_sizes, + num_gemms, + ): + assert GroupedGemmCopySizesPrimitive.inner_primitive is not None + out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + +register_primitive(GroupedGemmCopySizesPrimitive) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -1244,7 +1302,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) inner_primitive = None outer_primitive = None @@ -1267,6 +1325,7 @@ def abstract( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): """ Grouped GEMM operation. @@ -1294,7 +1353,7 @@ def abstract( A jnp.ndarray containing the result of the grouped GEMM operation """ del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval - del K, lhs_is_trans, rhs_is_trans, has_bias + del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_alignment_padding = 256 @@ -1341,6 +1400,7 @@ def lowering( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( @@ -1354,6 +1414,7 @@ def lowering( scaling_mode=scaling_mode.value, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) @staticmethod @@ -1374,6 +1435,7 @@ def impl( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): assert GroupedGemmPrimitive.inner_primitive is not None (out, _) = GroupedGemmPrimitive.inner_primitive.bind( @@ -1393,6 +1455,7 @@ def impl( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return (out,) @@ -1661,6 +1724,24 @@ def gemm( return clean_outputs +def grouped_gemm_copy_group_sizes( + group_sizes: jnp.ndarray, + num_gemms: int, +) -> jnp.ndarray: + """ + Async copy group sizes from device to host + + Args: + group_sizes: 1D array containing the sizes of each group + num_gemms: number of grouped gemm calls to be made + """ + out = GroupedGemmCopySizesPrimitive.outer_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -1671,6 +1752,7 @@ def grouped_gemm( preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, + use_async_d2h_group_sizes: bool = False, ) -> jnp.ndarray: """ Grouped GEMM operation. @@ -1854,5 +1936,6 @@ def grouped_gemm( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index bbfc62120a..3ce6dee731 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -135,6 +135,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); // Cudnn helpers diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f2007efcf6..993ec1377d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -284,12 +284,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("collective_op"), FFI_CudaGraph_Traits); +size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, + int32_t *host_group_sizes) { + static std::once_flag init_flag; + static cudaEvent_t d2h_event; + static size_t host_num_gemms; + static const size_t max_num_gemms = 1024; + //static int32_t host_group_sizes_internal[max_num_gemms]; + static int32_t *host_group_sizes_internal = nullptr; + auto init = [&]() { + NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event)); + NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms)); + }; + std::call_once(init_flag, init); + + NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr, + "Only one of dev_group_sizes and host_group_sizes can be non-nullptr."); + + if (dev_group_sizes != nullptr) { + NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ", + "supported number ", max_num_gemms, " to be downloaded in advance."); + host_num_gemms = num_gemms; + // Wait for current compute stream to finish + cudaStream_t compute_stream_0 = nvte_get_compute_stream(0); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event)); + // Async copy group_sizes from device to host + size_t copy_bytes = sizeof(int32_t) * num_gemms; + NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes, + cudaMemcpyDeviceToHost, compute_stream_0)); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0)); + return num_gemms; + } + + if (host_group_sizes != nullptr) { + if (host_num_gemms == 0) return 0; + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the previous value ", host_num_gemms, "."); + // Wait for the async copy to finish, then copy group_sizes to user buffer + // Note: This may break cudaGraph. + NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event)); + memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms); + return host_num_gemms; + } +} + +Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes, + Result_Type dummy_output, size_t num_gemms) { + int32_t *dev_group_sizes = reinterpret_cast(group_sizes.untyped_data()); + GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // group_sizes + .Ret() // dummy_output + .Attr("num_gemms")); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad) { + bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -410,11 +469,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); if (!is_grouped_dense_wgrad) { NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, @@ -673,7 +739,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad")); + .Attr("is_grouped_dense_wgrad") + .Attr("use_async_d2h_group_sizes")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 23d46b3384..f6b1acd439 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -69,6 +69,9 @@ pybind11::dict Registrations() { pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM + dict["te_grouped_gemm_d2h_group_sizes_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler)); dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); From e37e33e12768a9fa397b51cd17c6425775c543ea Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 8 Oct 2025 19:19:51 -0700 Subject: [PATCH 71/78] Disallow pure E5M2 recipe for `Float8BlockScaling` (#2251) Catch unsupported GEMM during recipe init Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1204c37c5b..f70b43a7a8 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -363,6 +363,7 @@ def __post_init__(self) -> None: assert ( not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: return ( From 9bf4175f6b100219e0e02f4ca50d9d8fa5331efe Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 8 Oct 2025 19:20:33 -0700 Subject: [PATCH 72/78] [PyTorch] Deprecate old `float8_tensor.py` (#2250) Deprecate old float8_tensor.py Signed-off-by: Kirthi Shankar Sivamani --- .../attention/dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 2 +- .../dot_product_attention/dot_product_attention.py | 2 +- .../pytorch/attention/dot_product_attention/utils.py | 2 +- .../pytorch/attention/multi_head_attention.py | 2 +- transformer_engine/pytorch/float8_tensor.py | 10 ++++++++++ transformer_engine/pytorch/utils.py | 6 +++--- 7 files changed, 18 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 3a13758382..0ddb261d2e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,7 +29,7 @@ prepare_for_saving, restore_from_saved, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d0ddae25ef..d1374e949e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -20,7 +20,7 @@ FusedAttnBackend, ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 88e28e3d81..df96067d65 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -30,7 +30,7 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ea7b0e8763..8b26a1760d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,8 +35,8 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index b2f1ff1ac9..8f01832248 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -10,7 +10,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index a771e3bb75..eeafc23c70 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -4,6 +4,16 @@ """Tensor class with FP8 data""" +import warnings + from .tensor.float8_tensor import Float8Tensor +warnings.warn( + "transformer_engine.pytorch.float8_tensor is deprecated and will be removed" + " in a future release. Float8Tensor should be imported directly through " + "`from transformer_engine.pytorch import Float8Tensor`", + DeprecationWarning, + stacklevel=2, +) + __all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 8ea3623713..b1a7e3731d 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -184,7 +184,7 @@ def combine_tensors( num_tensors = len(tensors) new_shape = list(tensors[0].shape) new_shape.insert(dim, num_tensors) - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor if isinstance(tensors[0], Float8Tensor): new_stride = list(tensors[0]._data.stride()) @@ -224,7 +224,7 @@ def forward( # pylint: disable=missing-function-docstring ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( Float8TensorStorage, ) @@ -278,7 +278,7 @@ def backward(ctx, *grad_outputs): split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) dims = len(grad_outputs[0].shape) split_dim = (ctx.split_dim + dims) % dims - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor if isinstance(grad_outputs[0], Float8Tensor): noop_ok = True From e99be1b6af1aa138194c211ad9952858b3aaee44 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 9 Oct 2025 10:17:41 -0700 Subject: [PATCH 73/78] Update minimum python version to 3.10 and add checks in CI (#2247) * Update minimum python version to 3.10 and update CI Signed-off-by: Kirthi Shankar Sivamani * review Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .pre-commit-config.yaml | 6 ++++++ build_tools/utils.py | 19 +++++++++++++++++++ setup.py | 3 ++- transformer_engine/jax/setup.py | 3 ++- transformer_engine/pytorch/setup.py | 3 ++- 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bbe486fac7..5043d6ea22 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,3 +38,9 @@ repos: entry: clang-format -i args: ["-style=file"] files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$ + + - repo: https://github.com/netromdk/vermin + rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3 + hooks: + - id: vermin + args: ['-t=3.10', '--violations'] diff --git a/build_tools/utils.py b/build_tools/utils.py index 3d8ec462c8..296f928b71 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -12,12 +12,31 @@ import shutil import subprocess import sys +import platform from pathlib import Path from importlib.metadata import version as get_version from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union +# Needs to stay consistent with .pre-commit-config.yaml config. +def min_python_version() -> Tuple[int]: + """Minimum supported Python version.""" + return (3, 10, 0) + + +def min_python_version_str() -> str: + """String representing minimum supported Python version.""" + return ".".join(map(str, min_python_version())) + + +if sys.version_info < min_python_version(): + raise RuntimeError( + f"Transformer Engine requires Python {min_python_version_str()} or newer, " + f"but found Python {platform.python_version()}." + ) + + @functools.lru_cache(maxsize=None) def debug_build_enabled() -> bool: """Whether to build with a debug configuration""" diff --git a/setup.py b/setup.py index ed1f5b8a9d..c932da5e02 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ cuda_version, get_frameworks, remove_dups, + min_python_version_str, ) frameworks = get_frameworks() @@ -190,7 +191,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8", + python_requires=f">={min_python_version_str()}", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index ca83cf465e..f83375d821 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -44,7 +44,7 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers +from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements @@ -100,6 +100,7 @@ description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, + python_requires=f">={min_python_version_str()}", install_requires=install_requirements(), tests_require=test_requirements(), ) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 46543acf28..08870040f3 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -45,7 +45,7 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers +from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.pytorch import ( setup_pytorch_extension, @@ -152,6 +152,7 @@ def run(self): description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, + python_requires=f">={min_python_version_str()}", install_requires=install_requirements(), tests_require=test_requirements(), ) From 8a7ab3ddc17e275fcbcd2eee8688ada265efbcad Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 9 Oct 2025 13:49:12 -0700 Subject: [PATCH 74/78] [JAX] NVFP4 support in TE/JAX (#2254) Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen --- examples/jax/encoder/common.py | 13 +- .../run_test_multiprocessing_encoder.sh | 2 + .../encoder/test_model_parallel_encoder.py | 118 +++-- examples/jax/encoder/test_multigpu_encoder.py | 84 ++- .../encoder/test_multiprocessing_encoder.py | 80 ++- .../jax/encoder/test_single_gpu_encoder.py | 57 +- examples/jax/mnist/test_single_gpu_mnist.py | 10 +- tests/jax/test_custom_call_compute.py | 495 +++++++++++++++--- tests/jax/test_distributed_layernorm_mlp.py | 133 +++-- tests/jax/test_helper.py | 69 ++- tests/jax/utils.py | 6 + transformer_engine/jax/__init__.py | 3 +- .../jax/cpp_extensions/__init__.py | 1 + .../jax/cpp_extensions/activation.py | 10 +- transformer_engine/jax/cpp_extensions/amax.py | 420 +++++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 134 ++++- transformer_engine/jax/cpp_extensions/misc.py | 15 +- .../jax/cpp_extensions/normalization.py | 15 +- .../jax/cpp_extensions/quantization.py | 414 ++++++++------- transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/amax.cpp | 100 ++++ .../jax/csrc/extensions/ffi.cpp | 3 + transformer_engine/jax/csrc/extensions/ffi.h | 2 + .../jax/csrc/extensions/gemm.cpp | 56 +- .../jax/csrc/extensions/misc.cpp | 20 +- transformer_engine/jax/csrc/extensions/misc.h | 31 +- .../jax/csrc/extensions/pybind.cpp | 11 +- .../jax/csrc/extensions/quantization.cpp | 189 +++++-- transformer_engine/jax/dense.py | 2 +- transformer_engine/jax/flax/module.py | 60 +-- transformer_engine/jax/layernorm_dense.py | 2 +- transformer_engine/jax/layernorm_mlp.py | 2 +- transformer_engine/jax/quantize/__init__.py | 1 + .../jax/quantize/dequantizer.py | 97 +++- transformer_engine/jax/quantize/hadamard.py | 72 +++ transformer_engine/jax/quantize/helper.py | 376 ++++++++++--- transformer_engine/jax/quantize/metadata.py | 16 +- transformer_engine/jax/quantize/quantizer.py | 278 +++++++++- .../jax/quantize/scaling_modes.py | 267 +++++++++- transformer_engine/jax/quantize/tensor.py | 38 +- 40 files changed, 2987 insertions(+), 721 deletions(-) create mode 100644 transformer_engine/jax/cpp_extensions/amax.py create mode 100644 transformer_engine/jax/csrc/extensions/amax.cpp create mode 100644 transformer_engine/jax/quantize/hadamard.py diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index a8bf25113a..772d5f4c14 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -33,6 +33,13 @@ def is_mxfp8_supported(): return gpu_arch >= 100 +@lru_cache +def is_nvfp4_supported(): + """Return if FP8 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 100 + + def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False): """Checks whether most params are sharded across sharding axis. @@ -98,7 +105,7 @@ def assert_leaf_sharding(path, arr): ) -def get_fp8_recipe_from_name_string(name: str): +def get_quantization_recipe_from_name_string(name: str): """Query recipe from a given name string""" match name: case "DelayedScaling": @@ -107,5 +114,7 @@ def get_fp8_recipe_from_name_string(name: str): return recipe.MXFP8BlockScaling() case "Float8CurrentScaling": return recipe.Float8CurrentScaling() + case "NVFP4BlockScaling": + return recipe.NVFP4BlockScaling() case _: - raise ValueError(f"Invalid fp8_recipe, got {name}") + raise ValueError(f"Invalid quantization_recipe, got {name}") diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 2a979e1775..fa7102cb42 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -10,9 +10,11 @@ TEST_CASES=( "test_te_delayed_scaling_fp8" "test_te_current_scaling_fp8" "test_te_mxfp8" +"test_te_nvfp4" "test_te_bf16_shardy" "test_te_delayed_scaling_fp8_shardy" "test_te_current_scaling_fp8_shardy" +"test_te_nvfp4_shardy" ) : ${TE_PATH:=/opt/transformerengine} diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 41832650fa..5fc7efbba7 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -21,13 +21,13 @@ from common import ( is_bf16_supported, - get_fp8_recipe_from_name_string, + get_quantization_recipe_from_name_string, assert_params_sufficiently_sharded, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode DEVICE_DP_AXIS = "data" @@ -36,6 +36,7 @@ NAMED_TP_AXIS = "my_tp_axis" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +SR_KEY = "sr_rng" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -121,6 +122,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -135,11 +138,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -150,7 +153,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect, eval_fn): +def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -159,11 +162,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -223,7 +228,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -257,7 +262,7 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -275,7 +280,8 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] @@ -355,7 +361,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -367,22 +380,24 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs ) print( @@ -402,16 +417,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for training (default: 128)", + help="input batch size for training (default: 256)", ) parser.add_argument( "--test-batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for testing (default: 128)", + help="input batch size for testing (default: 256)", ) parser.add_argument( "--max-seq-len", @@ -466,8 +481,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 5 epochs for testing""" @@ -477,7 +493,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -485,7 +501,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.361 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -493,14 +509,22 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -509,7 +533,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -518,14 +542,23 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_with_sp(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -534,7 +567,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp_shardy(self): @@ -544,24 +577,27 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.361 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_with_sp_shardy(self): """Test Transformer Engine with MXFP8 + SP""" self.args.enable_shardy = True @@ -569,7 +605,17 @@ def test_te_mxfp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_with_sp_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index bc6a567521..68fb3ddd3f 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -19,17 +19,18 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding -from common import is_bf16_supported, get_fp8_recipe_from_name_string +from common import is_bf16_supported, get_quantization_recipe_from_name_string import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode DEVICE_DP_AXIS = "data" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" DROPOUT_KEY = "dropout" +SR_KEY = "sr_rng" INPUT_KEY = "input_rng" @@ -97,6 +98,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -111,11 +114,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -126,7 +129,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect, eval_fn): +def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -135,11 +138,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -199,7 +204,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -254,7 +259,7 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -270,6 +275,7 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) + rng, sr_rng = jax.random.split(rng) init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} input_shape = [args.batch_size, args.max_seq_len] @@ -322,7 +328,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -334,22 +347,24 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs ) print( @@ -369,16 +384,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=256, + default=512, metavar="N", - help="input batch size for training (default: 256)", + help="input batch size for training (default: 512)", ) parser.add_argument( "--test-batch-size", type=int, - default=256, + default=512, metavar="N", - help="input batch size for testing (default: 256)", + help="input batch size for testing (default: 512)", ) parser.add_argument( "--max-seq-len", @@ -430,8 +445,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 5 epochs for testing""" @@ -441,7 +457,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -449,7 +465,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -457,7 +473,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.749 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -465,6 +481,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) + assert actual[0] < 0.51 and actual[1] > 0.75 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @@ -472,7 +496,7 @@ def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -481,7 +505,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8_shardy(self): @@ -490,18 +514,24 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.749 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) + assert actual[0] < 0.51 and actual[1] > 0.75 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index abf6a407b6..358fbca4bb 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -25,7 +25,8 @@ is_bf16_supported, is_fp8_supported, is_mxfp8_supported, - get_fp8_recipe_from_name_string, + is_nvfp4_supported, + get_quantization_recipe_from_name_string, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex @@ -39,6 +40,7 @@ NAMED_TP_AXIS = "my_tp_axis" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +SR_KEY = "sr_rng" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -175,6 +177,8 @@ def train_epoch( epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_input = sentence[perm, ...] batch_mask = mask[perm, ...] batch_label = label[perm, ...] @@ -200,11 +204,11 @@ def train_epoch( return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels, 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -216,7 +220,16 @@ def loss_fn(var_collect, disable_dropout=False): def eval_model( - state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec + state, + test_ds, + batch_size, + var_collect, + eval_fn, + mesh, + inputs_pspec, + masks_pspec, + labels_pspec, + rngs, ): """Evaluation loop.""" global_input_shape, input_named_sharding, sentence = shard_array_wrapper( @@ -233,7 +246,8 @@ def eval_model( all_accuracy = [] for batch_input, batch_mask, batch_label in zip(sentence, mask, label): - + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} shard_input = jax.make_array_from_single_device_arrays( global_input_shape, input_named_sharding, [batch_input] ) @@ -244,7 +258,7 @@ def eval_model( global_label_shape, label_named_sharding, [batch_label] ) - loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect) + loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -303,7 +317,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -372,7 +386,7 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -390,7 +404,8 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] @@ -444,7 +459,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -456,14 +478,16 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") else: for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, @@ -488,6 +512,7 @@ def train_and_evaluate(args): inputs_pspec, masks_pspec, labels_sharding.spec, + rngs, ) if args.process_id == 0: print( @@ -508,16 +533,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for training (default: 128)", + help="input batch size for training (default: 256)", ) parser.add_argument( "--test-batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for testing (default: 128)", + help="input batch size for testing (default: 256)", ) parser.add_argument( "--max-seq-len", @@ -629,7 +654,7 @@ def test_te_delayed_scaling_fp8(self): def test_te_current_scaling_fp8(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling") - assert result[0] < 0.43 and result[1] > 0.80 + assert result[0] < 0.432 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -639,6 +664,14 @@ def test_te_mxfp8(self): result = self.exec(True, "MXFP8BlockScaling") assert result[0] < 0.43 and result[1] > 0.80 + @unittest.skipIf( + not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" + ) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + result = self.exec(True, "NVFP4BlockScaling") + assert result[0] < 0.451 and result[1] > 0.79 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" @@ -659,19 +692,24 @@ def test_te_delayed_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 + assert result[0] < 0.432 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" ) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) assert result[0] < 0.43 and result[1] > 0.80 + @unittest.skipIf( + not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" + ) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) + assert result[0] < 0.451 and result[1] > 0.79 + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 826d0d2fc7..320483099d 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -16,14 +16,15 @@ from flax import linen as nn from flax.training import train_state -from common import is_bf16_supported, get_fp8_recipe_from_name_string +from common import is_bf16_supported, get_quantization_recipe_from_name_string import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode PARAMS_KEY = "params" DROPOUT_KEY = "dropout" +SR_KEY = "sr_rng" INPUT_KEY = "input_rng" @@ -92,6 +93,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -107,11 +110,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): @jax.jit -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -122,7 +125,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect): +def eval_model(state, test_ds, batch_size, var_collect, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -131,11 +134,15 @@ def eval_model(state, test_ds, batch_size, var_collect): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_step( + state, batch_inputs, batch_masks, batch_labels, var_collect, rngs + ) all_loss.append(loss) all_accuracy.append(accuracy) @@ -195,7 +202,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -208,14 +215,15 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -238,21 +246,25 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect ) - test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, rngs + ) print( f"Epoch: {epoch:>2} " @@ -329,8 +341,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 3 epochs for testing""" @@ -340,7 +353,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.452 and actual[1] > 0.788 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -348,7 +361,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.457 and actual[1] > 0.784 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -356,7 +369,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.461 and actual[1] > 0.784 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -364,7 +377,15 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.457 and actual[1] > 0.784 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.476 and actual[1] > 0.775 if __name__ == "__main__": diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 92baf4b0c5..81bea4a324 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -18,11 +18,11 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode DIR = str(Path(__file__).resolve().parents[1]) sys.path.append(str(DIR)) -from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string +from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string IMAGE_H = 28 IMAGE_W = 28 @@ -189,7 +189,7 @@ def train_and_evaluate(args): label_shape = [args.batch_size] if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -308,8 +308,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 124e0248b6..2934e48df1 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,11 +40,13 @@ QuantizerFactory, QuantizeLayout, noop_quantizer_set, + should_use_rht, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.common import recipe GEMM_CASES = [ (256, 256, 512), @@ -56,16 +58,23 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] -is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available() -is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) -supported_scaling_modes = [] +# TODO(Phuong): remove unneccessary pytest skips +is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.DELAYED_TENSOR_SCALING +) +is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.MXFP8_1D_SCALING +) +is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.NVFP4_1D_SCALING +) + """ Find supported scaling modes""" -if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) - supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) -if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) +supported_scaling_modes = helper.get_supported_scaling_modes() +non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling] +supported_recipes = helper.get_supported_quantization_recipes() +supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] def is_shape_supported_by_mxfp8(input_shape): @@ -83,12 +92,13 @@ def assert_bitwise_scaled_tensors( a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True ): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): - if not precise_comparison: + if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling: assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype) return assert a.scaling_mode == b.scaling_mode assert a.scale_inv.dtype == b.scale_inv.dtype + assert a.data_layout == b.data_layout if a.scaling_mode.is_tensor_scaling(): # Assert in dq_dtype as some unfused codepaths have an intermediate cast # to an input dtype which reduces precision compared to everything in fp32 @@ -96,6 +106,16 @@ def assert_bitwise_scaled_tensors( elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Compare MXFP8 scales as uint8 assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8)) + elif a.scaling_mode.is_nvfp4_scaling: + assert_allclose(a.amax, b.amax) + assert_allclose(a.scale_inv, b.scale_inv) + if not precise_comparison: + mismatch = a.data != b.data + mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32)) + assert ( + mismatch_fraction < 0.05 + ), f"Mismatch fraction {mismatch_fraction} is too high" + return else: raise ValueError(f"Unsupported scaling mode {a.scaling_mode}") assert_allclose(a.data, b.data) @@ -603,10 +623,24 @@ def test_norm_forward_with_block_scaling_fp8( ) -QUANTIZE_OUTPUT_DTYPES = { +QUANTIZE_OUTPUT_FP8_DTYPES = { "L0": [jnp.float8_e4m3fn], "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2], } +QUANTIZE_OUTPUT_DTYPES = { + test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn] + for test_level in QUANTIZE_OUTPUT_FP8_DTYPES +} +QUANTIZE_QDTYPE_AND_SCALING_MODES = { + test_level: [ + (q_dtype, scaling_mode) + for q_dtype, scaling_mode in zip( + QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes + ) + if q_dtype in scaling_mode.get_compatible_q_dtypes() + ] + for test_level in QUANTIZE_OUTPUT_FP8_DTYPES +} ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [ ((32, 64), -1), @@ -615,8 +649,7 @@ def test_norm_forward_with_block_scaling_fp8( ((32, 256, 128), -1), ((32, 256, 128), -2), ((64, 32, 32, 256), -1), - ((64, 32, 32, 256), -2), - ((64, 32, 32, 256), -3), + ((8192, 2, 4096), -2), ] QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = { @@ -636,18 +669,38 @@ def test_norm_forward_with_block_scaling_fp8( @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) +@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn]) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper( - "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + "q_layout", + [ + QuantizeLayout.ROWWISE, + QuantizeLayout.COLWISE, + QuantizeLayout.ROWWISE_COLWISE, + ], ) class TestQuantize: """ Purely quantization related tests that will always test on a wider set of types and shapes """ + def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + """Temporary hack to skip unsupported FP4 cases until we implement them""" + if q_dtype not in scaling_mode.get_compatible_q_dtypes(): + pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") + return + + # HACK: FIXME TODO(jberchtold) + row = reduce(operator.mul, input_shape[flatten_axis:], 1) + col = reduce(operator.mul, input_shape[:flatten_axis], 1) + will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout) + if will_use_rht and (row % 64 != 0 or col % 128 != 0): + pytest.skip("Unfused RHT is not supported currently, skipping") + def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) + key = jax.random.PRNGKey(0) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) @@ -657,6 +710,68 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt q_layout=q_layout, ) + if scaling_mode.is_nvfp4_scaling: + if in_dtype != jnp.bfloat16: + pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently") + return + q_func = _jax_quantize + # For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor. + x = jax.random.uniform(key, input_shape, in_dtype) * 10 + q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis) + + dq_rowwise = None + dq_colwise = None + if isinstance(q1, ScaledTensor1x): + dq = q1.dequantize() + if q1.is_colwise: + dq_colwise = dq + else: + dq_rowwise = dq + elif isinstance(q1, ScaledTensor2x): + dq_rowwise = q1.rowwise_tensor.dequantize() + dq_colwise = q1.colwise_tensor.dequantize() + else: + raise ValueError(f"Unsupported output type {type(q1)}") + + # We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization. + if dq_rowwise is not None: + assert ( + dq_rowwise.shape == x.shape + ), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}" + q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis) + q2_rowwise = ( + q2_rowwise + if isinstance(q2_rowwise, ScaledTensor1x) + else q2_rowwise.rowwise_tensor + ) + q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor + assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise) + + if dq_colwise is not None: + # Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape + flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis + colwise_flatten_axis = len(input_shape) - flatten_axis + dq_colwise = jnp.transpose( + dq_colwise, + (*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)), + ) + assert ( + dq_colwise.shape == x.shape + ), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}" + q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis) + q2_colwise = ( + q2_colwise + if isinstance(q2_colwise, ScaledTensor1x) + else q2_colwise.colwise_tensor + ) + q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor + assert_bitwise_scaled_tensors(q1_colwise, q2_colwise) + + assert ( + dq_rowwise is not None or dq_colwise is not None + ), "At least one of rowwise or colwise dq must be not None" + return + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -664,9 +779,33 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis) assert_dequantized_scaled_tensor(scaled_tensor, x) + def _should_use_precise_comparison( + self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ): + # TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values. + RHT_SLIGHT_MISMATCH_SHAPES = [ + ((32, 256, 128), -1), + ((64, 32, 32, 256), -1), + ((8192, 2, 4096), -2), + ] + + if ( + should_use_rht(scaling_mode, q_layout=q_layout) + and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES + ): + # TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes + return False + + if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: + # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation + return False + + return True + def test_quantize_bitwise( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis ): + self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) @@ -677,15 +816,202 @@ def test_quantize_bitwise( jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) - assert_bitwise_scaled_tensors(te_output, jax_output) + try: + te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + except AssertionError as e: + if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16: + error_message = e.args[0] + if "RHT requires input to be bfloat16" in error_message: + # Successfully caught the expected error, early return from the test + return + raise e + + assert_bitwise_scaled_tensors( + te_output, + jax_output, + precise_comparison=self._should_use_precise_comparison( + in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ), + ) + + def test_quantize_bitwise_jitted( + self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis + ): + self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) + + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + + te_quantizer, jax_quantizer = QuantizerFactory.create( + n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout + ) + + jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3)) + te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,)) + + jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) + + try: + te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + except AssertionError as e: + if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16: + error_message = e.args[0] + if "RHT requires input to be bfloat16" in error_message: + # Successfully caught the expected error, early return from the test + return + raise e + + assert_bitwise_scaled_tensors( + te_output, + jax_output, + precise_comparison=self._should_use_precise_comparison( + in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ), + ) + + +@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) +@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn]) +@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) +@pytest_parametrize_wrapper( + "scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling] +) +@pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] +) +class TestStochasticRounding: + + def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]: + """Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor.""" + if isinstance(scaled_tensor, ScaledTensor1x): + dq = scaled_tensor.dequantize() + if scaled_tensor.data_layout == "T": + dq = jnp.transpose( + dq, + ( + *range(scaled_tensor.flatten_axis, dq.ndim), + *range(scaled_tensor.flatten_axis), + ), + ) + return [dq] + elif isinstance(scaled_tensor, ScaledTensor2x): + [rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor) + [colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor) + return [rowwise_dq, colwise_dq] + raise ValueError( + "Unsupported ScaledTensor type, expected ScaledTensor but received" + f" {type(scaled_tensor)}" + ) + + def _sample_sr_qdq( + self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> list[jnp.ndarray]: + """Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors.""" + dq_tensors = [] + + key = jax.random.PRNGKey(0) + + for i in range(num_samples): + iter_key = jax.random.fold_in(key, i) + sr_rng_state = jax.random.randint( + iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 + ) + quantizer = QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + stochastic_rounding_rng_state=sr_rng_state, + ) + + q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) + iter_dq = self._dequantize(q_output) + dq_tensors.extend(iter_dq) + + avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0) + assert avg_sr_tensor.shape == inputs.shape, ( + f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" + f" {inputs.shape}" + ) + + sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) + + dq_var = jnp.var(jnp.stack(dq_tensors)) + assert ( + dq_var > 0 + ), "Variance of dequantized tensors is zero, stochastic rounding may not be working" + + return dq_tensors + + def _round_nearest( + self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> jnp.ndarray: + """Quantizes and dequantizes the input tensor with round nearest quantization.""" + quantizer = QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + stochastic_rounding_rng_state=None, + ) + q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) + rn_dq = self._dequantize(q_output)[0] + return rn_dq + + def _test_sr( + self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> float: + """Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples.""" + dq_tensors = self._sample_sr_qdq( + num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0) + assert avg_sr_tensor.shape == inputs.shape, ( + f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" + f" {inputs.shape}" + ) + + round_nearest_tensor = self._round_nearest( + q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + + sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) + rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs)) + + assert sr_mae < rn_mae, ( + f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than" + f" round nearest ({rn_mae})" + ) + + return sr_mae + + def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" + # HACK: FIXME TODO(jberchtold) + row = reduce(operator.mul, input_shape[flatten_axis:], 1) + col = reduce(operator.mul, input_shape[:flatten_axis], 1) + will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout) + if will_use_rht and (row % 64 != 0 or col % 128 != 0): + pytest.skip("Unfused RHT is not supported currently, skipping") + + key = jax.random.PRNGKey(0) + inputs = jax.random.uniform(key, input_shape, in_dtype) + + NUM_SAMPLES = 10 + + te_mean_error = self._test_sr( + NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + jax_mean_error = self._test_sr( + NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + + assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) -@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) +@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @pytest_parametrize_wrapper("with_group_sizes", [True, False]) @pytest_parametrize_wrapper( @@ -724,7 +1050,6 @@ def test_grouped_qdq( q_layout=q_layout, n_groups=n_groups, ) - scaled_tensor = tex.grouped_quantize( x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer ) @@ -736,9 +1061,8 @@ def test_grouped_qdq( class TestFusedQuantize: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) @@ -860,7 +1184,7 @@ def test_quantize_dact_dbias_no_quantization( @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] @@ -886,7 +1210,7 @@ def test_quantize_dact_dbias_tensor_scaling( @pytest_parametrize_wrapper( "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] ) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] @@ -919,6 +1243,11 @@ def test_quantize_dact_dbias_mxfp8_scaling( (jnp.float8_e4m3fn, jnp.float8_e5m2), ] +supported_nvfp4_scaling_mode_pairs = [ + (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING), + (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING), +] + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): @@ -960,7 +1289,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): @@ -994,6 +1323,40 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) + # TODO(Phuong): add bitwise test + @pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) + @pytest_parametrize_wrapper("with_jax_gemm", [True, False]) + def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm): + x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T" + w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N" + if x_uses_rht != w_uses_rht: + # TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT + pytest.skip("RHT must be used for both or neither operand, skipping") + + lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + lhs_quantizer = QuantizerFactory.create( + scaling_mode=lhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + ) + rhs_quantizer = QuantizerFactory.create( + scaling_mode=rhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + contracting_dims=contracting_dims, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + ) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): data_layout = "NN" @@ -1019,11 +1382,10 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)]) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -1044,14 +1406,9 @@ def ref_func(x, w, bias, data_layout): value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True, - ) + quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( @@ -1062,10 +1419,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.fixture(name="random_inputs") @@ -1087,11 +1444,11 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) + @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): + def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ @@ -1108,12 +1465,7 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True, - ) + quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) if norm_type == "layernorm": beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) @@ -1148,7 +1500,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( @@ -1158,22 +1510,22 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) - assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype) + assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) + @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm + self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule @@ -1201,10 +1553,7 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True, + fp8_recipe=recipe, ) if norm_type == "layernorm": @@ -1251,7 +1600,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( @@ -1272,18 +1621,16 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2) - if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2) - - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2) + fwd_dtype = quantizer_sets[0].x.q_dtype + bwd_dtype = quantizer_sets[0].dgrad.q_dtype + assert_allclose(prim_out, ref_out, dtype=fwd_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype) + assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2) - - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype) + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype) # E5M2 * E5M2 is not supported @@ -1388,7 +1735,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): fwd_dtype, bwd_dtype = fwd_bwd_dtype @@ -1469,7 +1816,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): "fwd_bwd_dtype", [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], ) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index d38f43d002..bf78ed3bb4 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -1,6 +1,7 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import re from typing import Callable, Sequence, Union, Optional import pytest @@ -17,7 +18,11 @@ ) from transformer_engine.common import recipe -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import ( + is_fp8_available, + ScalingMode, + get_quantize_config_with_recipe, +) from transformer_engine.jax import fp8_autocast from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -33,19 +38,20 @@ W_JOINED_AXES, ) from transformer_engine.jax.sharding import MeshResource -from transformer_engine.jax.quantize import QuantizerFactory +from transformer_engine.jax.quantize import ( + QuantizerFactory, + get_supported_quantization_recipes, + is_scaling_mode_supported, +) from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability -is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) +is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) -SUPPORTED_RECIPES = [] -if is_fp8_supported: - SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling")) - SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling")) -if is_mxfp8_supported: - SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) +SUPPORTED_RECIPES = get_supported_quantization_recipes() +SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES] DTYPES = [jnp.bfloat16, jnp.float16] INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in] @@ -141,6 +147,7 @@ def layernorm_fp8_mlp_prim_func( layernorm_type: str = "rmsnorm", activation_type: Sequence[Union[str, Callable]] = ("gelu",), multi_gpus: bool = False, + quantization_recipe: recipe.Recipe = None, ) -> jnp.ndarray: if multi_gpus: @@ -154,7 +161,9 @@ def layernorm_fp8_mlp_prim_func( dot_1_input_axes = dot_2_input_axes = None kernel_1_axes = kernel_2_axes = None - quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) + quantizer_sets = QuantizerFactory.create_set( + n_quantizer_sets=2, fp8_recipe=quantization_recipe + ) # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 return jnp.mean( @@ -182,7 +191,7 @@ def _test_layernorm_mlp_grad( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, use_shardy, with_jax_gemm, ): @@ -202,7 +211,9 @@ def _test_layernorm_mlp_grad( # Single GPU with fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + mesh_resource=MeshResource(), ): single_jitter = jax.jit( value_and_grad_func, @@ -214,7 +225,9 @@ def _test_layernorm_mlp_grad( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + mesh_resource=mesh_resource, ): k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) @@ -254,10 +267,16 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn - bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 + fwd_test_type = bwd_test_type = dtype + if quantization_recipe is not None: + quantize_config = get_quantize_config_with_recipe(quantization_recipe) + fwd_test_type = quantize_config.FWD_DTYPE + bwd_test_type = quantize_config.BWD_DTYPE - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + if fwd_test_type == jnp.float16 and use_bias: + assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5) + else: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: @@ -278,13 +297,12 @@ def _test_layernorm_mlp_grad( err_msg=f"multi_grads[{i}] is not close", ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, @@ -293,27 +311,28 @@ def test_layernorm_mlp_grad( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, with_jax_gemm, ): + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, use_shardy=False, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( self, @@ -322,18 +341,18 @@ def test_layernorm_mlp_grad_shardy( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, with_jax_gemm, ): - if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=True, with_jax_gemm=with_jax_gemm, ) @@ -346,7 +365,7 @@ def _test_layernorm_mlp( input_shape, dtype, use_fp8, - fp8_recipe, + quantization_recipe, use_shardy, with_jax_gemm, ): @@ -355,14 +374,16 @@ def _test_layernorm_mlp( layernorm_type = "rmsnorm" rng = jax.random.PRNGKey(0) - subkeys = jax.random.split(rng, 2) + subkeys = jax.random.split(rng, 3) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) - init_rngs = {"params": subkeys[1]} + init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]} with use_jax_gemm(enabled=with_jax_gemm): # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with fp8_autocast( + enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=MeshResource() + ): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, intermediate_dim=INTERMEDIATE, @@ -371,7 +392,7 @@ def _test_layernorm_mlp( ) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) mlp_out_single, ln_out_single = ln_mlp_single.apply( - params_single, x, deterministic=True + params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]} ) # Multi GPUs @@ -379,7 +400,7 @@ def _test_layernorm_mlp( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( layernorm_type=layernorm_type, @@ -399,7 +420,7 @@ def _test_layernorm_mlp( ) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( - params_sharded, x, deterministic=True + params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]} ) # Make sure params values are the same @@ -411,8 +432,8 @@ def _test_layernorm_mlp( rtol = None l40_tolerance_update = ( get_min_device_compute_capability() == 89 - and fp8_recipe == recipe.DelayedScaling() and use_fp8 + and quantization_recipe.delayed() and dtype == jnp.float16 and activation_type == ("gelu",) ) @@ -430,8 +451,8 @@ def _test_layernorm_mlp( # within tolerance to the float32 ground truth. jax_triton_gemm_precision_tolerance_update = ( with_jax_gemm - and fp8_recipe is not None - and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()) + and quantization_recipe is not None + and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling()) and dtype in (jnp.bfloat16, jnp.float16) and activation_type == ("gelu", "linear"), ) @@ -457,22 +478,30 @@ def test_layernorm_mlp_layer( input_shape, dtype, use_fp8=False, - fp8_recipe=None, + quantization_recipe=None, use_shardy=False, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + quantization_recipe, + with_jax_gemm, ): + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp( mesh_config, activation_type, @@ -480,7 +509,7 @@ def test_layernorm_mlp_layer_fp8( input_shape, dtype, use_fp8=True, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=False, with_jax_gemm=with_jax_gemm, ) @@ -501,24 +530,30 @@ def test_layernorm_mlp_layer_shardy( input_shape, dtype, use_fp8=False, - fp8_recipe=None, + quantization_recipe=None, use_shardy=True, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + quantization_recipe, + with_jax_gemm, ): - if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp( mesh_config, activation_type, @@ -526,7 +561,7 @@ def test_layernorm_mlp_layer_fp8_shardy( input_shape, dtype, use_fp8=True, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=True, with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index e4511e1fe0..e9f71a32fb 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -10,20 +10,27 @@ import numpy as np from utils import assert_allclose -from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling +from transformer_engine.common.recipe import ( + DelayedScaling, + MXFP8BlockScaling, + Float8CurrentScaling, + NVFP4BlockScaling, +) from transformer_engine.common.recipe import Format as FP8Format -from transformer_engine.jax import fp8_autocast, get_delayed_scaling +from transformer_engine.jax import fp8_autocast from transformer_engine.jax.quantize import ( get_quantize_config, - is_fp8_available, + is_scaling_mode_supported, ScalingMode, update_collections, TensorSource, ) +from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.sharding import MeshResource, global_mesh_resource -is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) +is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) class TestHelper(unittest.TestCase): @@ -52,14 +59,16 @@ class TestFP8Functions(unittest.TestCase): def _check_default_state(self): self.assertFalse(get_quantize_config().is_fp8_enabled()) - def _compare_delay_scaling(self, ref, test): - self.assertTrue(ref.margin == test.margin) - self.assertTrue(ref.fp8_format == test.fp8_format) - self.assertTrue(ref.amax_history_len == test.amax_history_len) - self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) + def _compare_delay_scaling(self, test): + self.assertEqual(get_quantize_config().MARGIN, test.margin) + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) + self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len) + self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo) def _compare_current_scaling(self, test): - self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) for tensor_source in TensorSource: self.assertEqual( get_quantize_config().get_scaling_mode(tensor_source), @@ -67,13 +76,26 @@ def _compare_current_scaling(self, test): ) def _compare_mxfp8_scaling(self, test): - self.assertEqual(get_quantize_config().MARGIN, test.margin) - self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) for tensor_source in TensorSource: self.assertEqual( get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING ) + def _compare_nvfp4_scaling(self, test): + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1]) + for tensor_source in TensorSource: + target_scaling_mode = ( + ScalingMode.NVFP4_2D_SCALING + if tensor_source == TensorSource.KERNEL + else ScalingMode.NVFP4_1D_SCALING + ) + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode + ) + @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_delayed_scaling(self): self._check_default_state() @@ -86,14 +108,14 @@ def test_fp8_autocast_delayed_scaling(self): ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) + self._compare_delay_scaling(ds) self._check_default_state() ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) + self._compare_delay_scaling(ds) self._check_default_state() @@ -133,16 +155,27 @@ def test_fp8_autocast_mxfp8_block_scaling(self): self._check_default_state() - bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) + bs = MXFP8BlockScaling() with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() - bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) + @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason) + def test_fp8_autocast_nvfp4_block_scaling(self): + self._check_default_state() + + with fp8_autocast( + enabled=False, fp8_recipe=NVFP4BlockScaling(), mesh_resource=MeshResource() + ): + self._check_default_state() + + self._check_default_state() + + bs = NVFP4BlockScaling() with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_mxfp8_scaling(bs) + self._compare_nvfp4_scaling(bs) self._check_default_state() diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 8ad6dccfec..c28e68a15f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1544,6 +1544,12 @@ def dtype_tols( rtol = eps_relaxed if atol is None: atol = max(ulp, eps_relaxed) + + # Manually set tols for nvfp4 + if dtype == jnp.float4_e2m1fn: + atol = 0.05 + rtol = 0.1 + return {"rtol": rtol, "atol": atol} diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 0b5e43402f..354a1293e6 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,7 +34,7 @@ from . import flax from . import quantize -from .quantize import fp8_autocast, update_collections, get_delayed_scaling +from .quantize import fp8_autocast, update_collections from .quantize import NVTE_FP8_COLLECTION_NAME from .sharding import MeshResource @@ -47,7 +47,6 @@ "NVTE_FP8_COLLECTION_NAME", "fp8_autocast", "update_collections", - "get_delayed_scaling", "MeshResource", "flax", "quantize", diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index ef8d76cd05..c0285e157a 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for c++ extensions""" from .activation import * +from .amax import * from .attention import * from .normalization import * from .quantization import * diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index be1f9f9564..bb3c56bcf1 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1314,7 +1314,10 @@ def act_lu( ) return out - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( x=x, @@ -1488,7 +1491,10 @@ def quantize_dact_dbias( if war_output is not None: return war_output - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( dz=dz, diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py new file mode 100644 index 0000000000..2f3bc402ec --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -0,0 +1,420 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for amax calculation""" +from enum import Enum + + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.experimental.custom_partitioning import SdyShardingRule +from jax.sharding import PartitionSpec + +from .base import BasePrimitive, register_primitive +from .misc import ( + get_padded_spec, + NamedSharding, +) +from ..sharding import ( + global_mesh_resource, + lax_paral_op, +) +from ..quantize import ( + get_wgrad_sign_vector, + get_sign_from_vector, +) + + +__all__ = ["AmaxScope", "calculate_amax", "calculate_post_rht_amax"] + + +class AmaxScope(Enum): + """ + Amax Scope Enum + """ + + LOCAL = 1 + TPSP = 2 + FSDP = 3 + + def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh): + """Reduce the amax based on its scope""" + gmesh = global_mesh_resource() + sequence_dim = 0 if transpose_batch_sequence else 1 + # Run AR across TPSP only when tensor-sequence is detected in the input spec + if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource: + return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + # Run AR across FSDP + if self is AmaxScope.FSDP: + return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + return amax + + +class AmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning + """ + + name = "jax_local_amax" + multiple_results = False + impl_static_args = ( + 1, + 2, + ) # amax_scope, transpose_batch_sequence + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + transpose_batch_sequence, + ): + """ + amax calcuation abstract + """ + del amax_scope, transpose_batch_sequence + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + return out_aval + + @staticmethod + def impl( + x, + amax_scope, + transpose_batch_sequence, + ): + """ + amax calcuation implementation + """ + del amax_scope, transpose_batch_sequence + amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) + return amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + transpose_batch_sequence, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding + + @staticmethod + def partition( + amax_scope, + transpose_batch_sequence, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculation.amax_sharding", + ) + + def sharded_impl(x): + amax = AmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh + ) + + return amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, amax_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types): + """ + amax calcuation shardy_sharding_rule + """ + del amax_scope, transpose_batch_sequence, mesh, result_types + prefix = "AmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_spec = (f"{prefix}_amax",) + return SdyShardingRule((input_spec,), (output_spec,)) + + +register_primitive(AmaxCalculationPrimitive, outer_only=True) + + +class RHTAmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning for calculating regular and post-Random Hadamard Transform (RHT) amax using TE's fused kernels. + """ + + name = "te_rht_amax_ffi" + multiple_results = True + impl_static_args = ( + 1, # amax_scope + 2, # transpose_batch_sequence + 3, # rht_matrix_random_sign_mask_t + 4, # produce_regular_amax + 5, # flatten_axis + ) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + amax calcuation abstract + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ) + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.bfloat16], f"RHT requires input to be bfloat16, but got {dtype}" + + amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + post_rht_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + return amax_aval, post_rht_amax_aval + + @staticmethod + def lowering( + ctx, + x, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + te_dbias_quantize_p lowering rules + """ + del amax_scope, transpose_batch_sequence + (x_aval,) = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + flatten_axis = flatten_axis if flatten_axis >= 0 else flatten_axis + len(x_aval.shape) + assert 0 < flatten_axis < len(x_aval.shape), "Flatten axis out of bounds!" + + return ffi.ffi_lowering( + RHTAmaxCalculationPrimitive.name, + )( + ctx, + x, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + + @staticmethod + def impl( + x, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + amax calcuation implementation + """ + assert RHTAmaxCalculationPrimitive.inner_primitive is not None + ( + amax, + post_rht_amax, + ) = RHTAmaxCalculationPrimitive.inner_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + return amax, post_rht_amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + arg_infos, + result_infos, + ) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="RHTAmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding, amax_sharding + + @staticmethod + def partition( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="RHTAmaxCalculationPrimitive.amax_sharding", + ) + out_shardings = (amax_sharding, amax_sharding) + + def sharded_impl(x): + amax, post_rht_amax = RHTAmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh + ) + post_rht_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + post_rht_amax, x_spec, transpose_batch_sequence, mesh + ) + + return amax, post_rht_amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + value_types, + result_types, + ): + """ + amax calcuation shardy_sharding_rule + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + result_types, + ) + prefix = "RHTAmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_amax_spec = (f"{prefix}_amax",) + output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",) + return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec)) + + +register_primitive(RHTAmaxCalculationPrimitive) + + +def calculate_amax(x: jnp.ndarray, amax_scope: AmaxScope, transpose_batch_sequence: bool): + """ + Compute the maximum absolute value (amax) of the input tensor. + """ + assert AmaxCalculationPrimitive.outer_primitive is not None + return AmaxCalculationPrimitive.outer_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + + +def calculate_post_rht_amax( + x: jnp.ndarray, + amax_scope: AmaxScope, + transpose_batch_sequence: bool, + produce_regular_amax: bool, + flatten_axis: int, +): + """Compute the post-Random Hadamard Transform (RHT) amax of the input tensor, and optionally the regular amax. + + Args: + x: Input tensor. + amax_scope: The scope for amax reduction (local, TPSP, or FSDP). + transpose_batch_sequence: Whether the input tensor has its batch and sequence dimensions transposed. + produce_regular_amax: Whether to compute and return the regular amax alongside the post-RHT amax. + flatten_axis: The axis at which to flatten the input tensor before applying RHT. + Returns: + A tuple containing: + - The regular amax if `produce_regular_amax` is True, otherwise None. + - The post-RHT amax. + """ + amax, post_rht_amax = RHTAmaxCalculationPrimitive.outer_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=get_sign_from_vector(get_wgrad_sign_vector()), + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + + if produce_regular_amax: + return amax, post_rht_amax + return None, post_rht_amax diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 7fe433bcc6..b72161f1aa 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -32,6 +32,7 @@ AbstractBaseTensor, NoScaleTensor, ScaledTensor, + ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, @@ -43,6 +44,7 @@ noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, + should_use_rht, ) from .misc import get_padded_spec, is_all_reduce_in_float32 from ..sharding import ( @@ -138,6 +140,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ need_lhs_colwise = lhs_is_transposed and ( lhs_quantizer.scaling_mode.is_1d_block_scaling() or not is_fp8_gemm_with_all_layouts_supported() + or lhs_quantizer.scaling_mode.is_nvfp4_scaling ) flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) lhs_q = lhs_quantizer.quantize( @@ -153,6 +156,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ need_rhs_colwise = not rhs_is_transposed and ( rhs_quantizer.scaling_mode.is_1d_block_scaling() or not is_fp8_gemm_with_all_layouts_supported() + or rhs_quantizer.scaling_mode.is_nvfp4_scaling ) flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( @@ -165,9 +169,27 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x) + def uses_rht(q: AbstractBaseTensor) -> bool: + return isinstance(q, ScaledTensor1x) and should_use_rht( + q.scaling_mode, is_colwise=q.is_colwise + ) + + # TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class + assert uses_rht(lhs_q) == uses_rht(rhs_q), ( + "With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise" + " quantized as well. This is to ensure the RHT is applied to both and will cancel out in" + " the GEMM." + ) + return lhs_q, rhs_q +def _get_nvfp4_tensor_scale_inv(amax): + DATA_DTYPE_MAX = jnp.finfo(jnp.float4_e2m1fn.dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(jnp.float8_e4m3fn.dtype).max.astype(jnp.float32) + return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) + + def collective_gemm_bootstrap( num_total_devices, num_devices_per_process, @@ -345,7 +367,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) inner_primitive = None outer_primitive = None @@ -357,6 +379,8 @@ def abstract( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -404,7 +428,9 @@ def _dims_are_consecutive(dims): lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) if scaling_mode != ScalingMode.NO_SCALING: - assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes( + lhs.dtype, rhs.dtype + ), ( "cuBLAS GEMM quantized operands have incompatible data types: " f"{lhs.dtype} x {rhs.dtype}." ) @@ -484,6 +510,8 @@ def _dims_are_consecutive(dims): f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + assert alpha.size == 1 and alpha.dtype == jnp.float32 + assert beta.size == 1 and beta.dtype == jnp.float32 # Declare cuBLAS workspace workspace_size = get_cublas_workspace_size_bytes() @@ -510,6 +538,8 @@ def lowering( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -530,7 +560,7 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) - args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { "scaling_mode": int(scaling_mode.value), "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), @@ -563,6 +593,8 @@ def impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -626,6 +658,8 @@ def impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, @@ -675,6 +709,8 @@ def outer_impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -694,6 +730,8 @@ def outer_impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -1001,6 +1039,9 @@ def partition( gelu_input_specs = (None,) arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) + # Alpha, beta + arg_shardings += (none_sharding, none_sharding) + # Assemble output shardings out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] @@ -1014,7 +1055,7 @@ def partition( pre_gelu_specs = (None,) out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) - def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta): # We should not fuse bias in the output reduction case sharded_fuse_bias = fuse_bias and reduce_spec is None outputs = GemmPrimitive.impl( @@ -1024,6 +1065,8 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, @@ -1114,8 +1157,10 @@ def _generate_operand_rules(name, ndim, cdims): rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) out_spec = (*lhs_non_cspec, *rhs_non_cspec) bias_spec = rhs_non_cspec if fuse_bias else ("…4",) - dbias_spec = bias_spec if grad else ("…5") - gelu_spec = out_spec if fuse_gelu else ("…6",) + gelu_spec = out_spec if fuse_gelu else ("…5",) + alpha_spec = ("_6",) + beta_spec = ("_7",) + dbias_spec = bias_spec if grad else ("…8") return SdyShardingRule( operand_mappings=( @@ -1125,6 +1170,8 @@ def _generate_operand_rules(name, ndim, cdims): rhs_scale_specs, bias_spec, gelu_spec, + alpha_spec, + beta_spec, ), result_mappings=( out_spec, @@ -1178,6 +1225,7 @@ def _te_gemm( # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + lhs_amax = rhs_amax = None # Extract GEMM custom op inputs from quantized operands if isinstance(lhs_q, ScaledTensor): assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( @@ -1192,6 +1240,7 @@ def _te_gemm( lhs_scale_inv = lhs_q.scale_inv if lhs_q.data_layout == "T": lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_amax = lhs_q.amax if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -1201,7 +1250,11 @@ def _te_gemm( if isinstance(rhs_q, ScaledTensor2x): # Choose the quantization of the contracting dimension(s) rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() - assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + assert ( + rhs_q.scaling_mode == lhs_q.scaling_mode + or rhs_q.scaling_mode.is_nvfp4_scaling + and lhs_q.scaling_mode.is_nvfp4_scaling + ), ( "cuBLAS GEMM quantized operands have mismatched scaling types, " f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." ) @@ -1209,6 +1262,15 @@ def _te_gemm( rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_amax = rhs_q.amax + + alpha = jnp.ones((1,), jnp.float32) + beta = jnp.zeros((1,), jnp.float32) + if scaling_mode.is_nvfp4_scaling: + assert lhs_amax is not None and rhs_amax is not None + lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax) + rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) + alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -1224,6 +1286,8 @@ def _te_gemm( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=(lhs_cdims, rhs_cdims), scaling_mode=scaling_mode, @@ -1514,15 +1578,17 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): @partial(jax.jit, static_argnums=(2,)) -def _jax_gemm_mxfp8_1d( +def _jax_scaled_matmul( lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] ): """ JAX GEMM for MXFP8 via scaled_matmul """ - assert ( - rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING - ), "rhs does not have MXFP8 1D scaling mode" + assert rhs.scaling_mode in ( + ScalingMode.MXFP8_1D_SCALING, + ScalingMode.NVFP4_1D_SCALING, + ScalingMode.NVFP4_2D_SCALING, + ), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -1537,21 +1603,48 @@ def _jax_gemm_mxfp8_1d( f" {rhs.is_colwise}" ) + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: + out_dtype = lhs.dq_dtype + assert ( + lhs.data_layout == "N" and rhs.data_layout == "N" + ), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}" + else: + if lhs.data_layout == "T": + lhs_contract = transpose_dims( + lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis + ) + if rhs.data_layout == "T": + rhs_contract = transpose_dims( + rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis + ) + out_dtype = jnp.float32 + # Reshape + Transpose (if needed) # [..., M, K] -> [1, reduce(..., M), K] # [..., K, M] -> [1, reduce(..., M), K] - lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch)) - rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch)) - lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch)) - rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch)) + lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch), lhs.data_layout == "T") + rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch), rhs.data_layout == "T") + lhs_scale_3d = _shape_normalization( + lhs.scale_inv, (lhs_contract, lhs_batch), lhs.data_layout == "T" + ) + rhs_scale_3d = _shape_normalization( + rhs.scale_inv, (rhs_contract, rhs_batch), rhs.data_layout == "T" + ) # JAX scaled_matmul only supports NT now (TN-gemm) # * Expected shape: # * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block) out_3d = jax.nn.scaled_matmul( - lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype + lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype ) + if lhs.scaling_mode.is_nvfp4_scaling: + assert lhs.amax is not None and rhs.amax is not None + lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax) + rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax) + alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv + out_3d = (out_3d * alpha).astype(lhs.dq_dtype) + # Reshape [1, reduce(..., M), N] -> [..., M, N] lhs_remain_shape = tuple( lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract @@ -1560,6 +1653,7 @@ def _jax_gemm_mxfp8_1d( rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract ) out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape) + return out @@ -1575,7 +1669,7 @@ def _jax_gemm( """ dim_nums = (contracting_dims, ((), ())) - def _jax_gemm_fp8_impl(lhs, rhs): + def _jax_gemm_impl(lhs, rhs): if lhs.scaling_mode.is_tensor_scaling(): assert ( rhs.scaling_mode == lhs.scaling_mode @@ -1587,15 +1681,15 @@ def _jax_gemm_fp8_impl(lhs, rhs): ) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) - if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) + if lhs.scaling_mode.is_1d_block_scaling: + return _jax_scaled_matmul(lhs, rhs, dim_nums) raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}") lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + return _jax_gemm_impl(lhs_q, rhs_q) if ( isinstance(lhs, jnp.ndarray) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 52f5edbf3a..572d82f18d 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -6,8 +6,6 @@ import os import functools from typing import Tuple -from importlib.metadata import version as get_pkg_version -from packaging.version import Version as PkgVersion import numpy as np @@ -75,7 +73,8 @@ def jax_dtype_to_te_dtype(jax_dtype): jnp.int64.dtype: TEDType.kInt64, jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, - jnp.uint8.dtype: TEDType.kByte, + jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0, + jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1, } if jax_dtype not in converter: @@ -151,16 +150,6 @@ def get_cudnn_version() -> Tuple[int, int, int]: return (major, minor, patch) -@functools.lru_cache(maxsize=None) -def jax_version_meet_requirement(version: str): - """ - Helper function checking if required JAX version is available - """ - jax_version = PkgVersion(get_pkg_version("jax")) - jax_version_required = PkgVersion(version) - return jax_version >= jax_version_required - - def get_xla_flag(flag: str, default=None, cast=str): """ Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3ce8a19a76..90ab5fb7fe 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -28,7 +28,10 @@ get_cudnn_version, ) from .quantization import _quantize_dbias_impl, AmaxScope -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp +from ..sharding import ( + all_reduce_max_along_all_axes_except_PP, + all_reduce_sum_along_dp_fsdp_tpsp, +) from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, @@ -1031,7 +1034,10 @@ def layernorm_fwd( ) return out, mu, rsigma - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. out, mu, rsigma = layernorm_fwd( x=x, @@ -1276,7 +1282,10 @@ def rmsnorm_fwd( ) return out, rsigma - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. out, rsigma = rmsnorm_fwd( x=x, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 38fd50a00f..b3f1e60f9a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,7 +6,6 @@ from functools import reduce from typing import Tuple, Optional, Union import math -from enum import Enum import jax @@ -17,6 +16,7 @@ import transformer_engine_jax +from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, @@ -31,8 +31,7 @@ from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, - global_mesh_resource, - lax_paral_op, + num_of_devices, ) from ..quantize import ( ScaledTensor2x, @@ -45,6 +44,8 @@ ScalingMode, compute_scale_from_amax, NoScaleTensor, + get_rht_matrix, + should_use_rht, ) @@ -59,14 +60,16 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): name = "te_dbias_quantize_ffi" multiple_results = True impl_static_args = ( - 3, - 4, - 5, - 6, - 7, - 8, - 9, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer + 6, # out_dtype + 7, # scaling_mode + 8, # q_layout + 9, # flatten_axis + 10, # scale_dtype + 11, # is_dbias + 12, # is_outer + 13, # stochastic_rounding + 14, # use_rht + ) inner_primitive = None outer_primitive = None @@ -75,6 +78,9 @@ def abstract( x_aval, scale_aval, amax_aval, + sr_rng_state_aval, + post_rht_amax_aval, + rht_matrix_aval, *, out_dtype, scaling_mode, @@ -83,6 +89,8 @@ def abstract( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p abstract @@ -91,6 +99,28 @@ def abstract( assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape assert scale_aval is None or scale_aval.dtype == jnp.float32 + if stochastic_rounding: + assert ScalingMode( + scaling_mode + ).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes" + # JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64 + assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, ( + "sr_rng_state must be a uint32 array when stochastic_rounding is True but" + f" received {sr_rng_state_aval}" + ) + if is_outer: + assert ( + sr_rng_state_aval.shape[0] == num_of_devices() + and sr_rng_state_aval.shape[1] == 4 + ), ( + "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is" + f" True and is_outer is True but received {sr_rng_state_aval.shape}" + ) + else: + assert sr_rng_state_aval.shape == (4,), ( + "Sharded sr_rng_state must be of shape (4,) per device when" + f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" + ) if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): rowwise_out_shape = out_shape @@ -98,14 +128,50 @@ def abstract( rowwise_out_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) + assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), ( + f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must" + f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}" + ) + updated_amax_aval = amax_aval + if use_rht: + assert ( + x_aval.dtype == jnp.bfloat16 + ), "x must be of dtype bfloat16 to be eligible for RHT cast fusion." + + if flatten_axis < 0: + flatten_axis += len(x_aval.shape) + rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1) + cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1) + assert rows % 64 == 0 and cols % 128 == 0, ( + "Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be" + f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape" + f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}." + ) + + assert ( + rht_matrix_aval is not None + and rht_matrix_aval.dtype == jnp.bfloat16 + and rht_matrix_aval.shape == (16, 16) + ), "rht_matrix must be of shape (16, 16) and dtype bfloat16" + assert ( + post_rht_amax_aval is not None + and post_rht_amax_aval.dtype == jnp.float32 + and post_rht_amax_aval.size == 1 + ), "post_rht_amax must be of dtype float32" + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) + ).get_scale_shape_2x( + x_aval.shape, + is_padded=not is_outer, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -126,6 +192,7 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(scale_dtype), scaling_mode, QuantizeLayout( q_layout @@ -172,6 +239,9 @@ def lowering( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, *, out_dtype, scaling_mode, @@ -180,12 +250,14 @@ def lowering( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p lowering rules """ del out_dtype, scale_dtype, is_outer - x_aval, scale_aval, amax_aval = ctx.avals_in + x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == amax_aval.dtype == jnp.float32 return ffi.ffi_lowering( @@ -196,10 +268,15 @@ def lowering( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, is_dbias=is_dbias, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) @staticmethod @@ -207,6 +284,9 @@ def impl( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype, scaling_mode, q_layout, @@ -214,6 +294,8 @@ def impl( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p implementation @@ -232,6 +314,9 @@ def impl( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -239,10 +324,14 @@ def impl( scale_dtype=scale_dtype, is_dbias=is_dbias, is_outer=False, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis) + ).get_scale_shape_2x( + x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True + ) scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) @@ -271,6 +360,8 @@ def batcher( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ to describe batch rules for vmap @@ -278,8 +369,8 @@ def batcher( del is_outer check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale, amax = batched_args - x_bdim, scale_bdim, amax_bdim = batch_dims + x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args + x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim return ( @@ -287,12 +378,17 @@ def batcher( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, flatten_axis=flatten_axis, scale_dtype=scale_dtype, is_dbias=is_dbias, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ), out_bdims, ) @@ -306,11 +402,20 @@ def infer_sharding_from_operands( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, arg_infos, result_infos, ): - del (out_dtype, result_infos, scale_dtype, is_outer) # Unused. + del ( + out_dtype, + result_infos, + scale_dtype, + is_outer, + stochastic_rounding, + use_rht, + ) # Unused. x_spec = get_padded_spec(arg_infos[0]) amax_spec = get_padded_spec(arg_infos[2]) @@ -320,7 +425,7 @@ def infer_sharding_from_operands( desc="BaseDBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -340,11 +445,19 @@ def infer_sharding_from_operands( ) scale_inv_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv_spec = scale_inv_spec + if ( + ScalingMode(scaling_mode).is_block_scaling + and ScalingMode(scaling_mode).is_colwise_transposed + ): + colwise_scale_inv_spec = multidim_transpose( + scale_inv_spec, transpose_axis=flatten_axis + ) + else: + colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" @@ -376,11 +489,13 @@ def partition( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, arg_infos, result_infos, ): - del result_infos, is_outer + del result_infos, is_outer # Unused. x_spec = get_padded_spec(arg_infos[0]) amax_spec = get_padded_spec(arg_infos[2]) @@ -389,8 +504,9 @@ def partition( PartitionSpec(*x_spec), desc="BaseDBiasQuantizePrimitive.out_sharding", ) + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -410,11 +526,19 @@ def partition( ) scale_inv_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv_spec = scale_inv_spec + if ( + ScalingMode(scaling_mode).is_block_scaling + and ScalingMode(scaling_mode).is_colwise_transposed + ): + colwise_scale_inv_spec = multidim_transpose( + scale_inv_spec, transpose_axis=flatten_axis + ) + else: + colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" @@ -428,6 +552,7 @@ def partition( desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) + # TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, @@ -438,7 +563,7 @@ def partition( dbias_sharding, ) - def sharded_impl(x, scale, amax): + def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): ( local_x, local_colwise_x, @@ -450,6 +575,9 @@ def sharded_impl(x, scale, amax): x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -457,6 +585,8 @@ def sharded_impl(x, scale, amax): scale_dtype=scale_dtype, is_dbias=is_dbias, is_outer=True, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: @@ -489,35 +619,54 @@ def shardy_sharding_rule( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, is_outer, mesh, result_types + del ( + out_dtype, + scale_dtype, + is_outer, + stochastic_rounding, + use_rht, + mesh, + result_types, + ) prefix = "DBiasQuantize_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[0].shape, unique_var=prefix + "x", flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, ) x_axes = scale_rules.input_spec - colwise_scale_inv = scale_rules.colwise_rule out = x_axes colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "colwise_scale_inv",) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + colwise_scale_inv = scale_rules.colwise_rule + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) + colwise_scale_inv = tuple( + multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis) + ) else: colwise_out = x_axes dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) + sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis") + + post_rht_amax = (prefix + "post_rht_amax",) + rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2") return SdyShardingRule( - (x_axes, ("…1",), amax), + (x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), **scale_rules.factor_sizes, ) @@ -534,141 +683,6 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive): """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -class AmaxScope(Enum): - """ - Amax Scope Enum - """ - - LOCAL = 1 - TPSP = 2 - FSDP = 3 - - def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh): - """Reduce the amax based on its scope""" - gmesh = global_mesh_resource() - sequence_dim = 0 if transpose_batch_sequence else 1 - # Run AR across TPSP only when tensor-sequence is detected in the input spec - if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource: - return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) - # Run AR across FSDP - if self is AmaxScope.FSDP: - return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) - return amax - - -class AmaxCalculationPrimitive(BasePrimitive): - """ - Amax Calculation Primitive with custom_partitioning - """ - - name = "jax_local_amax" - multiple_results = False - impl_static_args = ( - 1, - 2, - ) # amax_scope, transpose_batch_sequence - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract( - x_aval, - *, - amax_scope, - transpose_batch_sequence, - ): - """ - amax calcuation abstract - """ - del amax_scope, transpose_batch_sequence - - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - - out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - return out_aval - - @staticmethod - def impl( - x, - amax_scope, - transpose_batch_sequence, - ): - """ - amax calcuation implementation - """ - del amax_scope, transpose_batch_sequence - amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) - return amax - - @staticmethod - def infer_sharding_from_operands( - amax_scope, - transpose_batch_sequence, - mesh, - arg_infos, - result_infos, - ): - """ - amax calcuation infer_sharding_from_operands - """ - del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. - amax_sharding = NamedSharding( - mesh, - PartitionSpec(None), - desc="AmaxCalculationPrimitive.out_sharding", - ) - return amax_sharding - - @staticmethod - def partition( - amax_scope, - transpose_batch_sequence, - mesh, - arg_infos, - result_infos, - ): - """ - amax calcuation partition - """ - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - amax_sharding = NamedSharding( - mesh, - PartitionSpec(None), - desc="AmaxCalculation.amax_sharding", - ) - - def sharded_impl(x): - amax = AmaxCalculationPrimitive.impl( - x, - amax_scope=amax_scope, - transpose_batch_sequence=transpose_batch_sequence, - ) - amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( - amax, x_spec, transpose_batch_sequence, mesh - ) - - return amax - - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - return mesh, sharded_impl, amax_sharding, arg_shardings - - @staticmethod - def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types): - """ - amax calcuation shardy_sharding_rule - """ - del amax_scope, transpose_batch_sequence, mesh, result_types - prefix = "AmaxCal" - input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) - output_spec = (f"{prefix}_amax",) - return SdyShardingRule((input_spec,), (output_spec,)) - - -register_primitive(AmaxCalculationPrimitive, outer_only=True) - - def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): @@ -740,7 +754,11 @@ def _quantize_dbias_impl( # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): + is_unsupported = ( + quantizer.q_layout == QuantizeLayout.COLWISE + and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING + ) + if is_unsupported or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -767,15 +785,32 @@ def _quantize_dbias_impl( dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias + use_rht = False + scale = jnp.empty((1,), jnp.float32) - amax = None + post_rht_amax = None + rht_matrix = jnp.empty((1, 1), jnp.bfloat16) + amax = x.amax + + if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout): + use_rht = True + rht_matrix = get_rht_matrix() + + new_amax, post_rht_amax = calculate_post_rht_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + produce_regular_amax=amax is None, + flatten_axis=flatten_axis, + ) + if amax is None: + # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel + # So here we only calculate and update amax when it is not provided from a previous layer (amax is None) + amax = new_amax + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - # Globally reduce amax across all devices for current scaling so we have a single global scale. - # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this - # until the tensor is dequantized (e.g. in the GEMM). - amax = x.amax if amax is None: - amax = AmaxCalculationPrimitive.outer_primitive.bind( + amax = calculate_amax( x.data, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -783,8 +818,17 @@ def _quantize_dbias_impl( scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale + # Make sure to reset amax to zeros for DelayedScaling + amax = jnp.zeros((1,), jnp.float32) + elif quantizer.scaling_mode.is_nvfp4_scaling: + if amax is None: + amax = calculate_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) - # Make sure amax is init with zero + # Make sure amax is not None if amax is None: amax = jnp.zeros((1,), jnp.float32) @@ -796,9 +840,16 @@ def _quantize_dbias_impl( and is_1x_kernel_supported ) q_layout = quantizer.q_layout + if force_1x_quantization: q_layout = QuantizeLayout.ROWWISE + sr_rng_state = None + if quantizer.scaling_mode.is_nvfp4_scaling: + # Only NVFP4 scaling modes support stochastic rounding + if quantizer.stochastic_rounding_rng_state is not None: + sr_rng_state = quantizer.stochastic_rounding_rng_state + ( rowwise_casted_output, colwise_casted_output, @@ -810,13 +861,18 @@ def _quantize_dbias_impl( x.data, scale, amax, + sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32), + post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), + rht_matrix, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout.value, flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), - is_dbias=is_dbias, + is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, is_outer=True, + stochastic_rounding=sr_rng_state is not None, + use_rht=use_rht, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): @@ -830,14 +886,17 @@ def _quantize_dbias_impl( colwise_casted_output = jnp.transpose( rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis)) ) - quantizer.update(updated_amax) + if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias: + dbias = _jax_dbias(x, flatten_axis=flatten_axis) out = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, colwise_data=colwise_casted_output, colwise_scale_inv=colwise_scale_inv, + amax=updated_amax, + colwise_amax=post_rht_amax, scaling_mode=quantizer.scaling_mode, dq_dtype=dq_dtype, q_layout=quantizer.q_layout, @@ -955,6 +1014,11 @@ def abstract( # TODO(Phuong): can scale_aval be None? assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), ( + f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must" + f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}" + ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_grouped_scale_shape_2x( diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3ce6dee731..87c6fa91cd 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -85,7 +85,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, + DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, QuantizeLayout q_layout); @@ -138,6 +138,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); +// Amax +XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp new file mode 100644 index 0000000000..46f167fcaf --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -0,0 +1,100 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include + +#include + +#include "../extensions.h" +#include "transformer_engine/cast.h" +#include "transformer_engine/hadamard_transform.h" +#include "transformer_engine/recipe.h" +#include "transformer_engine/transformer_engine.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type RHTAmaxCalculationFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type amax_buf, + Result_Type post_rht_amax_buf, + int64_t rht_matrix_random_sign_mask_t, bool produce_regular_amax, + int64_t flatten_axis) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, + "Input must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(input_buf.element_type()) == DType::kBFloat16, + "Input must be of type bfloat16 for RHT Amax calculation"); + + NVTE_CHECK(flatten_axis > 0 && flatten_axis < static_cast(input_buf.dimensions().size()), + "Flatten axis is out of bounds"); + TensorWrapper input_tensor(input_buf.untyped_data(), + std::vector{product(input_buf.dimensions(), 0, flatten_axis), + product(input_buf.dimensions(), flatten_axis, + input_buf.dimensions().size())}, + convert_ffi_datatype_to_te_dtype(input_buf.element_type())); + + float *amax_out = nullptr; + if (produce_regular_amax) { + amax_out = reinterpret_cast(amax_buf->untyped_data()); + NVTE_CHECK(amax_out != nullptr, "Amax output must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(amax_buf->element_type()) == DType::kFloat32, + "Amax output must be of type float32 for RHT Amax calculation"); + NVTE_CHECK(amax_buf->dimensions().size() == 1 && amax_buf->dimensions()[0] == 1, + "Amax output must be a single float for RHT Amax calculation"); + } + + float *post_rht_amax_out = reinterpret_cast(post_rht_amax_buf->untyped_data()); + NVTE_CHECK(post_rht_amax_out != nullptr, + "Post-RHT Amax output must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(post_rht_amax_buf->element_type()) == DType::kFloat32, + "Post-RHT Amax output must be of type float32 for RHT Amax calculation"); + NVTE_CHECK(post_rht_amax_buf->dimensions().size() == 1 && post_rht_amax_buf->dimensions()[0] == 1, + "Post-RHT Amax output must be a single float for RHT Amax calculation"); + + TensorWrapper out_tensor{}; + out_tensor.set_amax(amax_out, DType::kFloat32, std::vector{1}); + out_tensor.set_columnwise_amax(post_rht_amax_out, DType::kFloat32, std::vector{1}); + + // Zero'ing of amaxes is handled by TE common inside nvte_hadamard_transform_amax + nvte_hadamard_transform_amax(input_tensor.data(), out_tensor.data(), + 0, // Regular amax for rowwise does not apply RHT so mask is 0 + rht_matrix_random_sign_mask_t, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RHTAmaxCalculationHandler, RHTAmaxCalculationFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // amax + .Ret() // post_rht_amax + .Attr("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t + .Attr("produce_regular_amax") // produce_regular_amax + .Attr("flatten_axis"), // flatten_axis + FFI_CudaGraph_Traits); + +Error_Type RHTAmaxCalculationInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type amax_buf, Result_Type post_rht_amax_buf, + int64_t rht_matrix_random_sign_mask_t, + bool produce_regular_amax, int64_t flatten_axis) { + return wrapInStreamCapture(std::function(RHTAmaxCalculationFFI), stream, input_buf, amax_buf, + post_rht_amax_buf, rht_matrix_random_sign_mask_t, produce_regular_amax, + flatten_axis); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RHTAmaxCalculationInitializeHandler, RHTAmaxCalculationInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // amax + .Ret() // post_rht_amax + .Attr("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t + .Attr("produce_regular_amax") // produce_regular_amax + .Attr("flatten_axis")); // flatten_axis + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index e77c38e990..a0425efda6 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -41,6 +41,9 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { case xla::ffi::DataType::F8E8M0FNU: return DType::kFloat8E8M0; break; + case xla::ffi::DataType::F4E2M1FN: + return DType::kFloat4E2M1; + break; default: auto type_num = static_cast(type); NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 82f062a15b..0fc2e83898 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -102,6 +102,8 @@ inline static size_t te_dtype_bytes(const DType& type) { return 1; case DType::kFloat8E8M0: return 1; + case DType::kFloat4E2M1: + return 1; default: NVTE_ERROR("Unsupported DType: ", static_cast(type)); } diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 993ec1377d..8a3658a0ba 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -51,7 +51,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set scaling factor for quantized tensors if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { - NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(is_nvfp4_scaling(scaling_mode) || typeToSize(input_dtype) == 1, + "Quantized GEMM requires 4-bit or 8-bit operands."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); std::vector scale_shape = {1}; @@ -74,7 +75,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, - Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, + Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, @@ -119,6 +121,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out @@ -136,11 +140,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, - Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, - int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, - JAXX_Collective_Op collective_op) { + Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || @@ -192,10 +196,31 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); - - // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + float one = 1.; + float zero = 0.; + // alpha, beta + float *alpha_ptr = &one, *beta_ptr = &zero; + if (is_nvfp4_scaling(scaling_mode)) { + NVTE_CHECK(alpha.element_count() == 1 && + convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32); + alpha_ptr = reinterpret_cast(alpha.untyped_data()); + NVTE_CHECK(beta.element_count() == 1 && + convert_ffi_datatype_to_te_dtype(beta.element_type()) == DType::kFloat32); + beta_ptr = reinterpret_cast(beta.untyped_data()); + } + + // Construct GEMM config + transformer_engine::MatmulConfigWrapper config; + config.set_use_split_accumulator(use_split_accumulator); + config.set_sm_count(num_math_sm); + if (fuse_bias) config.set_bias_tensor(bias_.data()); + if (fuse_gelu) { + config.set_with_gelu_epilogue(true); + config.set_epilogue_aux_tensor(pre_gelu_.data()); + } + if (collective_op == JAXX_Collective_Op::NONE) { auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); NVTE_CHECK(out_.numel() == output->element_count(), @@ -205,9 +230,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, ", out_shape[1]=", out_shape[1]); - nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr, + rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/, + out_.data() /*D*/, workspace_.data(), config, stream); } else { std::vector buffer_shape{0, 0}; DType buffer_dtype = out_dtype; @@ -268,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out @@ -599,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // point to swizzled scale_inv data (store on workspace, only used for GEMM). // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers auto lhs_sinv_shape_i = - get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); + get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); auto rhs_sinv_shape_i = - get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); + get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; if (lhs_use_colwise) { diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index ee81b5ad72..176115ade9 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -26,11 +26,21 @@ std::vector Shape::to_vector() const { return shape; } -std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) { - auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x; - auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y; - auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x; - auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y; +std::vector get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N, + bool is_colwise) { + auto block_size = BLOCK_SIZE(1, 1); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + block_size = MXFP8_BLOCK_SIZE; + } else if (scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) { + block_size = NVFP4_BLOCK_SIZE; + } else { + NVTE_ERROR("Unsupported scaling_mode = ", static_cast(scaling_mode)); + } + auto block_x = is_colwise ? block_size.y : block_size.x; + auto block_y = is_colwise ? block_size.x : block_size.y; + auto alignment_x = is_colwise ? BLOCK_SCALE_ALIGNMENT.y : BLOCK_SCALE_ALIGNMENT.x; + auto alignment_y = is_colwise ? BLOCK_SCALE_ALIGNMENT.x : BLOCK_SCALE_ALIGNMENT.y; NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M); NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N); diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c8fb713d7d..07e9aec7e9 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t { DELAYED_TENSOR_SCALING = 1, MXFP8_1D_SCALING = 2, CURRENT_TENSOR_SCALING = 3, + NVFP4_1D_SCALING = 4, + NVFP4_2D_SCALING = 5, }; inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { @@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); } +inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: @@ -70,22 +77,32 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING: return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; break; + case JAXX_Scaling_Mode::NVFP4_1D_SCALING: + return NVTEScalingMode::NVTE_NVFP4_1D_SCALING; + break; + case JAXX_Scaling_Mode::NVFP4_2D_SCALING: + // TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization + return NVTEScalingMode::NVTE_NVFP4_1D_SCALING; + break; default: NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); break; } } -constexpr struct BlockSize { +struct BLOCK_SIZE { size_t x; size_t y; -} MXFP8_BLOCK_SIZE{1, 32}; -constexpr struct Alignment { - size_t x; - size_t y; -} MXFP8_ALIGNMENT{128, 4}; + constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {} +}; + +constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32}; +constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16}; + +constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4}; -std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +std::vector get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N, + bool is_colwise); template void hash_combine(int64_t &seed, const T &v, Rest... rest) { diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index f6b1acd439..d740df0e2a 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -76,6 +76,11 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + // Amax + dict["te_rht_amax_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + return dict; } @@ -106,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("kFloat16", DType::kFloat16) .value("kBFloat16", DType::kBFloat16) .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); + .value("kFloat8E5M2", DType::kFloat8E5M2) + .value("kFloat8E8M0", DType::kFloat8E8M0) + .value("kFloat4E2M1", DType::kFloat4E2M1); pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) @@ -165,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) + .value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING) + .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .export_values(); pybind11::enum_(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 05260741b6..a45a698822 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -5,8 +5,11 @@ ************************************************************************/ #include +#include + #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" #include "xla/ffi/api/c_api.h" @@ -15,7 +18,7 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, + DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, QuantizeLayout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; @@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ // this function. We pass a dummy pointer as a workaround. int temp = 0; + bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto scale_shape = std::vector{1}; // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); - if (is_fp8_dtype(out_dtype)) { - output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, - std::vector{1}); + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + if (is_nvfp4) + scale_shape = get_block_scale_shape(scaling_mode, batch_size, hidden_size, false); + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), scale_dtype, + scale_shape); } } @@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter - if (is_fp8_dtype(out_dtype)) { - output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, - std::vector{1}); + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + if (is_nvfp4) + scale_shape = + get_block_scale_shape(scaling_mode, hidden_size, batch_size, false); //Transpose + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), scale_dtype, + scale_shape); } } - if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Buffer_Type amax_buf, Result_Type output_buf, - Result_Type output_trans_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, - bool is_dbias, int64_t flatten_axis) { + Buffer_Type amax_buf, Buffer_Type sr_rng_state, + Buffer_Type post_rht_amax_buf, Buffer_Type rht_matrix_buf, + Result_Type output_buf, Result_Type output_trans_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis, + bool stochastic_rounding, bool use_rht) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); - NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization."); + NVTE_CHECK(is_fp8_dtype(out_dtype) || is_fp4_dtype(out_dtype), + "Output datatype must be FP8 or FP4 for quantization."); auto *input = input_buf.untyped_data(); @@ -112,41 +127,106 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; + + NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4."); + NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling"); if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); - if (is_fp8_dtype(out_dtype)) { - if (is_tensor_scaling) { - float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf.untyped_data()); - float *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); - NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax == updated_amax && amax != nullptr, - "amax must be provided for delayed tensor scaling"); - output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{1}); - } else { - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), - product(scale_inv_buf->dimensions(), flatten_axis, - scale_inv_buf->dimensions().size())}); - } + if (is_tensor_scaling) { + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); } } + if (is_nvfp4) { + float *amax = reinterpret_cast(amax_buf.untyped_data()); + NVTE_CHECK(amax != nullptr, "amax must be provided for NVFP4"); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } + + QuantizationConfigWrapper quant_config{}; + if (scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) { + quant_config.set_nvfp4_2d_quantization(true); + } + + // Stochastic rounding + quant_config.set_stochastic_rounding(stochastic_rounding); + TensorWrapper sr_rng_state_tensor(sr_rng_state.untyped_data(), std::vector{2}, + DType::kInt64); + if (stochastic_rounding) { + NVTE_CHECK(sr_rng_state.size_bytes() == 2 * sizeof(uint64_t), + "rng_state must be of type int64[2]"); + NVTE_CHECK(sr_rng_state.untyped_data() != nullptr, "rng_state must be provided for SR"); + quant_config.set_rng_state(sr_rng_state_tensor.data()); + } + if (quantize_layout == QuantizeLayout::COLWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + if (is_nvfp4 && use_rht) { + if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + // Do regular rowwise quantization without RHT + nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); + } + + TensorWrapper out_transpose(get_nvte_scaling_mode(scaling_mode)); + + // nvte_hadamard_transform_cast_fusion_columnwise expects the colwise data to be populated in the rowwise buffers on TensorWrapper + out_transpose.set_rowwise_data(output_trans, out_dtype, output_trans_shape); + auto const colwise_flatten_axis = output_trans_buf->dimensions().size() - flatten_axis; + out_transpose.set_rowwise_scale_inv( + colwise_scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), + std::vector{product(colwise_scale_inv_buf->dimensions(), 0, colwise_flatten_axis), + product(colwise_scale_inv_buf->dimensions(), colwise_flatten_axis, + colwise_scale_inv_buf->dimensions().size())}); + + float *post_rht_amax = reinterpret_cast(post_rht_amax_buf.untyped_data()); + NVTE_CHECK(post_rht_amax != nullptr, "Post-RHT colwise amax must be provided for NVFP4"); + out_transpose.set_amax(post_rht_amax, DType::kFloat32, std::vector{1}); + + bool const eligible_for_rht_cast_fusion = + input_tensor.dtype() == DType::kBFloat16 && m % 64 == 0 && n % 128 == 0; + NVTE_CHECK(eligible_for_rht_cast_fusion, "RHT cast fusion conditions not met"); + + NVTE_CHECK( + convert_ffi_datatype_to_te_dtype(rht_matrix_buf.element_type()) == DType::kBFloat16, + "RHT matrix must be bf16"); + NVTE_CHECK(rht_matrix_buf.dimensions().size() == 2 && rht_matrix_buf.dimensions()[0] == 16 && + rht_matrix_buf.dimensions()[1] == 16, + "RHT matrix must be 16x16"); + TensorWrapper rht_matrix_tensor(rht_matrix_buf.untyped_data(), std::vector{16, 16}, + DType::kBFloat16); + + nvte_hadamard_transform_cast_fusion_columnwise(input_tensor.data(), out_transpose.data(), + rht_matrix_tensor.data(), quant_config, + stream); + + return ffi_with_cuda_error_check(); + } + + bool const is_colwise_transposed = + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4; + auto &tmp_shape = is_colwise_transposed ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf; @@ -156,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); } else { + auto colwise_flatten_axis = flatten_axis; + if (is_colwise_transposed) { + // convert flatten_axis from N layout to T layout + colwise_flatten_axis = tmp_buf->dimensions().size() - flatten_axis; + } output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{ - product(tmp_buf->dimensions(), 0, flatten_axis), - product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + product(tmp_buf->dimensions(), 0, colwise_flatten_axis), + product(tmp_buf->dimensions(), colwise_flatten_axis, tmp_buf->dimensions().size())}); } } - if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { - output_tensor.set_amax(nullptr, DType::kFloat32, std::vector{1}); - } - auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); if (is_dbias) { + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NVFP4_2D_SCALING, + "DBias quantization is not supported for NVFP4_2D_SCALING as fused dbias API cannot " + "take quant_config as input."); nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), workspace_tensor.data(), stream); } else { - nvte_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); } return ffi_with_cuda_error_check(); } @@ -186,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Arg() // input .Arg() // scale .Arg() // amax + .Arg() // sr_rng_state + .Arg() // colwise amax + .Arg() // rht matrix .Ret() // output .Ret() // colwise output .Ret() // scale_inv @@ -196,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Attr("scaling_mode") .Attr("q_layout") .Attr("is_dbias") - .Attr("flatten_axis"), + .Attr("flatten_axis") + .Attr("stochastic_rounding") + .Attr("use_rht"), FFI_CudaGraph_Traits); Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -346,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty sinv_size = 1; } else { const bool is_colwise = false; - auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise); out_i.set_rowwise_scale_inv(static_cast(sinv_ptr), sinv_dtype, sinv_shape_i); sinv_size = product(sinv_shape_i); } @@ -365,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty colwise_sinv_size = 1; } else { const bool is_colwise = true; - auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise); out_i.set_columnwise_scale_inv(static_cast(colwise_sinv_ptr), sinv_dtype, sinv_shape_i); colwise_sinv_size = product(sinv_shape_i); diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 28525a22a9..44c73a5b1e 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -16,7 +16,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex -from .cpp_extensions.quantization import AmaxScope +from .cpp_extensions.amax import AmaxScope from .quantize import ( ScaledTensorFactory, ScalingMode, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 76865f7c12..c54ecb236f 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -15,7 +15,6 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from transformer_engine.common import recipe from ..dense import dense @@ -35,10 +34,9 @@ from ..quantize import ( QuantizerFactory, get_quantize_config, - QuantizeMeta, QuantizeMetaSet, - ScalingMode, TensorSource, + get_quantize_config_with_recipe, ) PRNGKey = Any @@ -353,40 +351,32 @@ def generate_quantizer_set( Generate a set of FP8 meta for a GEMM. """ - def generate_quantize_meta(quantizer_name: str): - collection_name = ( - variable_collection - if variable_collection is not None - else get_quantize_config().COLLECTION_NAME - ) - scale = self.variable( - collection_name, - f"{quantizer_name}{postfix}_scale", - jnp.ones, - (1,), - jnp.float32, - ).value - amax_history = self.variable( - collection_name, - f"{quantizer_name}{postfix}_amax_history", - jnp.zeros, - (get_quantize_config().AMAX_HISTORY_LEN,), - jnp.float32, - ).value - return QuantizeMeta(scale=scale, amax_history=amax_history) - - if get_quantize_config().get_scaling_mode( - TensorSource.X - ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling): - x_meta = generate_quantize_meta("x") - kernel_meta = generate_quantize_meta("kernel") - grad_meta = generate_quantize_meta("grad") - quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) - kwargs = {"quantize_meta_set": quantize_meta_set} + collection_name = ( + variable_collection + if variable_collection is not None + else get_quantize_config().COLLECTION_NAME + ) + + if fp8_recipe is None: + quantize_config = get_quantize_config() else: - kwargs = {} + quantize_config = get_quantize_config_with_recipe(fp8_recipe) - quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs) + x_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.X, "x" + ) + kernel_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.KERNEL, "kernel" + ) + grad_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.DGRAD, "grad" + ) + + quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) + + quantizer_set = QuantizerFactory.create_set( + fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set + ) return quantizer_set diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 136f43df41..705c742326 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -16,7 +16,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex -from .cpp_extensions.quantization import AmaxScope +from .cpp_extensions.amax import AmaxScope from .quantize import ( QuantizerSet, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index c43430cf36..100848fdd5 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,7 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex -from .cpp_extensions.quantization import AmaxScope +from .cpp_extensions.amax import AmaxScope from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index 11f692917f..9616965c75 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -14,5 +14,6 @@ from .dequantizer import * from .scaling_modes import * from .metadata import * +from .hadamard import * from .helper import * from .device_utils import * diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 9d46c3c300..b4da6f3bed 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -15,6 +15,8 @@ import jax.numpy as jnp from .scaling_modes import ScalingMode +from .hadamard import apply_rht, should_use_rht + __all__ = ["ScalingModeToDequantizerMap"] @@ -119,7 +121,7 @@ def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatte 0 < flatten_axis < len(data_shape) ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" scale_shape = scaling_mode.get_scale_shape( - data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) data = data.reshape( @@ -161,10 +163,99 @@ def dequantize(scaled_tensor): ) +class NVFP4Dequantizer(Dequantizer): + """NVFP4 Dequantizer Class. + + This class provides static methods for dequantizing tensors that have been + quantized using NVFP4 scaling modes. + """ + + @staticmethod + def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis): + """Dequantize a tensor using block scaling. + + Args: + data: The quantized tensor data + scale_inv: The inverse scaling factors + amax: The maximum absolute value of the tensor + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D + + Returns: + The dequantized tensor + """ + + DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32) + tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) + + data = data.astype(jnp.float32) + scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv + data_layout = "T" if is_colwise else "N" + + data_shape = data.shape + flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + scale_shape = scaling_mode.get_scale_shape( + data_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + # expect the flatten_axis wrt the N layout + flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + + data = data.reshape( + *data_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *data_shape[flatten_axis:-1], + scale_shape[-1], + int(data_shape[-1] / scale_shape[-1]), + ) + + scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1)) + out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape) + + # Apply inverse of RHT if needed + use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise) + if use_rht: + out = apply_rht(out, inverse=True) + + return out + + @staticmethod + def dequantize(scaled_tensor): + """Dequantize a tensor using block scaling. + + Args: + scaled_tensor: The quantized tensor to dequantize + + Returns: + The dequantized tensor + """ + return NVFP4Dequantizer._dequantize_func( + scaled_tensor.data, + scaled_tensor.scale_inv, + scaled_tensor.amax, + scaled_tensor.dq_dtype, + scaled_tensor.scaling_mode, + scaled_tensor.is_colwise, + scaled_tensor.flatten_axis, + ) + + ScalingModeToDequantizerMap = { ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer, + ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer, ScalingMode.NO_SCALING: NoopDequantizer, } @@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor): ) padded_scale_shape_i = scaling_mode.get_scale_shape( data_shape_i, - grouped_scaled_tensor.is_colwise, + is_colwise=grouped_scaled_tensor.is_colwise, is_padded=True, flatten_axis=flatten_axis, ) unpadded_scale_shape_i = scaling_mode.get_scale_shape( data_shape_i, - grouped_scaled_tensor.is_colwise, + is_colwise=grouped_scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis, ) diff --git a/transformer_engine/jax/quantize/hadamard.py b/transformer_engine/jax/quantize/hadamard.py new file mode 100644 index 0000000000..c0b74ef75e --- /dev/null +++ b/transformer_engine/jax/quantize/hadamard.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Randomized Hadamard Transform (RHT) utilities for JAX.""" +import jax.numpy as jnp + +from .scaling_modes import ScalingMode + + +def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool: + """Determine if RHT (Randomized Hadamard Transform) should be used. + + Args: + scaling_mode: The scaling mode of the tensor. + is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided. + q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided. + + Returns: + bool: True if RHT should be used, False otherwise. + """ + # Delayed import to avoid circular dependencies + from .quantizer import QuantizeLayout + + assert (is_colwise is None) != ( + q_layout is None + ), "Exactly one of is_colwise or q_layout must be provided." + + if q_layout is not None: + is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE} + + return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise + + +def get_wgrad_sign_vector() -> list[int]: + """Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization.""" + return [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1] + + +def get_sign_from_vector(vector: list[int]) -> int: + """Convert a sign vector to a bitmask integer.""" + mask = 0 + for i, v in enumerate(vector): + mask |= (v == -1) << i + return mask + + +def apply_rht(x: jnp.ndarray, inverse=False) -> jnp.ndarray: + """Apply the Randomized Hadamard Transform (RHT) to the input tensor.""" + h = get_rht_matrix() + block_size = 16 + if inverse: + h = jnp.linalg.inv(h.astype(jnp.float32)).astype(jnp.bfloat16) + # TODO(jberchtold): These reshapes will break partitioning, fixme + return (x.reshape(-1, block_size) @ h).reshape(x.shape) + + +def get_rht_matrix() -> jnp.ndarray: + """Get the Randomized Hadamard Transform (RHT) matrix used in NVFP4 weight gradient quantization. + + Returns: + A (16, 16) bfloat16 matrix representing the RHT. This matrix is pre-multiplied by the random sign mask. + """ + import scipy + + block_size = 16 + h = jnp.array(scipy.linalg.hadamard(block_size)) + + # Apply the random sign mask + s = jnp.array(get_wgrad_sign_vector(), dtype=jnp.int32) + h = jnp.diag(s) @ h + + return (h / jnp.sqrt(block_size)).astype(jnp.bfloat16) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 67f0a68c6a..70611cbea9 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -11,9 +11,12 @@ from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple, Dict, Union, Sequence, Type -from functools import reduce +from typing import Optional, Tuple, Dict, Union, Sequence, Type, List +from functools import reduce, lru_cache import operator +from importlib.metadata import version as get_pkg_version +import warnings +from packaging.version import Version as PkgVersion import jax import jax.numpy as jnp @@ -21,18 +24,27 @@ from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version from transformer_engine.common import recipe -from transformer_engine.jax.sharding import global_shard_guard, MeshResource - +from transformer_engine.jax.sharding import ( + global_shard_guard, + MeshResource, + num_of_devices, + get_all_mesh_axes, + with_sharding_constraint, +) + +from .metadata import QuantizeMeta from .scaling_modes import ScalingMode -from .. import cpp_extensions as tex from .device_utils import get_device_compute_capability __all__ = [ "get_quantize_config", + "get_quantize_config_with_recipe", "fp8_autocast", "is_fp8_available", + "is_scaling_mode_supported", + "get_supported_scaling_modes", + "get_supported_quantization_recipes", "update_collections", - "get_delayed_scaling", "apply_padding_to_scale_inv", "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", @@ -41,11 +53,23 @@ _is_fp8_available = None _reason_for_no_fp8 = "" +_is_scaling_mode_supported = None +_reason_for_no_scaling_mode = "" Collection = Union[Dict, FrozenDict] NVTE_FP8_COLLECTION_NAME = "fp8_metas" +@lru_cache(maxsize=None) +def _jax_version_meet_requirement(version: str): + """ + Helper function checking if required JAX version is available + """ + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. @@ -55,8 +79,6 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ - if gpu_arch >= 90: # hopper and above - return True, "" if gpu_arch < 89: # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_cublasLt_version() < 120103: @@ -75,20 +97,31 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ - if gpu_arch >= 100: # blackwell and above - return True, "" if gpu_arch < 99: # pre-blackwell return False, "Device compute capability 9.9 or higher required for MXFP8 execution." if get_cublasLt_version() < 120800: return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution." - if get_cuda_version() < 12010: + if get_cuda_version() < 12080: return False, "Cuda version 12.8 or higher required for MXFP8 execution." - if not tex.jax_version_meet_requirement("0.5.3"): + if not _jax_version_meet_requirement("0.5.3"): return False, "Jax version 0.5.3 or higher required for MXFP8 execution." return True, "" -def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: +def _check_fp4_support(gpu_arch) -> Tuple[bool, str]: + """Check if FP4 is supported for the given GPU architecture.""" + if gpu_arch < 100: # pre-blackwell + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + if get_cublasLt_version() < 120800: + return False, "CublasLt version 12.8.0 or higher required for NVFP4 execution." + if get_cuda_version() < 12080: + return False, "Cuda version 12.8 or higher required for NVFP4 execution." + if not _jax_version_meet_requirement("0.5.3"): + return False, "Jax version 0.5.3 or higher required for NVFP4 execution." + return True, "" + + +def _check_scaling_support(scaling_mode: ScalingMode, gpu_id: int) -> Tuple[bool, str]: """Check if FP8 is supported for the given scaling mode and GPU. Args: @@ -101,9 +134,35 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: gpu_arch = get_device_compute_capability(gpu_id) if scaling_mode.is_tensor_scaling(): return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + if scaling_mode.is_mxfp8_scaling: return _check_block_scaling_fp8_support(gpu_arch) - return (False, "Unsupported scaling_mode!") + if scaling_mode.is_nvfp4_scaling: + return _check_fp4_support(gpu_arch) + return (True, "") # NO_SCALING is always supported + + +def is_scaling_mode_supported( + scaling_mode=ScalingMode.NO_SCALING, + gpu_id=None, +) -> Tuple[bool, str]: + """Check if the given scaling mode is available for the given GPU.""" + if gpu_id is not None: + return _check_scaling_support(scaling_mode, gpu_id) + + global _is_scaling_mode_supported, _reason_for_no_scaling_mode + if _is_scaling_mode_supported is None: + _is_scaling_mode_supported = {} + _reason_for_no_scaling_mode = {} + if scaling_mode not in _is_scaling_mode_supported: + _is_scaling_mode_supported[scaling_mode] = True + _reason_for_no_scaling_mode[scaling_mode] = "" + for local_gpu_id in range(len(jax.local_devices())): + ret, msg = _check_scaling_support(scaling_mode, local_gpu_id) + if ret is False: + _is_scaling_mode_supported[scaling_mode] = ret + _reason_for_no_scaling_mode[scaling_mode] = msg + return ret, msg + return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode] def is_fp8_available( @@ -119,26 +178,36 @@ def is_fp8_available( Returns: A tuple of (bool, str) indicating availability and any error message """ - if gpu_id is not None: - return _check_fp8_support(scaling_mode, gpu_id) - - global _is_fp8_available, _reason_for_no_fp8 - if _is_fp8_available is None: - _is_fp8_available = {} - _reason_for_no_fp8 = {} - - if scaling_mode not in _is_fp8_available: - _is_fp8_available[scaling_mode] = True - _reason_for_no_fp8[scaling_mode] = "" - # JAX doesn't provide the local GPU id. - for local_gpu_id in range(len(jax.local_devices())): - ret, msg = _check_fp8_support(scaling_mode, local_gpu_id) - if ret is False: - _is_fp8_available[scaling_mode] = ret - _reason_for_no_fp8[scaling_mode] = msg - return ret, msg - - return _is_fp8_available[scaling_mode], _reason_for_no_fp8[scaling_mode] + warnings.warn( + "is_fp8_available is deprecated. Use is_scaling_mode_supported instead.", DeprecationWarning + ) + return is_scaling_mode_supported(scaling_mode=scaling_mode, gpu_id=gpu_id) + + +# TODO(Phuong): make the infrastruture to support NO_SCALING +def get_supported_scaling_modes() -> List[ScalingMode]: + """Get all supported quantization scaling modes.""" + return [ + scaling_mode + for scaling_mode in ScalingMode + if is_scaling_mode_supported(scaling_mode=scaling_mode)[0] + and scaling_mode != ScalingMode.NO_SCALING + ] + + +def get_supported_quantization_recipes() -> List[recipe.Recipe]: + """Get all supported quantization recipes.""" + # We don't support all the recipes TE/Common supports yet + # return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()] + all_recipes = [ + recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), + recipe.MXFP8BlockScaling(), + recipe.NVFP4BlockScaling(), + ] + return [ + recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0] + ] def _format2dtypes(format_: recipe.Format): @@ -156,6 +225,8 @@ def _format2dtypes(format_: recipe.Format): return jnp.float8_e5m2, jnp.float8_e5m2 if format_ == recipe.Format.HYBRID: return jnp.float8_e4m3fn, jnp.float8_e5m2 + if format_ == recipe.Format.E2M1: + return jnp.float4_e2m1fn, jnp.float4_e2m1fn return jnp.bfloat16, jnp.bfloat16 @@ -193,7 +264,6 @@ class BaseQuantizeConfig(ABC): INITIALIZED: Whether the config has been initialized MARGIN: Margin value for quantization COLLECTION_NAME: Name of the collection for quantization metadata - FP8_FORMAT: FP8 format to use FWD_DTYPE: Forward pass data type BWD_DTYPE: Backward pass data type FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass @@ -207,28 +277,26 @@ class BaseQuantizeConfig(ABC): INITIALIZED = False MARGIN: float = 0.0 COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME - FP8_FORMAT: recipe.Format = recipe.Format.HYBRID - FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] - BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] + FWD_DTYPE: DType = None + BWD_DTYPE: DType = None FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False INFERENCE_MODE: bool = False # DelayedScaling + # TODO(Phuong): move these two into DelayedScalingQuantizeConfig AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: - """Initialize the quantization configuration. + """Initialize the quantization configuration from a given recipe. Args: fp8_recipe: The FP8 recipe to use for initialization """ self.INITIALIZED = True - self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 - self.FP8_FORMAT = fp8_recipe.fp8_format - self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT) + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp8_format) def is_fp8_enabled(self) -> bool: """Check if FP8 quantization is enabled. @@ -249,6 +317,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: The scaling mode for the specified usage type. """ + @abstractmethod + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + def is_supported(self) -> tuple[bool, str]: """Check if this QuantizeConfig class is supported on the available devices. @@ -261,7 +350,7 @@ def is_supported(self) -> tuple[bool, str]: kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL) grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD) for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]: - is_supported, reason = is_fp8_available(scaling_mode=scaling_mode) + is_supported, reason = is_scaling_mode_supported(scaling_mode=scaling_mode) if not is_supported: return is_supported, reason return True, None @@ -281,6 +370,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.NO_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + class DelayedScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for delayed scaling FP8 recipe. @@ -299,6 +409,7 @@ def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: AssertionError: If recipe parameters are not supported """ super().initialize_from_recipe(fp8_recipe) + self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 assert fp8_recipe.amax_compute_algo in [ "max", @@ -323,6 +434,41 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.DELAYED_TENSOR_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + scale = module.variable( + collection_name, + f"{quantizer_name}{postfix}_scale", + jnp.ones, + (1,), + jnp.float32, + ).value + amax_history = module.variable( + collection_name, + f"{quantizer_name}{postfix}_amax_history", + jnp.zeros, + (self.AMAX_HISTORY_LEN,), + jnp.float32, + ).value + return QuantizeMeta(scale=scale, amax_history=amax_history) + class CurrentScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for current scaling FP8 recipe. @@ -344,6 +490,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.CURRENT_TENSOR_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + class BlockScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for block scaling FP8 recipe. @@ -365,6 +532,91 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.MXFP8_1D_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + + +class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): + """Configuration class for NVFP4 scaling recipe. + + This class provides specific initialization and finalization for NVFP4 scaling quantization mode. + """ + + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize block scaling FP8 configuration. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + """ + self.INITIALIZED = True + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format) + self.AMAX_HISTORY_LEN = 0 + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + if tensor_source == TensorSource.KERNEL: + return ScalingMode.NVFP4_2D_SCALING + # for x and grad + return ScalingMode.NVFP4_1D_SCALING + + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + if tensor_source != TensorSource.DGRAD: + # Only DGRAD uses stochastic rounding + return QuantizeMeta() + + # TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it. + sr_jax_rng = module.make_rng("sr_rng") + # Get a unique key for this quantizer + sr_jax_rng = jax.jit(jax.random.fold_in)( + sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max + ) + + # Generate 4 random uint32 values from the JAX PRNG key + sr_jax_rng_state = jax.random.randint( + sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 + ).view(jnp.uint32) + sr_jax_rng_state = with_sharding_constraint( + sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None) + ) + + return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state) + _QUANTIZE_CONFIG = NoOpQuantizeConfig() @@ -377,7 +629,7 @@ def get_quantize_config(): def get_quantize_config_class( fp8_recipe: recipe.Recipe, ) -> Type[BaseQuantizeConfig]: - """Get the quantization configuration based on the FP8 recipe. + """Get the quantization configuration class based on the FP8 recipe. Args: fp8_recipe: The FP8 recipe to use for initialization @@ -390,9 +642,18 @@ def get_quantize_config_class( return BlockScalingQuantizeConfig if isinstance(fp8_recipe, recipe.Float8CurrentScaling): return CurrentScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.NVFP4BlockScaling): + return NVFP4ScalingQuantizeConfig raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") +def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe): + """Get the quantization configuration object based on the FP8 recipe.""" + config = get_quantize_config_class(fp8_recipe)() + config.initialize_from_recipe(fp8_recipe) + return config + + @contextmanager def fp8_autocast( enabled: bool = False, @@ -457,31 +718,6 @@ def fp8_autocast( _QUANTIZE_CONFIG = old_quantize_config -def get_delayed_scaling(): - r""" - Obtain an instance of DelayedScaling which is set via fp8_autocast. - - .. note:: - We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len` - , and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in - recipe.DelayedScaling would be returned as the default values. - - Returns - ------- - delay_scaling : DelayedScaling - an instance of DelayedScaling which is set via fp8_autocast. - """ - amax_compute_algo = ( - "max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" - ) - return recipe.DelayedScaling( - margin=int(get_quantize_config().MARGIN), - fp8_format=get_quantize_config().FP8_FORMAT, - amax_history_len=get_quantize_config().AMAX_HISTORY_LEN, - amax_compute_algo=amax_compute_algo, - ) - - def update_collections(new: Collection, original: Collection) -> Collection: r"""Update collections with new values while preserving original structure. diff --git a/transformer_engine/jax/quantize/metadata.py b/transformer_engine/jax/quantize/metadata.py index 6374502165..11a349ed7d 100644 --- a/transformer_engine/jax/quantize/metadata.py +++ b/transformer_engine/jax/quantize/metadata.py @@ -9,23 +9,29 @@ scale factors and amax history for different tensor types. """ from dataclasses import dataclass -import jax.numpy as jnp __all__ = ["QuantizeMeta", "QuantizeMetaSet"] -@dataclass class QuantizeMeta: """Metadata for quantization parameters. - Attributes: + For Delayed Scaling recipe: scale: The scaling factor for quantization amax_history: History of maximum absolute values + + For NVFP4 recipe with Stochastic Rounding: + sr_rng_state: The state of the stochastic rounding RNG + """ - scale: jnp.ndarray - amax_history: jnp.ndarray + def __init__(self, **kwargs): + self._kwargs = kwargs + + def get_kwargs_dictionary(self): + """Get the metadata as a dictionary.""" + return self._kwargs @dataclass diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 306603bbe1..7198014f2e 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -19,6 +19,7 @@ from transformer_engine.common import recipe from .scaling_modes import ScalingMode +from .hadamard import apply_rht, should_use_rht from .tensor import ( ScaledTensor, ScaledTensor1x, @@ -28,7 +29,7 @@ ) from .helper import ( get_quantize_config, - get_quantize_config_class, + get_quantize_config_with_recipe, AmaxComputeAlgo, TensorSource, ) @@ -66,6 +67,7 @@ def compute_scale_from_amax( sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) + assert sf.shape == (1,) return sf @@ -155,7 +157,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> """ def quantize( - self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs + self, x, is_rowwise=None, is_colwise=None, dq_dtype=None, flatten_axis=-1, **kwargs ) -> ScaledTensor: """Quantize a tensor using the internal _quantize_func(). @@ -170,6 +172,18 @@ def quantize( A ScaledTensor1x or ScaledTensor2x containing the quantized data """ del kwargs + + is_rowwise = ( + is_rowwise + if is_rowwise is not None + else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) + ) + is_colwise = ( + is_colwise + if is_colwise is not None + else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) + ) + if (is_rowwise and is_colwise) or self.is_2x2x(): rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = self._quantize_func( @@ -380,6 +394,7 @@ def _quantize_func( clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / self.scale amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) + # Note, this updating of amax here will only be called once because the "quantize" method impl inherited from CurrentScaleQuantizer only calls _quantize_func once then transposes the result for colwise quantization. So we don't have to worry about update being called twice for 2x2x quantization. self.update(amax) return ScaledTensorFactory.create_1x( data=clipped_scaled_x, @@ -494,7 +509,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> dq_dtype = dq_dtype if dq_dtype is not None else x.dtype x_shape = x.shape scale_shape = self.scaling_mode.get_scale_shape( - x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + x_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) scale_dtype = self.scaling_mode.get_scale_dtype() x = x.reshape( @@ -563,6 +578,221 @@ def _e8m0_to_dtype(self, x, dtype): return new_x.astype(dtype) +@register_pytree_node_class +@dataclass +class NVFP4Quantizer(Quantizer): + """Quantizer implementation using current scaling. + + This quantizer uses current scaling mode with float32 scales + + Attributes: + scaling_mode: Set to NVFP4_1D_SCALING or NVFP4_2D_SCALING + q_layout: Quantization axis + data_layout: Data layout string (default: "NT") + stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled. + """ + + scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE + data_layout: str = "NT" + stochastic_rounding_rng_state: Optional[jnp.ndarray] = None + + def __post_init__(self): + assert ( + self.q_dtype == jnp.float4_e2m1fn + ), "NVFP4 quantization must use a q_dtype of float4_e2m1fn" + assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes" + + def _apply_stochastic_rounding(self, x): + assert ( + self.stochastic_rounding_rng_state is not None + ), "Stochastic rounding RNG state is not initialized" + assert self.stochastic_rounding_rng_state.shape == ( + 4, + ), "Stochastic rounding RNG state must be of shape (4,)" + assert ( + self.stochastic_rounding_rng_state.dtype == jnp.uint32 + ), "Stochastic rounding RNG state must be of dtype uint32" + + # Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s + key_bits = jnp.array( + [ + self.stochastic_rounding_rng_state[0], + self.stochastic_rounding_rng_state[1], + ], + dtype=jnp.uint32, + ) + key = jax.random.wrap_key_data(key_bits) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2]) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3]) + + abs_x = jnp.abs(x) + sign_x = jnp.sign(x) + + floor = ( + (abs_x >= 0.5) * 0.5 + + (abs_x >= 1) * 0.5 + + (abs_x >= 2) + + (abs_x >= 3) + + (abs_x >= 4) + + (abs_x >= 6) * 2 + ) + ceil = ( + 0.5 + + (abs_x > 0.5) * 0.5 + + (abs_x > 1) * 1 + + (abs_x > 2) + + (abs_x > 3) + + (abs_x > 4) * 2 + ) + frac = (abs_x - floor) / (ceil - floor) + + rand = jax.random.uniform(key, abs_x.shape) + return sign_x * jnp.where(frac >= rand, ceil, floor) + + def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: + """Quantize function helper for block scaling FP8. + + Args: + x: Input tensor to quantize + is_colwise: Whether to use column-wise quantization + dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor + + Returns: + A ScaledTensor1x containing the quantized data + """ + # TODO(Phuong): use quantize_func from JAX + if flatten_axis < 0: + flatten_axis = x.ndim + flatten_axis + assert ( + 0 <= flatten_axis < x.ndim + ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}" + + should_apply_rht = self.scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise + + global_amax = None + if isinstance(x, NoScaleTensor): + global_amax = ( + x.amax if not should_apply_rht else None + ) # RHT changes the amax so don't use precalculated amax for colwise 1D nvfp4 quantization with RHT + x = x.data + + # Transpose if required + rowwise_flatten_axis = flatten_axis + data_layout = self.data_layout[0] + if is_colwise: + x = jnp.transpose(x, (*range(flatten_axis, x.ndim), *range(flatten_axis))) + data_layout = self.data_layout[1] + # convert flatten_axis from N layout to T layout + flatten_axis = x.ndim - flatten_axis + x_shape = x.shape + + if should_use_rht(self.scaling_mode, is_colwise=is_colwise): + # We only apply RHT for 1D colwise nvfp4 + x = apply_rht(x) + + dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + scale_shape = self.scaling_mode.get_scale_shape( + x_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=rowwise_flatten_axis, + ) + scale_dtype = self.scaling_mode.get_scale_dtype() + x = x.reshape( + *x_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *x_shape[flatten_axis:-1], + scale_shape[-1], + int(x_shape[-1] / scale_shape[-1]), + ) + + # Dtype max constants + DATA_DTYPE_MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(scale_dtype).max.astype(jnp.float32) + + # Level 1: Current Tensor Scaling + global_amax = ( + global_amax + if global_amax is not None + else jnp.max(jnp.abs(x)).reshape((1,)).astype(jnp.float32) + ) + tensor_scale = DATA_DTYPE_MAX * SCALE_DTYPE_MAX / global_amax + tensor_scale = jnp.minimum( + tensor_scale, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32) + ) + tensor_scale = jnp.where( + tensor_scale == jnp.array(0.0, dtype=jnp.float32), + jnp.array(1.0, dtype=jnp.float32), + tensor_scale, + ) + tensor_scale_inv = 1.0 / tensor_scale + + # Level 2: Block Scaling + block_amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True).astype( + jnp.float32 + ) + block_scale_inv = jnp.divide(block_amax, DATA_DTYPE_MAX) + block_scale_inv = block_scale_inv * tensor_scale + block_scale_inv = jnp.minimum( + block_scale_inv, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32) + ) + block_scale_inv = jnp.clip(block_scale_inv, -SCALE_DTYPE_MAX, SCALE_DTYPE_MAX) + # We cast block_scale_inv to scale_dtype here to account for any rounding during the cast. This will ensure the quantized data incorporates the rounded scale value into its computation so dequantization is accurate. + block_scale_inv = block_scale_inv.astype(scale_dtype) + # Note, with JIT jax removes this intermediate cast leading to slightly incorrect results during DQ and worse convergence to the original tensor during many samples of Q+SR->DQ. So we use reduce_precision to simulate the cast to scale_dtype. + assert scale_dtype == jnp.float8_e4m3fn, "Only float8_e4m3fn is supported for scale_dtype" + block_scale_inv = jax.lax.reduce_precision(block_scale_inv, 4, 3) + block_scale = jnp.minimum( + jnp.divide(1.0, block_scale_inv.astype(jnp.float32) * tensor_scale_inv), + jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32), + ) + + # Apply scaling + scaled_x = x.astype(jnp.float32) * block_scale + if self.stochastic_rounding_rng_state is not None: + scaled_x = self._apply_stochastic_rounding(scaled_x) + clipped_x = jnp.clip(scaled_x, -DATA_DTYPE_MAX, DATA_DTYPE_MAX) + + # Cast to the right dtype + quantized_data = clipped_x.reshape(x_shape).astype(self.q_dtype) + block_scale_inv = block_scale_inv.reshape(scale_shape).astype(scale_dtype) + + # In the 2D scaling mode, the scale shape is 2D but it needs to be broadcasted to 1D for GEMM. + # TODO(Phuong): expose this broadcast_2d_scale_shape_to_1d option to the + # quantizer.quantize() API + broadcasted_1d_scale_shape = self.scaling_mode.get_scale_shape( + x_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=rowwise_flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + + # Broadcast and tile x to match the target shape + def repeat_to_shape(x, target_shape): + x_shape = x.shape + reps = [int(t // s) for s, t in zip(x_shape, target_shape)] + return jnp.tile(x, reps) + + block_scale_inv = repeat_to_shape(block_scale_inv, broadcasted_1d_scale_shape) + + return ScaledTensorFactory.create_1x( + data=quantized_data, + data_layout=data_layout, + is_colwise=is_colwise, + scale_inv=block_scale_inv, + amax=global_amax, + scaling_mode=self.scaling_mode, + dq_dtype=dq_dtype, + flatten_axis=rowwise_flatten_axis, + ) + + @register_pytree_node_class @dataclass class QuantizerSet: @@ -801,6 +1031,8 @@ class QuantizerFactory: ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.NVFP4_1D_SCALING: NVFP4Quantizer, + ScalingMode.NVFP4_2D_SCALING: NVFP4Quantizer, } @staticmethod @@ -826,7 +1058,6 @@ def create( Returns: A single quantizer or tuple of quantizers """ - # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" if n_groups: if n_quantizers != 1: @@ -887,18 +1118,9 @@ def _create_set( if "quantize_meta_set" in kwargs: quantize_meta_set = kwargs.get("quantize_meta_set") - args_x = { - "scale": quantize_meta_set.x.scale, - "amax_history": quantize_meta_set.x.amax_history, - } - args_kernel = { - "scale": quantize_meta_set.kernel.scale, - "amax_history": quantize_meta_set.kernel.amax_history, - } - args_grad = { - "scale": quantize_meta_set.grad.scale, - "amax_history": quantize_meta_set.grad.amax_history, - } + args_x = quantize_meta_set.x.get_kwargs_dictionary() + args_kernel = quantize_meta_set.kernel.get_kwargs_dictionary() + args_grad = quantize_meta_set.grad.get_kwargs_dictionary() else: args_x = args_kernel = args_grad = {} @@ -919,6 +1141,7 @@ def create_set( bwd_dtype: jnp.dtype = None, is_2x2x: bool = None, n_groups: int = None, + # TODO(jberchtold): rename fp8_recipe to quantization_recipe fp8_recipe: Optional[recipe.Recipe] = None, **kwargs, ) -> tuple[Union[tuple[Quantizer], None]]: @@ -946,21 +1169,24 @@ def create_set( ) if fp8_recipe is not None: - quantize_config = get_quantize_config_class(fp8_recipe)() + quantize_config = get_quantize_config_with_recipe(fp8_recipe) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) - elif scaling_mode is not None: - x_scaling_mode = scaling_mode - kernel_scaling_mode = scaling_mode - grad_scaling_mode = scaling_mode + fwd_dtype = quantize_config.FWD_DTYPE + bwd_dtype = quantize_config.BWD_DTYPE else: - x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) - kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) - grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) + if scaling_mode is not None: + x_scaling_mode = scaling_mode + kernel_scaling_mode = scaling_mode + grad_scaling_mode = scaling_mode + else: + x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) + kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) - fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE - bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE + fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE + bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE if is_2x2x is None: # TODO(Jeremy): check x, kernel, grad separately for 2x if x_scaling_mode.is_1d_block_scaling(): diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index b7828e9315..d490e02752 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -100,10 +100,19 @@ def get_scale_dtype(self) -> jnp.dtype: The data type used for scale tensors """ + @abstractmethod + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + @abstractmethod def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, @@ -112,6 +121,7 @@ def get_scale_shape( Args: data_shape: The shape of the tensor being quantized + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) @@ -156,13 +166,15 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode @@ -183,12 +195,22 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float32 + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return "NN" + def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = True, ) -> Tuple[int, ...]: """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. @@ -201,7 +223,14 @@ def get_scale_shape( Returns: The shape for scale tensors - (1,) """ - del data_shape, is_colwise, is_padded, flatten_axis + del ( + data_shape, + data_layout, + is_colwise, + is_padded, + flatten_axis, + broadcast_2d_scale_shape_to_1d, + ) return (0,) @lru_cache(maxsize=4) @@ -239,18 +268,20 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode """ - del flatten_axis + del flatten_axis, broadcast_2d_scale_shape_to_1d input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -270,25 +301,37 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float32 + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return "NT" + def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = True, ) -> Tuple[int, ...]: """Get the shape for scale tensors in delayed scaling. Args: data_shape: The shape of the tensor being scaled + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True. Returns: The shape for scale tensors - (1,) """ - del is_colwise + del data_layout, is_colwise, broadcast_2d_scale_shape_to_1d if np.prod(data_shape) == 0: return (0,) return (1,) @@ -333,6 +376,7 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -340,11 +384,12 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode """ - del flatten_axis + del flatten_axis, broadcast_2d_scale_shape_to_1d input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -368,14 +413,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): _block_alignment: Alignment requirements for blocks """ - def __init__(self, block_dims: Tuple[int]): + def __init__(self, block_dims: Tuple[int], scale_dtype: jnp.dtype, data_layout: str): """Initialize block scaling mode implementation. Args: block_dims: Dimensions of the scaling blocks + scale_dtype: Data type of the scale tensor + data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. """ self._block_dims = block_dims + self._scale_dtype = scale_dtype self._block_alignment = (128, 4) + self._data_layout = data_layout def get_scale_dtype(self) -> jnp.dtype: """Get the data type for scale tensors in block scaling. @@ -383,7 +432,15 @@ def get_scale_dtype(self) -> jnp.dtype: Returns: The data type used for scale tensors (float8_e8m0fnu) """ - return jnp.float8_e8m0fnu + return self._scale_dtype + + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return self._data_layout def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim): """Remove excess padding from the scale shape and return the shape with respect to the original data shape.""" @@ -411,23 +468,51 @@ def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_ def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = False, ) -> Tuple[int, ...]: """Get the shape for scale tensors in block scaling. Args: data_shape: The shape of the tensor being quantized + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True. Returns: The shape for scale tensors """ + flatten_axis = (len(data_shape) + flatten_axis) % len(data_shape) + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + block_alignment = self._block_alignment if is_padded else (1, 1) + if is_colwise: + assert data_layout == self._data_layout[1], ( + f"Data layout must match colwise layout, received {data_layout} but expected" + f" {self._data_layout[1]}" + ) + else: + assert data_layout == self._data_layout[0], ( + f"Data layout must match rowwise layout, received {data_layout} but expected" + f" {self._data_layout[0]}" + ) + + if is_colwise and self._data_layout[1] == "T": + # TODO(Phuong): rework this hack so that we don't implicitly change is_colwise value + is_colwise = False # now rowwise in T is colwise in N + if flatten_axis < 0: + flatten_axis = len(data_shape) + flatten_axis + # flatten_axis is given wrt N layout, convert to T layout + flatten_axis = len(data_shape) - flatten_axis + if is_colwise: block_y, block_x = self._block_dims alignment_y, alignment_x = block_alignment @@ -435,12 +520,7 @@ def get_scale_shape( block_x, block_y = self._block_dims alignment_x, alignment_y = block_alignment - if flatten_axis < 0: - flatten_axis = len(data_shape) + flatten_axis - assert ( - 0 < flatten_axis < len(data_shape) - ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" - + is_block_2d = block_x > 1 and block_y > 1 assert data_shape[flatten_axis - 1] % block_x == 0, ( f"Data shape {data_shape} should be divisible by block_x {block_x} in axis" f" {flatten_axis - 1}" @@ -449,6 +529,9 @@ def get_scale_shape( data_shape[-1] % block_y == 0 ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1" + if broadcast_2d_scale_shape_to_1d and is_block_2d: + block_x = 1 + flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1) flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1) @@ -575,6 +658,7 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -582,30 +666,41 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode """ + # TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed input_rank = len(input_shape) input_spec = [f"{unique_var}_{i}" for i in range(input_rank)] flatten_axis = (flatten_axis + input_rank) % input_rank - # This implementation needs to be updated for different block dims. - assert self._block_dims == (1, 32) + assert ( + self._block_dims[1] != 1 + ), f"Expect 1D rowwise or 2D block. Got _block_dims={self._block_dims}" + # For 2D block scaling, only support when with broadcast_2d_scale_shape_to_1d + if self._block_dims[0] != 1: + assert self._block_dims[0] == self._block_dims[1] and broadcast_2d_scale_shape_to_1d, ( + f"Got broadcast_2d_scale_shape_to_1d={broadcast_2d_scale_shape_to_1d}," + f" _block_dims={self._block_dims}" + ) + + block_size_1d = self._block_dims[1] # We have to use two different factors in the two CompoundFactors because of Shardy # verifier requirements, even though they are the same. blocksizes = {} colwise_var = f"{unique_var}_None" rowwise_var = f"{unique_var}_None" - if not input_shape[-1] == 32: + if not input_shape[-1] == block_size_1d: rowwise_var = input_spec[-1] + "_compound" input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") - blocksizes["blocksize_x"] = 32 - if not input_shape[flatten_axis - 1] == 32: + blocksizes["blocksize_x"] = block_size_1d + if not input_shape[flatten_axis - 1] == block_size_1d: colwise_var = input_spec[flatten_axis - 1] + "_compound" input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") - blocksizes["blocksize_y"] = 32 + blocksizes["blocksize_y"] = block_size_1d # The rowwise and colwise scale tensors should be sharded the same way as the input. # However, we need to adjust the dimensions where the block scaling factor applies. @@ -632,6 +727,8 @@ class ScalingMode(Enum): - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales + - NVFP4_1D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales + - NVFP4_2D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales - NO_SCALING: No scaling applied """ @@ -639,6 +736,8 @@ class ScalingMode(Enum): DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING + NVFP4_1D_SCALING = JAXX_Scaling_Mode.NVFP4_1D_SCALING + NVFP4_2D_SCALING = JAXX_Scaling_Mode.NVFP4_2D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -662,40 +761,79 @@ def get_scale_dtype(self): """ return self._get_impl().get_scale_dtype() - def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]: + def get_scale_shape_2x( + self, data_shape, is_padded=True, flatten_axis=-1, broadcast_2d_scale_shape_to_1d=False + ) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. Args: data_shape: Shape of the data tensor is_padded: Whether to use padded shapes flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ + data_layout = self._get_impl().get_data_layout() + rowwise_layout = data_layout[0] + assert ( + rowwise_layout == "N" + ), f"For rowwise layout only 'N' is supported, received {rowwise_layout}" + colwise_layout = data_layout[1] + rowwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis + data_shape, + data_layout=rowwise_layout, + is_colwise=False, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, ) + + colwise_data_shape = data_shape + if colwise_layout == "T": + colwise_data_shape = data_shape[flatten_axis:] + data_shape[:flatten_axis] colwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis + colwise_data_shape, + data_layout=colwise_layout, + is_colwise=True, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, ) return (rowwise_scale_shape, colwise_scale_shape) def get_scale_shape( - self, data_shape, is_colwise, is_padded=True, flatten_axis=-1 + self, + data_shape, + data_layout="N", + is_colwise=False, + is_padded=True, + flatten_axis=-1, + broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[int]: """Get the shape for scale tensors in this mode. Args: data_shape: Shape of the data tensor + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether to use column-wise scaling is_padded: Whether to use padded shapes flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The shape for scale tensors """ - return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) + return self._get_impl().get_scale_shape( + data_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, + ) def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: """Get the quantize layout for the tensor usage. @@ -713,6 +851,7 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis=-1, + broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. @@ -720,11 +859,14 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis) + return self._get_impl().get_shardy_sharding_rules( + input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d + ) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 @@ -798,8 +940,64 @@ def is_1d_block_scaling(self) -> bool: Returns: True if the scaling mode is 1D block scaling, False otherwise """ + # Both 1D and 2D NVFP4 scaling are treated as 1D block scaling since the 2D scales are broadcast to 1D because it is required for the GEMM. + return self == ScalingMode.MXFP8_1D_SCALING or self.is_nvfp4_scaling + + @property + def is_block_scaling(self) -> bool: + """Check if this scaling mode is block scaling. + + Returns: + True if the scaling mode is block scaling, False otherwise + """ + # Currently we only have 1D block scaling modes + return self.is_1d_block_scaling() + + def get_compatible_q_dtypes(self) -> set[jnp.dtype]: + """Returns a set of compatible quantized data types for this scaling mode. + + Returns: + A set of compatible quantized data types + """ + if self in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ScalingMode.MXFP8_1D_SCALING, + ): + return {jnp.float8_e5m2, jnp.float8_e4m3fn} + if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING): + return {jnp.float4_e2m1fn} + if self == ScalingMode.NO_SCALING: + return {jnp.float16, jnp.bfloat16, jnp.float32} + raise ValueError(f"Invalid scaling mode: {self}") + + @property + def is_nvfp4_scaling(self) -> bool: + """Check if this scaling mode is NVFP4 scaling. + + Returns: + True if the scaling mode is NVFP4 scaling, False otherwise + """ + return self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING) + + @property + def is_mxfp8_scaling(self) -> bool: + """Check if this scaling mode is NVFP4 scaling. + + Returns: + True if the scaling mode is NVFP4 scaling, False otherwise + """ return self == ScalingMode.MXFP8_1D_SCALING + @property + def is_colwise_transposed(self) -> bool: + """Check if this scaling mode uses transposed layout for column-wise scaling. + + Returns: + True if the scaling mode uses transposed layout for column-wise scaling, False otherwise + """ + return self.is_tensor_scaling() or self.is_nvfp4_scaling + def __eq__(self, other): """Compare this scaling mode with another. @@ -836,9 +1034,20 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { + ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), - # WAR + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(1, 32), + scale_dtype=jnp.float8_e8m0fnu, + data_layout="NN", + ), ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), - ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), + ScalingMode.NVFP4_1D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(1, 16), + scale_dtype=jnp.float8_e4m3fn, + data_layout="NT", + ), + ScalingMode.NVFP4_2D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(16, 16), scale_dtype=jnp.float8_e4m3fn, data_layout="NT" + ), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index dbbac4abcc..2d2d78190f 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -201,13 +201,32 @@ def __post_init__(self): else: unpadded_scale_shape = self.scaling_mode.get_scale_shape( self.data.shape, + data_layout=self.data_layout, is_colwise=self.is_colwise, is_padded=False, - flatten_axis=self.flatten_axis, + # expect the flatten_axis wrt the N layout + flatten_axis=( + self.flatten_axis + if self.data_layout == "N" + else self.data.ndim - self.flatten_axis + ), ) - assert self.scale_inv.shape == unpadded_scale_shape, ( - "Unpadded inverse scale factor has wrong shape, expected" - f" {unpadded_scale_shape} but got {self.scale_inv.shape}." + unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( + self.data.shape, + data_layout=self.data_layout, + is_colwise=self.is_colwise, + is_padded=False, + # expect the flatten_axis wrt the N layout + flatten_axis=( + self.flatten_axis + if self.data_layout == "N" + else self.data.ndim - self.flatten_axis + ), + broadcast_2d_scale_shape_to_1d=True, + ) + assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or" + f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." ) def tree_flatten(self): @@ -583,6 +602,7 @@ def create_2x( colwise_data, colwise_scale_inv, amax=None, + colwise_amax=None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=jnp.bfloat16, data_layout="NN", @@ -612,6 +632,8 @@ def create_2x( """ if amax is None: amax = jnp.empty((1,), dtype=jnp.float32) + if colwise_amax is None: + colwise_amax = amax assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" rowwise_tensor = ScaledTensorFactory.create_1x( @@ -630,10 +652,10 @@ def create_2x( colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, - amax, + colwise_amax, scaling_mode, dq_dtype, - is_colwise=True, + is_colwise=True, # TODO(Phuong): set this correctly data_layout=data_layout[1], flatten_axis=flatten_axis, group_sizes=group_sizes, @@ -649,6 +671,7 @@ def create( colwise_data: jnp.ndarray, colwise_scale_inv: jnp.ndarray, amax=None, + colwise_amax=None, scaling_mode: ScalingMode = ScalingMode.NO_SCALING, dq_dtype: jnp.dtype = jnp.bfloat16, data_layout: str = "NN", @@ -684,6 +707,7 @@ def create( colwise_data, colwise_scale_inv, amax, + colwise_amax, scaling_mode, dq_dtype, data_layout=data_layout, @@ -698,7 +722,7 @@ def create( return ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, - amax, + colwise_amax if colwise_amax is not None else amax, scaling_mode, dq_dtype, is_colwise=is_colwise, From dd9433e7ad28c12f27da9770be54c9c584e85fa0 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Thu, 9 Oct 2025 17:30:35 -0600 Subject: [PATCH 75/78] Don't pickle an empty dict in LayerNorm and pt base modules (#2253) Don't pickle an empty dict in LayerNorm and BasicOperation layers Signed-off-by: Peter St. John Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 103ebf2418..095e3e89e8 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -595,6 +595,9 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: extra[key] = val state[mode]["extra_fp8_variables"] = extra + if not state: + return torch.empty(0, dtype=torch.uint8) + # Serialize state into byte tensor torch.cuda.synchronize() state_serialized = bytearray(pickle.dumps(state)) From 98d354c3d3b2571c7752f85a2fcf97fa6fd2aab9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 11 Oct 2025 12:42:07 +0000 Subject: [PATCH 76/78] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 13 ++++++++++--- .../pytorch/module/layernorm_linear.py | 9 +++++++-- transformer_engine/pytorch/module/linear.py | 9 +++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d258504df1..1815a5ef96 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -219,10 +219,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weights[0].requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weights[0].requires_grad + and fuse_wgrad_accumulation + ): grad_added_to_main_grad_list = [] for weight in weights: if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): @@ -292,7 +297,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: + if ( + ctx.cpu_offloading or ctx.fine_grained_activation_offloading + ) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): origin_weights[i].main_grad = main_grads[i] origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bba7f991a4..f230fc13a0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -445,10 +445,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0ac223bf89..9acd5ad1c6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -413,10 +413,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad From 88c7b050d8acd7fb9c363bff66ed64e9cd04a694 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sat, 11 Oct 2025 06:00:18 -0700 Subject: [PATCH 77/78] add comments Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 6 +++++- transformer_engine/pytorch/module/layernorm_linear.py | 5 ++++- transformer_engine/pytorch/module/linear.py | 5 ++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d258504df1..e0f43986b8 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -211,6 +211,7 @@ def forward( if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) + # Do not offload weights and biases for i in range(num_gemms): weights[i].offloading_activation = False weights_fp8[i].offloading_activation = False @@ -219,9 +220,11 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + "Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading if fine_grained_activation_offloading and weights[0].requires_grad and fuse_wgrad_accumulation: grad_added_to_main_grad_list = [] for weight in weights: @@ -292,6 +295,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + # Restore the attributes main_grad and grad_added_to_main_grad of weights if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): origin_weights[i].main_grad = main_grads[i] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bba7f991a4..adf616f7a7 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -445,9 +445,11 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + "Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True @@ -593,6 +595,7 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. + # Restore the attributes grad_added_to_main_grad of weights if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0ac223bf89..7180636ba4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -413,9 +413,11 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + "Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True @@ -529,6 +531,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) + # Restore the attributes main_grad and grad_added_to_main_grad of weights if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object From fe9bab470cc5fd1a6f575301df425a6f6862d258 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Tue, 21 Oct 2025 01:39:39 -0700 Subject: [PATCH 78/78] remove unused code Signed-off-by: hongbinl --- transformer_engine/pytorch/module/grouped_linear.py | 5 ----- transformer_engine/pytorch/module/layernorm_linear.py | 6 ------ transformer_engine/pytorch/module/linear.py | 5 ----- 3 files changed, 16 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 6b4c35170c..7b15ebf527 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -211,11 +211,6 @@ def forward( if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) - # Do not offload weights and biases - for i in range(num_gemms): - weights[i].offloading_activation = False - weights_fp8[i].offloading_activation = False - biases[i].offloading_activation = False ctx.fine_grained_activation_offloading = fine_grained_activation_offloading if fine_grained_activation_offloading and cpu_offloading: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 75c8a030d9..57f4e25eba 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -435,12 +435,6 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - # Do not offload weights and biases - weight.offloading_activation = False - weightmat.offloading_activation = False - if bias is not None: - bias.offloading_activation = False - ln_weight.offloading_activation = False ctx.fine_grained_activation_offloading = fine_grained_activation_offloading if fine_grained_activation_offloading and cpu_offloading: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ffbe02ae98..7b974e4a7f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -442,11 +442,6 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - # Do not offload weights and biases - weight.offloading_activation = False - weightmat.offloading_activation = False - if bias is not None: - bias.offloading_activation = False # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat,