From 96db8f55b6283393517fa5df19092f88e8eaf911 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 12 Mar 2025 02:55:41 -0700 Subject: [PATCH 01/53] split wgrad for GroupedLinear Signed-off-by: Hongbin Liu --- tests/pytorch/test_numerics.py | 32 +++++++++++---- transformer_engine/pytorch/module/_common.py | 39 ++++++++++++++++++ .../pytorch/module/grouped_linear.py | 41 +++++++++++++------ 3 files changed, 92 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 35f65a75f4..feb4b6b5d0 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -106,6 +106,13 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq recipe.Float8CurrentScaling(), ] +# param_types = [torch.bfloat16] +# batch_sizes = [2] +# all_boolean = [False] +# mask_types = ["causal"] +# fp8_recipes = [ +# recipe.Float8CurrentScaling(), +# ] def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -1445,7 +1452,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret def _test_grouped_linear_accuracy( - block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation + block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw=False ): reset_rng_states() if fp8: @@ -1488,10 +1495,15 @@ def _test_grouped_linear_accuracy( ) loss = out.sum() loss.backward() + if split_bw: + block.wgrad_comp() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] - for p in block.parameters(): + # breakpoint() + for name, p in block.named_parameters(): + # print(f"参数名称: {name}, 参数形状: {p.shape}", {p.main_grad}) + # breakpoint() if p.requires_grad: if getattr(p, "main_grad", None) is not None: outputs.append(p.main_grad) @@ -1508,7 +1520,9 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) -@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) +@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("split_bw", all_boolean) def test_grouped_linear_accuracy( dtype, num_gemms, @@ -1518,6 +1532,8 @@ def test_grouped_linear_accuracy( recipe, fp8_model_params, fuse_wgrad_accumulation, + bias, + split_bw, parallel_mode=None, ): if fp8 and not fp8_available: @@ -1538,18 +1554,19 @@ def test_grouped_linear_accuracy( num_gemms, config.hidden_size, 4 * config.hidden_size, - bias=True, + bias=bias, params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", fuse_wgrad_accumulation=fuse_wgrad_accumulation, + split_bw=split_bw, ).eval() sequential_linear = torch.nn.ModuleList( [ Linear( config.hidden_size, 4 * config.hidden_size, - bias=True, + bias=bias, params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", @@ -1563,7 +1580,8 @@ def test_grouped_linear_accuracy( with torch.no_grad(): for i in range(num_gemms): sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -1573,7 +1591,7 @@ def test_grouped_linear_accuracy( sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) outputs = _test_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw ) # Shoule be bit-wise match diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index c2b525ab55..467964a988 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -10,6 +10,7 @@ from operator import mul as multiply_op import torch +import queue from .. import cpp_extensions as tex from ..constants import TE_DType @@ -216,3 +217,41 @@ def __post_init__(self): """Safeguard reference to the parameter's parent module and initialization function.""" if self.init_fn is None: self.init_fn = get_default_init_method() + + +class WeightGradStore: + + def __init__(self, split_bw = False): + self.context = queue.Queue() + self.enabled = split_bw + + def is_supported(self): + # doesn't support use_bias + return True + + def split_bw(self): + if not self.is_supported(): + return False + return self.enabled + + def enable_split_bw(self): + self.enabled = True + + def disable_split_bw(self): + self.enabled = False + + def put(self, tensor_list, func): + self.context.put([tensor_list, func]) + return + + def pop(self): + if self.context.qsize() > 0: + tensor_list, func = self.context.get() + func(*tensor_list) + else: + rank = torch.distributed.get_rank() + raise Exception(f"Pop empty queue. rank {rank}") + + def assert_empty(self): + rank = torch.distributed.get_rank() + assert self.context.empty(), f"Queue is not empty. rank {rank}" \ No newline at end of file diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9cd52b1e5..44d124521c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -6,6 +6,7 @@ from typing import Union, Optional, Callable, Tuple, List import torch +import functools import transformer_engine_torch as tex @@ -16,6 +17,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ._common import WeightGradStore from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, @@ -65,6 +67,7 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, input_quantizers: List[Quantizer], weight_quantizers: List[Quantizer], output_quantizers: List[Quantizer], @@ -213,6 +216,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() ) + ctx.wgrad_store = wgrad_store # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -294,13 +298,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - # WGRAD - _, grad_biases_, _ = general_grouped_gemm( - inputmats, - grad_output, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), + grouped_gemm_wgrad = functools.partial(general_grouped_gemm, + out_dtype=ctx.activation_dtype, + workspaces=get_multi_stream_cublas_workspace(), layout="NT", grad=True, m_splits=ctx.m_splits, @@ -309,13 +309,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=_2X_ACC_WGRAD, accumulate=accumulate_wgrad_into_param_main_grad, ) - for i in range(ctx.num_gemms): - if grad_biases[i] is None: - grad_biases[i] = grad_biases_[i] - del grad_biases_ + # WGRAD + if ctx.wgrad_store.split_bw(): + # if True: + ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) + else: + _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) + + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ - # Deallocate input tensor - clear_tensor_data(*inputmats) + # Deallocate input tensor + clear_tensor_data(*inputmats) def handle_custom_ddp_from_mcore(weight, wgrad): if ctx.weights_requires_grad: @@ -372,6 +379,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, None, # is_grad_enabled None, # is_grad_enabled *wgrad_list, @@ -445,6 +453,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, + split_bw: bool = False, ) -> None: super().__init__() @@ -465,6 +474,8 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.wgrad_store = WeightGradStore(split_bw) + self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} if tp_group is None: @@ -647,6 +658,7 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, input_quantizers, weight_quantizers, output_quantizers, @@ -677,3 +689,6 @@ def forward( if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out + def wgrad_comp(self): + with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): + self.wgrad_store.pop() \ No newline at end of file From 4d3326e4d478624f2467410b6729823c5d295768 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Mar 2025 10:02:58 +0000 Subject: [PATCH 02/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Hongbin Liu --- tests/pytorch/test_numerics.py | 1 + transformer_engine/pytorch/module/_common.py | 8 ++++---- transformer_engine/pytorch/module/grouped_linear.py | 8 +++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index feb4b6b5d0..b8cd29bf42 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -114,6 +114,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq # recipe.Float8CurrentScaling(), # ] + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 467964a988..77277b8cae 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -221,7 +221,7 @@ def __post_init__(self): class WeightGradStore: - def __init__(self, split_bw = False): + def __init__(self, split_bw=False): self.context = queue.Queue() self.enabled = split_bw @@ -238,8 +238,8 @@ def enable_split_bw(self): self.enabled = True def disable_split_bw(self): - self.enabled = False - + self.enabled = False + def put(self, tensor_list, func): self.context.put([tensor_list, func]) return @@ -254,4 +254,4 @@ def pop(self): def assert_empty(self): rank = torch.distributed.get_rank() - assert self.context.empty(), f"Queue is not empty. rank {rank}" \ No newline at end of file + assert self.context.empty(), f"Queue is not empty. rank {rank}" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 44d124521c..370014c652 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -298,7 +298,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - grouped_gemm_wgrad = functools.partial(general_grouped_gemm, + grouped_gemm_wgrad = functools.partial( + general_grouped_gemm, out_dtype=ctx.activation_dtype, workspaces=get_multi_stream_cublas_workspace(), layout="NT", @@ -311,7 +312,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) # WGRAD if ctx.wgrad_store.split_bw(): - # if True: + # if True: ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) else: _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) @@ -689,6 +690,7 @@ def forward( if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out + def wgrad_comp(self): with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): - self.wgrad_store.pop() \ No newline at end of file + self.wgrad_store.pop() From 94f18925934b488ad708e8dcce51adbc7cc63bde Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 12 Mar 2025 06:31:29 -0700 Subject: [PATCH 03/53] support wgrad split for linear and ln_linear Signed-off-by: Hongbin Liu --- tests/pytorch/test_numerics.py | 126 +++++++++++++++++- transformer_engine/pytorch/module/_common.py | 2 + .../pytorch/module/grouped_linear.py | 2 + .../pytorch/module/layernorm_linear.py | 56 +++++--- transformer_engine/pytorch/module/linear.py | 58 ++++++-- 5 files changed, 205 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b8cd29bf42..ab1c02c90c 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1036,7 +1036,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_granular_accuracy(block, bs, dtype, config): +def _test_granular_accuracy(block, bs, dtype, config, split_bw=False): reset_rng_states() inp_hidden_states = torch.randn( @@ -1052,12 +1052,20 @@ def _test_granular_accuracy(block, bs, dtype, config): out = out[0] loss = out.sum() loss.backward() + if split_bw and hasattr(block, 'wgrad_comp'): + block.wgrad_comp() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] - for p in block.parameters(): + for name, p in block.named_parameters(): + # print(f"参数名称: {name}, 参数形状: {p.shape}", {p.main_grad}) + # breakpoint() if p.requires_grad: - outputs.append(p.grad) + if getattr(p, "main_grad", None) is not None: + outputs.append(p.main_grad) + assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True + else: + outputs.append(p.grad) return outputs @@ -1190,6 +1198,50 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): for te_output, torch_output in zip(te_outputs, torch_outputs): assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) +def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulation): + config = model_configs[model] + + te_linear_ref = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + device="cuda", + split_bw=False, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + te_linear = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + device="cuda", + split_bw=True, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + # Share params + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if bias: + te_linear_ref.bias = Parameter(te_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() + + te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, split_bw=True) + te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, split_bw=False) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @@ -1375,6 +1427,62 @@ def test_layernorm_linear_accuracy( for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("normalization", all_normalizations) +@pytest.mark.parametrize("zero_centered_gamma", all_boolean) +@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) +def test_layernorm_linear_accuracy_split_bw(dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation): + config = model_configs[model] + + ln_linear_ref = LayerNormLinear( + config.hidden_size, + 4 * config.hidden_size, + config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", + split_bw=False, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + ln_linear = LayerNormLinear( + config.hidden_size, + 4 * config.hidden_size, + config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", + split_bw=True, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + + # Share params + with torch.no_grad(): + ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) + if normalization != "RMSNorm": + ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) + ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) + if bias: + ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(ln_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + ln_linear_ref.weight.main_grad = weight.main_grad.clone() + + te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, split_bw=True) + te_outputs_ref = _test_granular_accuracy(ln_linear_ref, bs, dtype, config, split_bw=False) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @@ -1497,7 +1605,11 @@ def _test_grouped_linear_accuracy( loss = out.sum() loss.backward() if split_bw: - block.wgrad_comp() + if isinstance(block, GroupedLinear): + block.wgrad_comp() + else: + for i in range(num_gemms): + block[i].wgrad_comp() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] @@ -1589,7 +1701,7 @@ def test_grouped_linear_accuracy( sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation + sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw ) outputs = _test_grouped_linear_accuracy( grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw @@ -1614,6 +1726,8 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): fp8_model_params=True, parallel_mode=parallel_mode, fuse_wgrad_accumulation=True, + bias=True, + split_bw=False, ) @@ -1629,6 +1743,8 @@ def test_grouped_linear_accuracy_single_gemm(recipe): recipe=recipe, fp8_model_params=True, fuse_wgrad_accumulation=True, + bias=True, + split_bw=False, ) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 77277b8cae..f2e5dd86ad 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -227,6 +227,8 @@ def __init__(self, split_bw=False): def is_supported(self): # doesn't support use_bias + # doesn't support fuse_wgrad_accumulation + # doesn't support return True def split_bw(self): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 370014c652..4a6242cfff 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -692,5 +692,7 @@ def forward( return out def wgrad_comp(self): + if not self.wgrad_store.split_bw(): + return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..ba061cfd77 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -11,6 +11,7 @@ import torch from torch.nn import init +import functools import transformer_engine_torch as tex @@ -48,7 +49,7 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose +from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -83,6 +84,7 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, fuse_wgrad_accumulation: bool, input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], @@ -418,6 +420,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store # Row Parallel Linear if ub_overlap_rs_fprop: @@ -713,15 +716,11 @@ def backward( # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") - wgrad, grad_bias_, *_, rs_out = general_gemm( - ln_out_total, - grad_output, - get_workspace(), + general_gemm_wgrad = functools.partial(general_gemm, + out_dtype=main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, + workspace=get_workspace(), layout="NT", grad=True, - out_dtype=( - main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype - ), bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, @@ -729,24 +728,29 @@ def backward( ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, - bulk_overlap=ctx.ub_bulk_wgrad, + bulk_overlap=ctx.ub_bulk_wgrad, ) - nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") + if ctx.wgrad_store.split_bw(): + ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad) + else: + wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if not ctx.return_layernorm_output: + # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme + clear_tensor_data(ln_out_total) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") + if ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): dgrad = rs_out else: - dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) - - if grad_bias is None: - grad_bias = grad_bias_ - del grad_bias_ - - # Deallocate input tensor - if not ctx.return_layernorm_output: - # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme - clear_tensor_data(ln_out_total) + dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) # Make sure all tensor-parallel communication is finished if ln_out_total_work is not None: @@ -833,6 +837,7 @@ def backward( None, # is_first_microbatch None, # fp8 None, # fp8_calibration + None, # wgrad_store None, # fuse_wgrad_accumulation None, # input_quantizer None, # weight_quantizer @@ -979,6 +984,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, + split_bw: bool = False, ) -> None: super().__init__() @@ -995,6 +1001,8 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma + self.wgrad_store = WeightGradStore(split_bw) + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1344,6 +1352,7 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, self.fuse_wgrad_accumulation, input_quantizer, weight_quantizer, @@ -1454,3 +1463,10 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + + def wgrad_comp(self): + # return + if not self.wgrad_store.split_bw(): + return + with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): + self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..025a2e055d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -8,6 +8,7 @@ from operator import mul as multiply_op import torch +import functools import transformer_engine_torch as tex @@ -20,7 +21,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import noop_cat, _fix_gathered_fp8_transpose +from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ..fp8 import FP8GlobalStateManager from ..utils import ( cast_if_needed, @@ -79,6 +80,7 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], @@ -354,6 +356,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store # Row Parallel Linear if ub_overlap_rs_fprop: @@ -633,15 +636,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") +<<<<<<< HEAD wgrad, grad_bias_, _, rs_out = general_gemm( inputmat_total, grad_output, get_workspace(), +======= + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) + + general_gemm_wgrad = functools.partial(general_gemm, + out_dtype=main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, + workspace=get_workspace(), +>>>>>>> 267b38e (support wgrad split for linear and ln_linear) layout="NT", grad=True, - out_dtype=( - main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype - ), bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, @@ -649,23 +663,28 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, - bulk_overlap=ctx.ub_bulk_wgrad, + bulk_overlap=ctx.ub_bulk_wgrad, ) + + if ctx.wgrad_store.split_bw(): + ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad) + else: + wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if ctx.owns_input: + clear_tensor_data(inputmat_total) nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") if ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): dgrad = rs_out else: - dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) - - if grad_bias is None: - grad_bias = grad_bias_ - del grad_bias_ - - # Deallocate input tensor - if ctx.owns_input: - clear_tensor_data(inputmat_total) + dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) # Don't return grad bias if not needed if not ctx.use_bias: @@ -721,6 +740,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # is_first_microbatch None, # fp8 None, # fp8_calibration + None, # wgrad_store None, # input_quantizer None, # weight_quantizer None, # output_quantizer @@ -842,6 +862,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, + split_bw: bool = False, ) -> None: super().__init__() @@ -855,6 +876,8 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.wgrad_store = WeightGradStore(split_bw) + if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." if tp_group is None: @@ -1159,6 +1182,7 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, input_quantizer, weight_quantizer, output_quantizer, @@ -1264,3 +1288,9 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + + def wgrad_comp(self): + if not self.wgrad_store.split_bw(): + return + with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): + self.wgrad_store.pop() From 4cac7d0df130f713dfea766882a3ecabb6b60f11 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 2 Apr 2025 23:58:34 -0700 Subject: [PATCH 04/53] add comments and fix WeightGradStore Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/_common.py | 45 +++++++++++++++---- .../pytorch/module/grouped_linear.py | 6 ++- .../pytorch/module/layernorm_linear.py | 9 ++-- transformer_engine/pytorch/module/linear.py | 8 +++- 4 files changed, 53 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index f2e5dd86ad..a49dcba046 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -220,33 +220,56 @@ def __post_init__(self): class WeightGradStore: + """ + A class to manage weight gradient storage and computation in Transformer modules. + This class enables split backward propagation for better memory efficiency. + """ - def __init__(self, split_bw=False): + def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True): + """ + Initialize the WeightGradStore. + + Args: + split_bw (bool): Whether to enable split backward propagation + """ self.context = queue.Queue() + assert use_bias == False, "use_bias is not supported when enable split_bw" + assert fuse_wgrad_accumulation == True, "fuse_wgrad_accumulation is not supported when enable split_bw" self.enabled = split_bw - def is_supported(self): - # doesn't support use_bias - # doesn't support fuse_wgrad_accumulation - # doesn't support - return True - def split_bw(self): - if not self.is_supported(): - return False + """ + Get the current split backward propagation status. + + Returns: + bool: True if split backward is enabled, False otherwise + """ return self.enabled def enable_split_bw(self): + """Enable split backward propagation.""" self.enabled = True def disable_split_bw(self): + """Disable split backward propagation.""" self.enabled = False def put(self, tensor_list, func): + """ + Store tensors and computation function for later execution. + + Args: + tensor_list (list): List of tensors needed for computation + func (callable): Function to be executed with the tensors + """ self.context.put([tensor_list, func]) return def pop(self): + """ + Execute the stored computation with the stored tensors. + Raises an exception if the queue is empty. + """ if self.context.qsize() > 0: tensor_list, func = self.context.get() func(*tensor_list) @@ -255,5 +278,9 @@ def pop(self): raise Exception(f"Pop empty queue. rank {rank}") def assert_empty(self): + """ + Assert that the queue is empty. + Used for debugging and ensuring proper cleanup. + """ rank = torch.distributed.get_rank() assert self.context.empty(), f"Queue is not empty. rank {rank}" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4a6242cfff..d15e925ce2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -475,7 +475,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.wgrad_store = WeightGradStore(split_bw) + self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation) self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} @@ -692,6 +692,10 @@ def forward( return out def wgrad_comp(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ba061cfd77..ff6f1ab4e3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1001,7 +1001,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma - self.wgrad_store = WeightGradStore(split_bw) + self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation) if tp_group is None: self.tp_size = tp_size @@ -1465,8 +1465,11 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon def wgrad_comp(self): - # return + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ if not self.wgrad_store.split_bw(): return - with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): + with torch.cuda.nvtx.range("_LayerNormLinear_wgrad"): self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 025a2e055d..6c19b1e447 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -876,7 +876,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.wgrad_store = WeightGradStore(split_bw) + self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation) if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -1290,7 +1290,11 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe ].amax_reduction_group = self.tp_group def wgrad_comp(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ if not self.wgrad_store.split_bw(): return - with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): + with torch.cuda.nvtx.range("_Linear_wgrad"): self.wgrad_store.pop() From 981ed830ab28ffc5e32956ce5f83c29cd69261da Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 3 Apr 2025 03:07:53 -0700 Subject: [PATCH 05/53] support bias and fix unit tests Signed-off-by: Hongbin Liu --- tests/pytorch/test_numerics.py | 22 ++++----------- transformer_engine/pytorch/module/_common.py | 8 +++--- .../pytorch/module/grouped_linear.py | 11 +++++--- .../pytorch/module/layernorm_linear.py | 16 ++++++++--- transformer_engine/pytorch/module/linear.py | 27 +++++++------------ 5 files changed, 39 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index ab1c02c90c..754d4747c0 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -106,13 +106,6 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq recipe.Float8CurrentScaling(), ] -# param_types = [torch.bfloat16] -# batch_sizes = [2] -# all_boolean = [False] -# mask_types = ["causal"] -# fp8_recipes = [ -# recipe.Float8CurrentScaling(), -# ] def get_causal_attn_mask(sq: int) -> torch.Tensor: @@ -1057,9 +1050,7 @@ def _test_granular_accuracy(block, bs, dtype, config, split_bw=False): torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] - for name, p in block.named_parameters(): - # print(f"参数名称: {name}, 参数形状: {p.shape}", {p.main_grad}) - # breakpoint() + for p in block.parameters(): if p.requires_grad: if getattr(p, "main_grad", None) is not None: outputs.append(p.main_grad) @@ -1201,7 +1192,7 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulation): config = model_configs[model] @@ -1432,7 +1423,7 @@ def test_layernorm_linear_accuracy( @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) -@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) def test_layernorm_linear_accuracy_split_bw(dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation): config = model_configs[model] @@ -1613,10 +1604,7 @@ def _test_grouped_linear_accuracy( torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] - # breakpoint() - for name, p in block.named_parameters(): - # print(f"参数名称: {name}, 参数形状: {p.shape}", {p.main_grad}) - # breakpoint() + for p in block.parameters(): if p.requires_grad: if getattr(p, "main_grad", None) is not None: outputs.append(p.main_grad) @@ -1634,7 +1622,7 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) -@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("split_bw", all_boolean) def test_grouped_linear_accuracy( dtype, diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index a49dcba046..c282733609 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -225,7 +225,7 @@ class WeightGradStore: This class enables split backward propagation for better memory efficiency. """ - def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True): + def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True, ub_bulk_wgrad=False): """ Initialize the WeightGradStore. @@ -233,8 +233,8 @@ def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True) split_bw (bool): Whether to enable split backward propagation """ self.context = queue.Queue() - assert use_bias == False, "use_bias is not supported when enable split_bw" - assert fuse_wgrad_accumulation == True, "fuse_wgrad_accumulation is not supported when enable split_bw" + assert fuse_wgrad_accumulation == True, "fuse_wgrad_accumulation is not supported when enabling split_bw" + assert ub_bulk_wgrad == False, "ub_bulk_wgrad is not supported when enabling split_bw" self.enabled = split_bw def split_bw(self): @@ -272,7 +272,7 @@ def pop(self): """ if self.context.qsize() > 0: tensor_list, func = self.context.get() - func(*tensor_list) + return func(*tensor_list) else: rank = torch.distributed.get_rank() raise Exception(f"Pop empty queue. rank {rank}") diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d15e925ce2..f3f79025fa 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -312,7 +312,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) # WGRAD if ctx.wgrad_store.split_bw(): - # if True: ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) else: _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) @@ -359,7 +358,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias: + if not ctx.use_bias or (ctx.wgrad_store.split_bw() and not ctx.fp8): grad_biases = [None] * ctx.num_gemms if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): @@ -699,4 +698,10 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): - self.wgrad_store.pop() + _, grad_biases_, _ = self.wgrad_store.pop() + if self.use_bias: + for i in range(self.num_gemms): + bias_param = getattr(self, f"bias{i}") + if bias_param.grad is None: + bias_param.grad = grad_biases_[i].to(bias_param.dtype) + del grad_biases_ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ff6f1ab4e3..432a297dfc 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -752,7 +752,11 @@ def backward( else: dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) - # Make sure all tensor-parallel communication is finished + # Don't return grad bias if not needed + if not ctx.use_bias or ctx.wgrad_store.split_bw(): + grad_bias = None + + # Synchronize tensor parallel communication if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None @@ -1001,7 +1005,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma - self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation) + self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation, ub_bulk_wgrad) if tp_group is None: self.tp_size = tp_size @@ -1472,4 +1476,10 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_LayerNormLinear_wgrad"): - self.wgrad_store.pop() + _, grad_bias_, _, _ = self.wgrad_store.pop() + if self.use_bias: + for bias_name in self.bias_names: + bias_param = getattr(self, bias_name) + if bias_param.grad is None: + bias_param.grad = grad_bias_.to(bias_param.dtype) + del grad_bias_ diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c19b1e447..dc8edab3f3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -636,24 +636,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") -<<<<<<< HEAD - wgrad, grad_bias_, _, rs_out = general_gemm( - inputmat_total, - grad_output, - get_workspace(), -======= - wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - wgrad_gemm_use_split_accumulator = ( - recipe.fp8_gemm_wgrad.use_split_accumulator - ) - general_gemm_wgrad = functools.partial(general_gemm, out_dtype=main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, workspace=get_workspace(), ->>>>>>> 267b38e (support wgrad split for linear and ln_linear) layout="NT", grad=True, bias=(bias if (grad_bias is None and not ctx.fp8) else None), @@ -687,7 +672,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) # Don't return grad bias if not needed - if not ctx.use_bias: + if not ctx.use_bias or ctx.wgrad_store.split_bw(): grad_bias = None # Make sure all tensor-parallel communication is finished @@ -876,7 +861,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation) + self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation, ub_bulk_wgrad) if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -1297,4 +1282,10 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_Linear_wgrad"): - self.wgrad_store.pop() + _, grad_bias_, _, _ = self.wgrad_store.pop() + if self.use_bias: + for bias_name in self.bias_names: + bias_param = getattr(self, bias_name) + if bias_param.grad is None: + bias_param.grad = grad_bias_.to(bias_param.dtype) + del grad_bias_ From d5f83768fde4b533dc928581de5aae0d162fd475 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 6 Apr 2025 19:36:28 -0700 Subject: [PATCH 06/53] minor fix Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/_common.py | 15 +++++++++++---- transformer_engine/pytorch/module/linear.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index c282733609..50c7efe9ea 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -232,10 +232,14 @@ def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True, Args: split_bw (bool): Whether to enable split backward propagation """ - self.context = queue.Queue() - assert fuse_wgrad_accumulation == True, "fuse_wgrad_accumulation is not supported when enabling split_bw" - assert ub_bulk_wgrad == False, "ub_bulk_wgrad is not supported when enabling split_bw" - self.enabled = split_bw + if split_bw: + self.context = queue.Queue() + assert fuse_wgrad_accumulation == True, "fuse_wgrad_accumulation is not supported when enabling split_bw" + assert ub_bulk_wgrad == False, "ub_bulk_wgrad is not supported when enabling split_bw" + self.enabled = split_bw + else: + self.context = None + self.enabled = False def split_bw(self): """ @@ -262,6 +266,7 @@ def put(self, tensor_list, func): tensor_list (list): List of tensors needed for computation func (callable): Function to be executed with the tensors """ + assert self.enabled == True, "split_bw is not enabled" self.context.put([tensor_list, func]) return @@ -270,6 +275,7 @@ def pop(self): Execute the stored computation with the stored tensors. Raises an exception if the queue is empty. """ + assert self.enabled == True, "split_bw is not enabled" if self.context.qsize() > 0: tensor_list, func = self.context.get() return func(*tensor_list) @@ -282,5 +288,6 @@ def assert_empty(self): Assert that the queue is empty. Used for debugging and ensuring proper cleanup. """ + assert self.enabled == True, "split_bw is not enabled" rank = torch.distributed.get_rank() assert self.context.empty(), f"Queue is not empty. rank {rank}" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dc8edab3f3..3115638c3f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1273,7 +1273,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group - + def wgrad_comp(self): """ Execute the delayed weight gradient computation. From 6c23454f32b29edfe6119df6ae5c16afe75d17e7 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 6 Apr 2025 20:57:51 -0700 Subject: [PATCH 07/53] support fuse_grad_accumulation=false Signed-off-by: Hongbin Liu --- tests/pytorch/test_numerics.py | 10 ++++++---- transformer_engine/pytorch/module/_common.py | 3 +-- .../pytorch/module/grouped_linear.py | 13 +++++++++++-- .../pytorch/module/layernorm_linear.py | 17 +++++++++++------ transformer_engine/pytorch/module/linear.py | 17 +++++++++++------ 5 files changed, 40 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 754d4747c0..ee95a73ba9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1193,7 +1193,7 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("bias", all_boolean) -@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulation): config = model_configs[model] @@ -1424,8 +1424,10 @@ def test_layernorm_linear_accuracy( @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("bias", all_boolean) -@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) -def test_layernorm_linear_accuracy_split_bw(dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation): +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +def test_layernorm_linear_accuracy_split_bw( + dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation +): config = model_configs[model] ln_linear_ref = LayerNormLinear( @@ -1621,7 +1623,7 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) -@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("split_bw", all_boolean) def test_grouped_linear_accuracy( diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 50c7efe9ea..e30621fd67 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -234,7 +234,6 @@ def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True, """ if split_bw: self.context = queue.Queue() - assert fuse_wgrad_accumulation == True, "fuse_wgrad_accumulation is not supported when enabling split_bw" assert ub_bulk_wgrad == False, "ub_bulk_wgrad is not supported when enabling split_bw" self.enabled = split_bw else: @@ -278,7 +277,7 @@ def pop(self): assert self.enabled == True, "split_bw is not enabled" if self.context.qsize() > 0: tensor_list, func = self.context.get() - return func(*tensor_list) + return func(*tensor_list), tensor_list else: rank = torch.distributed.get_rank() raise Exception(f"Pop empty queue. rank {rank}") diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3f79025fa..cc2a1705e1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -357,6 +357,9 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ] else: wgrad_list = [None] * ctx.num_gemms + + if ctx.wgrad_store.split_bw(): + wgrad_list = [None] * ctx.num_gemms if not ctx.use_bias or (ctx.wgrad_store.split_bw() and not ctx.fp8): grad_biases = [None] * ctx.num_gemms @@ -474,7 +477,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation) + self.wgrad_store = WeightGradStore(split_bw) self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} @@ -698,7 +701,13 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): - _, grad_biases_, _ = self.wgrad_store.pop() + (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() + wgrad_list = tensor_list[2] + if not self.fuse_wgrad_accumulation: + for i in range(self.num_gemms): + weight_param = getattr(self, f"weight{i}") + if weight_param.grad is None: + weight_param.grad = wgrad_list[i].to(weight_param.dtype) if self.use_bias: for i in range(self.num_gemms): bias_param = getattr(self, f"bias{i}") diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 432a297dfc..64bde8217f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1005,7 +1005,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma - self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation, ub_bulk_wgrad) + self.wgrad_store = WeightGradStore(split_bw, ub_bulk_wgrad) if tp_group is None: self.tp_size = tp_size @@ -1476,10 +1476,15 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_LayerNormLinear_wgrad"): - _, grad_bias_, _, _ = self.wgrad_store.pop() + (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() + if not self.fuse_wgrad_accumulation: + unfused_weights = [getattr(self, name) for name in self.weight_names] + weight_tensor = noop_cat(unfused_weights) + if weight_tensor.grad is None: + weight_tensor.grad = wgrad.to(weight_tensor.dtype) if self.use_bias: - for bias_name in self.bias_names: - bias_param = getattr(self, bias_name) - if bias_param.grad is None: - bias_param.grad = grad_bias_.to(bias_param.dtype) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + if bias_tensor.grad is None: + bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) del grad_bias_ + del wgrad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3115638c3f..146d827cfa 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -861,7 +861,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.wgrad_store = WeightGradStore(split_bw, bias, fuse_wgrad_accumulation, ub_bulk_wgrad) + self.wgrad_store = WeightGradStore(split_bw, ub_bulk_wgrad) if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -1282,10 +1282,15 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_Linear_wgrad"): - _, grad_bias_, _, _ = self.wgrad_store.pop() + (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() + if not self.fuse_wgrad_accumulation: + unfused_weights = [getattr(self, name) for name in self.weight_names] + weight_tensor = noop_cat(unfused_weights) + if weight_tensor.grad is None: + weight_tensor.grad = wgrad.to(weight_tensor.dtype) if self.use_bias: - for bias_name in self.bias_names: - bias_param = getattr(self, bias_name) - if bias_param.grad is None: - bias_param.grad = grad_bias_.to(bias_param.dtype) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + if bias_tensor.grad is None: + bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) del grad_bias_ + del wgrad From cb49c707f72bbc3624facc8782ac076c4a5c948e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 02:12:52 +0000 Subject: [PATCH 08/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_numerics.py | 26 +++++++++++++------ transformer_engine/pytorch/module/_common.py | 10 ++++--- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 13 ++++++---- transformer_engine/pytorch/module/linear.py | 11 +++++--- 5 files changed, 40 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index ee95a73ba9..946f0c74d9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -107,7 +107,6 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq ] - def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -1045,7 +1044,7 @@ def _test_granular_accuracy(block, bs, dtype, config, split_bw=False): out = out[0] loss = out.sum() loss.backward() - if split_bw and hasattr(block, 'wgrad_comp'): + if split_bw and hasattr(block, "wgrad_comp"): block.wgrad_comp() torch.cuda.synchronize() @@ -1056,7 +1055,7 @@ def _test_granular_accuracy(block, bs, dtype, config, split_bw=False): outputs.append(p.main_grad) assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True else: - outputs.append(p.grad) + outputs.append(p.grad) return outputs @@ -1189,6 +1188,7 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): for te_output, torch_output in zip(te_outputs, torch_outputs): assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) @@ -1225,7 +1225,7 @@ def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulatio if fuse_wgrad_accumulation: weight = getattr(te_linear, f"weight") weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - te_linear_ref.weight.main_grad = weight.main_grad.clone() + te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, split_bw=True) te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, split_bw=False) @@ -1234,6 +1234,7 @@ def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulatio for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @@ -1418,6 +1419,7 @@ def test_layernorm_linear_accuracy( for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) @@ -1454,8 +1456,7 @@ def test_layernorm_linear_accuracy_split_bw( device="cuda", split_bw=True, fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ).eval() - + ).eval() # Share params with torch.no_grad(): @@ -1468,7 +1469,7 @@ def test_layernorm_linear_accuracy_split_bw( if fuse_wgrad_accumulation: weight = getattr(ln_linear, f"weight") weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - ln_linear_ref.weight.main_grad = weight.main_grad.clone() + ln_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, split_bw=True) te_outputs_ref = _test_granular_accuracy(ln_linear_ref, bs, dtype, config, split_bw=False) @@ -1477,6 +1478,7 @@ def test_layernorm_linear_accuracy_split_bw( for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) @@ -1691,7 +1693,15 @@ def test_grouped_linear_accuracy( sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw + sequential_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + split_bw, ) outputs = _test_grouped_linear_accuracy( grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index e30621fd67..11d2e40b96 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -225,10 +225,12 @@ class WeightGradStore: This class enables split backward propagation for better memory efficiency. """ - def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True, ub_bulk_wgrad=False): + def __init__( + self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True, ub_bulk_wgrad=False + ): """ Initialize the WeightGradStore. - + Args: split_bw (bool): Whether to enable split backward propagation """ @@ -243,7 +245,7 @@ def __init__(self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True, def split_bw(self): """ Get the current split backward propagation status. - + Returns: bool: True if split backward is enabled, False otherwise """ @@ -260,7 +262,7 @@ def disable_split_bw(self): def put(self, tensor_list, func): """ Store tensors and computation function for later execution. - + Args: tensor_list (list): List of tensors needed for computation func (callable): Function to be executed with the tensors diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index cc2a1705e1..265958a884 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -357,7 +357,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ] else: wgrad_list = [None] * ctx.num_gemms - + if ctx.wgrad_store.split_bw(): wgrad_list = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 64bde8217f..20d8da8442 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -716,8 +716,11 @@ def backward( # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") - general_gemm_wgrad = functools.partial(general_gemm, - out_dtype=main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, + general_gemm_wgrad = functools.partial( + general_gemm, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), workspace=get_workspace(), layout="NT", grad=True, @@ -728,7 +731,7 @@ def backward( ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, - bulk_overlap=ctx.ub_bulk_wgrad, + bulk_overlap=ctx.ub_bulk_wgrad, ) if ctx.wgrad_store.split_bw(): @@ -745,12 +748,12 @@ def backward( # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme clear_tensor_data(ln_out_total) nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") - + if ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): dgrad = rs_out else: - dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) + dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) # Don't return grad bias if not needed if not ctx.use_bias or ctx.wgrad_store.split_bw(): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 146d827cfa..cf9f9b17e4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -636,8 +636,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") - general_gemm_wgrad = functools.partial(general_gemm, - out_dtype=main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, + general_gemm_wgrad = functools.partial( + general_gemm, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), workspace=get_workspace(), layout="NT", grad=True, @@ -648,7 +651,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, - bulk_overlap=ctx.ub_bulk_wgrad, + bulk_overlap=ctx.ub_bulk_wgrad, ) if ctx.wgrad_store.split_bw(): @@ -669,7 +672,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ub_obj_wgrad.is_fp8_ubuf(): dgrad = rs_out else: - dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) + dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) # Don't return grad bias if not needed if not ctx.use_bias or ctx.wgrad_store.split_bw(): From ba5dc5ddebec040939d52ba72efb70202e6fa25d Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Tue, 8 Apr 2025 06:36:11 -0700 Subject: [PATCH 09/53] Enable reuse of dummy wgrad tensor (#1651) * Use dummy wgrads for lower memory consumption Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Vasudevan Rengasamy * Bug fix to avoid sharing gradients. Signed-off-by: Vasudevan Rengasamy * Disable automatic use of batch_p2p_comm for CP2 Signed-off-by: Vasudevan Rengasamy * Change weight to origin_weight for LN_LINEAR Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Vasudevan Rengasamy Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 4 ++-- transformer_engine/pytorch/module/base.py | 17 +++++++++++++++++ .../pytorch/module/layernorm_linear.py | 18 ++++++++---------- transformer_engine/pytorch/module/linear.py | 18 ++++++++---------- 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6440c628cd..0d442435bf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -616,7 +616,7 @@ def forward( rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -1564,7 +1564,7 @@ def backward(ctx, dout): rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cdb75aa1b6..31a464caad 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -43,6 +43,7 @@ _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _multi_stream_cublas_workspace = [] +_dummy_wgrads = {} _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 @@ -78,6 +79,22 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: return _multi_stream_cublas_workspace +def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: + """Returns a dummy tensor of given shape.""" + assert len(shape) == 2 + global _dummy_wgrads + if (shape[0], shape[1], dtype) not in _dummy_wgrads: + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + shape, + dtype=dtype, + device="cuda", + requires_grad=False, + ) + if zero: + _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) + return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + + def initialize_ub( shape: list, tp_size: int, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..f49bad48c3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -19,6 +19,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -796,18 +797,15 @@ def backward( if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..ca9dd29043 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -16,6 +16,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -688,18 +689,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None From 9d4e11eaa508383e35b510dc338e58b09c30be73 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 8 Apr 2025 14:12:05 -0700 Subject: [PATCH 10/53] [PyTorch] Debug GEMM refactor (#1652) * Minor stylistic tweaks and typo fixes Review suggestions from @ptrendx Signed-off-by: Tim Moon * Fix incorrect col strides for MXFP8 matrices Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/gemm/cublaslt_gemm.cu | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 6fe3539257..0cd0762ee5 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -96,8 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const int arch = cuda::sm_arch(); // Transpose mode with column-major ordering - bool transa_bool = transA == CUBLAS_OP_T; - bool transb_bool = transB == CUBLAS_OP_T; + bool is_A_transposed = transA == CUBLAS_OP_T; + bool is_B_transposed = transB == CUBLAS_OP_T; // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -106,8 +106,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transA = transA; ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - ret.lda = transa_bool ? k : m; - if (arch < 100 && !transa_bool) { + ret.lda = is_A_transposed ? k : m; + if (arch < 100 && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -123,28 +123,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = transA; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.lda = m; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? k : m; } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage @@ -165,8 +165,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transB = transB; ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; - ret.ldb = transb_bool ? n : k; - if (arch < 100 && transb_bool) { + ret.ldb = is_B_transposed ? n : k; + if (arch < 100 && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -182,28 +182,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = transB; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; - ret.ldb = k; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = is_B_transposed ? n : k; } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = CUBLAS_OP_N; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; // Requirements from @@ -392,7 +392,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &B_scale_inverse, sizeof(B_scale_inverse))); NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), - "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"); scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; From 962d9c53423e604f3c22c3ad634bc5a0d66e4f7c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 9 Apr 2025 16:25:23 -0400 Subject: [PATCH 11/53] [JAX] Scaling Enum Abstracting (#1655) * scaling enum abstract * rm NVTE_ from ScalingMode names * rework scaling mode enum in grouped gemm * fix norm sharding --------- Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 4 +- examples/jax/encoder/test_multigpu_encoder.py | 4 +- .../jax/encoder/test_single_gpu_encoder.py | 4 +- examples/jax/mnist/test_single_gpu_mnist.py | 4 +- qa/L0_jax_distributed_unittest/test.sh | 2 +- tests/jax/test_custom_call_compute.py | 40 ++--- tests/jax/test_distributed_layernorm.py | 2 +- tests/jax/test_distributed_layernorm_mlp.py | 2 +- tests/jax/test_layer.py | 4 +- .../jax/cpp_extensions/activation.py | 40 ++--- transformer_engine/jax/cpp_extensions/gemm.py | 34 ++-- transformer_engine/jax/cpp_extensions/misc.py | 2 +- .../jax/cpp_extensions/normalization.py | 162 ++++++++---------- .../jax/cpp_extensions/quantization.py | 24 +-- transformer_engine/jax/csrc/extensions.h | 22 ++- .../jax/csrc/extensions/activation.cpp | 65 +++---- .../jax/csrc/extensions/gemm.cpp | 37 ++-- transformer_engine/jax/csrc/extensions/misc.h | 23 +++ .../jax/csrc/extensions/normalization.cpp | 21 ++- .../jax/csrc/extensions/pybind.cpp | 8 +- .../jax/csrc/extensions/quantization.cpp | 58 +++++-- transformer_engine/jax/flax/module.py | 2 +- .../jax/quantize/dequantizer.py | 4 +- transformer_engine/jax/quantize/helper.py | 16 +- transformer_engine/jax/quantize/quantizer.py | 15 +- .../jax/quantize/scaling_modes.py | 25 ++- transformer_engine/jax/quantize/tensor.py | 3 +- 27 files changed, 335 insertions(+), 292 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7e6605c9fe..eabd1b2a3f 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -448,8 +448,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index ba62d964fa..839bc3175e 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -416,8 +416,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 1300be01bb..df78157cc5 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -327,8 +327,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 4022cb7493..435750a1db 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -306,8 +306,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 3253861484..3fbfb9cf5c 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4dc07a2eea..8917e92465 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -48,21 +48,21 @@ LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) + supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING) + supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) def is_shape_supported_by_mxfp8(input_shape): try: if isinstance(input_shape, type(pytest.param(0))): input_shape = input_shape.values[0] - ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) + ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True except: # get_scale_shapes will raise an exception if the shape is not supported @@ -170,7 +170,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, ) quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) @@ -198,7 +198,7 @@ def test_act_forward_with_delayed_scaling_fp8( te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=q_layout, ) @@ -223,7 +223,7 @@ def test_act_forward_with_block_scaling_fp8( self.activation_type = activation_type quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) output = tex.act_lu(x, activation_type, quantizer) @@ -345,7 +345,7 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_layout=q_layout, ) @@ -420,7 +420,7 @@ def test_norm_forward_with_delayed_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_layout=q_layout, ) @@ -437,7 +437,7 @@ def test_norm_forward_with_block_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_layout=QuantizeLayout.ROWWISE_COLWISE, ) @@ -493,7 +493,7 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -533,7 +533,7 @@ class TestFusedQuantize: def test_quantize_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis ): - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( + if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") @@ -618,7 +618,7 @@ def test_quantize_dact_dbias_no_quantization( in_dtype=in_dtype, input_shape=input_shape, out_dtype=in_dtype, - scaling_mode=ScalingMode.NVTE_NO_SCALING, + scaling_mode=ScalingMode.NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=QuantizeLayout.ROWWISE, @@ -639,7 +639,7 @@ def test_quantize_dact_dbias_delayed_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -670,7 +670,7 @@ def test_quantize_dact_dbias_mxfp8_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -785,7 +785,7 @@ def ref_func(x, w, bias, data_layout): scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) @@ -830,7 +830,7 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): Test layernorm_dense VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -886,7 +886,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, @@ -916,7 +916,7 @@ def test_layernorm_mlp_grad( Test layernorm_mlp VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -993,7 +993,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 6d4cde364f..476d455a6a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -29,7 +29,7 @@ } is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 4350d5e8f3..cf311ac404 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -36,7 +36,7 @@ is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index b89530c19f..a21583a98c 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -39,7 +39,7 @@ def enable_fused_attn(): is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) QUANTIZE_RECIPES = [] """ Find supported scaling modes""" @@ -313,7 +313,7 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: _, updated_quantize_meta = flax.core.pop( updated_state[0], QuantizeConfig.COLLECTION_NAME ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d7676781c3..c27f6f50f7 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -162,7 +162,7 @@ def lowering( assert scale_aval is None or scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x + ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x ) return out @@ -282,7 +282,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -293,9 +293,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -339,7 +339,7 @@ def partition( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -350,9 +350,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -391,7 +391,7 @@ def sharded_impl(x, scale): ) ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -463,7 +463,7 @@ def abstract( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: colwise_out_shape = out_shape @@ -545,7 +545,7 @@ def lowering( dz, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), @@ -673,7 +673,7 @@ def infer_sharding_from_operands( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -691,9 +691,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -743,7 +743,7 @@ def partition( ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -761,9 +761,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -810,7 +810,7 @@ def sharded_impl(dz, x, scale): else: global_dbias = local_dbias - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -928,7 +928,7 @@ def act_lu( out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1042,7 +1042,7 @@ def quantize_dact_dbias( # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, # unused scale_dtype=jnp.float32, # unused scale_shapes=((), ()), # unused @@ -1095,7 +1095,7 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1df2bcc97f..0327542c2f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -98,7 +98,7 @@ def lowering( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + scaling_mode=scaling_mode.value, ) @staticmethod @@ -123,7 +123,7 @@ def impl( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=scaling_mode, out_dtype=out_dtype, out_flat_size=out_flat_size, ) @@ -198,7 +198,7 @@ def _jax_gemm_delayed_scaling_fp8( ): """FP8 GEMM for XLA pattern match""" assert ( - rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -230,7 +230,7 @@ def _jax_gemm_mxfp8_1d( JAX GEMM for MXFP8 via scaled_matmul """ assert ( - rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING + rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING ), "rhs does not have MXFP8 1D scaling mode" from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper @@ -291,10 +291,10 @@ def _jax_gemm( def _jax_gemm_fp8_impl(lhs, rhs): - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) - if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") @@ -403,7 +403,7 @@ def grouped_gemm( rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" @@ -415,7 +415,7 @@ def grouped_gemm( dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NVTE_NO_SCALING + scaling_mode = ScalingMode.NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape out_dtype = lhs.dtype @@ -427,13 +427,13 @@ def grouped_gemm( lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) @@ -470,13 +470,13 @@ def grouped_gemm( dims.append((bm, bn, k)) lhs_contig_.append(lhs_3d.reshape(-1)) rhs_contig_.append(rhs_3d.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) if bias_list is not None: @@ -493,8 +493,8 @@ def grouped_gemm( # TE/common does not support NVTE_NO_SCALING yet # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + if scaling_mode == ScalingMode.NO_SCALING: + scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING # Perform batched GEMM on flattened inputs out_contig = GroupedGemmPrimitive.outer_primitive.bind( @@ -505,7 +505,7 @@ def grouped_gemm( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, out_flat_size=out_flat_size, ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index c79eda5568..d64104ac27 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, """ should_apply_war = ( quantizer is not None - and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x() ) if not should_apply_war: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 74882c92db..388d4f17ee 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -105,6 +105,26 @@ def abstract( if norm_type == NVTE_Norm_Type.LayerNorm: assert gamma_aval.size == beta_aval.size + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) + if norm_type == NVTE_Norm_Type.RMSNorm: + mu_aval = mu_aval.update(shape=(1,)) + + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) + + scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size @@ -112,33 +132,13 @@ def abstract( jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype jax_dtype_to_te_dtype(out_dtype), norm_type, - scaling_mode.value, + scaling_mode, zero_centered_gamma, epsilon, get_forward_sm_margin(), is_2x, ) - - out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) - mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_aval = mu_aval.update(shape=(1,)) - - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x_aval.shape, is_padded=not is_outer - ) - - scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) - colwise_out_aval = jax.core.ShapedArray( - shape=x_aval.shape if is_2x else (1,), dtype=out_dtype - ) - - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - - wkspace_aval = x_aval.update( + wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) @@ -274,9 +274,9 @@ def impl( scale_shapes=scale_shapes, is_outer=False, ) - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x.shape, is_padded=False - ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x.shape, is_padded=False) # slice out padding for mxfp8, noop for DelayedScaling scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape @@ -364,6 +364,8 @@ def infer_sharding_from_operands( del zero_centered_gamma, epsilon, out_dtype, result_infos del scale_dtype, scale_shapes, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) + out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -371,34 +373,27 @@ def infer_sharding_from_operands( "and hurt performance." ) - out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out" + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) - rsigma_sharding = NamedSharding( mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") + + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) - - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") output = ( out_sharding, colwise_out_sharding, @@ -427,8 +422,11 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) g_spec = get_padded_spec(arg_infos[2]) b_spec = get_padded_spec(arg_infos[3]) + out_spec = (*x_spec[:-1], None) + if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -445,43 +443,30 @@ def partition( f"{NormFwdPrimitive.name} does not support sharding of parameter beta " "Enforcing no sharding of parameters hidden dim! " ) - x_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x" - ) - g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma") - b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta") - out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out") - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" + ) rsigma_sharding = NamedSharding( - mesh, - PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]), - desc="NormFwdPrimitive.rsigma", + mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale" - ) - scale_inv_sharding = scale_sharding.duplicate_with_new_description( - "NormFwdPrimitive.scale_inv" + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") - arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, colwise_out_sharding, @@ -517,7 +502,7 @@ def sharded_impl(x, scale, gamma, beta): scale_shapes=scale_shapes, is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -824,7 +809,6 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) - if quantizer is None: output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, @@ -835,7 +819,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((1,), (1,)), @@ -845,7 +829,7 @@ def layernorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -864,7 +848,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -873,7 +857,7 @@ def layernorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -882,7 +866,7 @@ def layernorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) @@ -1017,7 +1001,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1027,7 +1011,7 @@ def rmsnorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -1046,7 +1030,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -1055,7 +1039,7 @@ def rmsnorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1064,7 +1048,7 @@ def rmsnorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 034e149c50..2911b5a420 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -93,7 +93,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -114,6 +114,10 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + scaling_mode, + QuantizeLayout( + q_layout + ), # For now until we have auto-decoding for QuantizeLayout enum ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -176,7 +180,7 @@ def lowering( ctx, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, is_dbias=is_dbias, @@ -302,7 +306,7 @@ def infer_sharding_from_operands( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -322,9 +326,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -374,7 +378,7 @@ def partition( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -394,9 +398,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -445,7 +449,7 @@ def sharded_impl(x, scale): is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -588,7 +592,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1950d6cbab..aaaf57fab7 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -31,6 +31,9 @@ #include "transformer_engine/activation.h" #include "utils.h" +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); + namespace transformer_engine { namespace jax { @@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); + +pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, bool is_2x); + // Normalization XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); @@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training); @@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype); - -XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); - -pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x); + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index e71597e4b3..fc7f231f34 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -17,7 +17,7 @@ namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum, + Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto scaling_mode = static_cast(scaling_mode_enum); auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis @@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("act_enum") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid auto dact_input_tensor = TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); - auto output_tensor = TensorWrapper(static_cast(scaling_mode)); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (is_fp8_dtype(out_dtype)) { @@ -172,9 +173,8 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } if (is_2x) { - auto &tmp_shape = scaling_mode == static_cast(NVTE_DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter @@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } } - if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, - bool is_dbias, int64_t act_enum) { + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, bool is_2x, bool is_dbias) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); - auto scaling_mode = static_cast(scaling_mode_enum); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -295,9 +297,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK( - !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), - "TE/common does not support delayed scaling for 2x with gated activations."); + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { switch (act_type) { @@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") + .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias") - .Attr("act_enum"), + .Attr("is_dbias"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e5ec160c91..d4b9bf720e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const int64_t &scaling_mode, + int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode, cudaStream_t stream) { size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); @@ -90,14 +90,17 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - auto lhs_i = TensorWrapper(static_cast(lhs_ptr), lhs_shape, lhs_dtype, nullptr, - nullptr, reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i = TensorWrapper(static_cast(rhs_ptr), rhs_shape, rhs_dtype, nullptr, - nullptr, reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); + rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); + + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, + std::vector{1}); + rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, + std::vector{1}); + } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", MXFP8_BLOCK_SIZE, k); size_t sinv_k = k / MXFP8_BLOCK_SIZE; @@ -107,20 +110,15 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh rhs_sinv_shape[1] = sinv_k; // Note: the scale_inv array should have been swizzled in Python before lowering - TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING); - TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, lhs_sinv_shape); rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); } else { - NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); + NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); lhs_ptr += m * k * lhs_dtype_bytes; @@ -169,7 +167,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) { + Result_Type workspace_flatten, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode) { // Inputs auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); @@ -207,7 +206,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Ret() // out_flatten .Ret() // workspace_flatten .Attr("num_gemms") - .Attr("scaling_mode"), + .Attr("scaling_mode"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c8526e20c0..f7577c24f3 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -40,5 +40,28 @@ enum class QuantizeLayout { ROWWISE_COLWISE, }; +enum class JAXX_Scaling_Mode : int64_t { + NO_SCALING = 0, + DELAYED_TENSOR_SCALING = 1, + MXFP8_1D_SCALING = 2, +}; + +static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { + switch (mode) { + case JAXX_Scaling_Mode::NO_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::MXFP8_1D_SCALING: + return NVTEScalingMode::NVTE_MXFP8_1D_SCALING; + break; + default: + NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); + break; + } +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 03855753cf..e23e42f528 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -14,7 +14,8 @@ namespace jax { pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training) { auto input_shape = std::vector{batch_size, hidden_size}; @@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - auto _scaling_mode = static_cast(scaling_mode); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated - if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { int temp = 1; output_tensor.set_columnwise_data(static_cast(&temp), out_dtype, input_shape); } @@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, @@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, int scaling_mode, bool is_2x) { + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); - auto _scaling_mode = static_cast(scaling_mode); auto _norm_type = static_cast(norm_type); auto _is_2x = static_cast(is_2x); @@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); if (is_fp8_dtype(out_dtype)) { @@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc scale_inv_buf->dimensions().back()}); } - if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); cudaMemsetAsync(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); @@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, @@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("zero_centered_gamma") .Attr("epsilon") .Attr("sm_margin") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index ebdfe461c7..5c165cccb6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("RMSNorm", NVTE_Norm_Type::RMSNorm) .export_values(); - pybind11::enum_(m, "NVTE_Scaling_Mode", pybind11::module_local()) - .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) - .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) - .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) + pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) + .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) + .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .export_values(); pybind11::enum_(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index b48ee8a9b9..481dbd7cdf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -13,7 +13,9 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ int temp = 0; auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); - auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); - output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); + + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + TensorWrapper dummy_workspace; nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), @@ -44,8 +73,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T Result_Type output_buf, Result_Type output_trans_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, - int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, - int64_t flatten_axis) { + JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -54,7 +83,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); - auto scaling_mode = static_cast(scaling_mode_enum); auto const quantize_layout = static_cast(quantize_layout_enum); auto *output = output_buf->untyped_data(); @@ -77,14 +105,14 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T std::vector workspace_shape{workspace_dims.begin(), workspace_dims.end()}; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); @@ -109,14 +137,16 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (quantize_layout == QuantizeLayout::COLWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -153,7 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("q_layout") .Attr("is_dbias") .Attr("flatten_axis"), diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a944848881..45ff8d7ed9 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -361,7 +361,7 @@ def generate_quantize_meta(quantizer_name: str): ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index b1e9ba03b4..d68eb3c6c2 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -84,8 +84,8 @@ def _dq_func_block_scaling(scaled_tensor): ) funcs = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, - ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling, + ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, + ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling, } @staticmethod diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 7d144aa69d..98f280b9a9 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -94,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ gpu_arch = get_device_compute_capability(gpu_id) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _check_block_scaling_fp8_support(gpu_arch) return (False, "Unsupported scaling_mode!") def is_fp8_available( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, ) -> Tuple[bool, str]: """Check if FP8 is available for the given scaling mode and GPU. @@ -179,9 +179,9 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ValueError: If the recipe type is not supported """ if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.NVTE_DELAYED_TENSOR_SCALING + return ScalingMode.DELAYED_TENSOR_SCALING if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.NVTE_MXFP8_1D_SCALING + return ScalingMode.MXFP8_1D_SCALING raise ValueError("Invalid fp8_recipe!") @@ -217,7 +217,7 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False IF_QUANTIZE_2X: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING + SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 @@ -253,11 +253,11 @@ def finalize(cls) -> None: cls.MARGIN = 0.0 cls.FP8_FORMAT = recipe.Format.HYBRID cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.FP8_2X_ACC_FPROP = False cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.IF_QUANTIZE_2X = False # DelayedScaling cls.AMAX_HISTORY_LEN = 1024 diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index bd7045453b..b57043a034 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) @@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): q_layout: Quantization axis (default: ROWWISE_COLWISE) """ - scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING + scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE def get_data_layout(self) -> str: @@ -530,8 +530,8 @@ class QuantizerFactory: """ quantizer_type_map = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, + ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, } @staticmethod @@ -556,8 +556,9 @@ def create( A single quantizer or tuple of quantizers """ # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted - # assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING - if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING): + assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" + # import pdb; pdb.set_trace() + if scaling_mode == ScalingMode.NO_SCALING: quantizers = [None] * n_quantizers else: quantizers = [] @@ -651,4 +652,4 @@ def create_set( return q_set[0] if len(q_set) == 1 else tuple(q_set) -noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING) +noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 95bbc9bb41..34f63a994c 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -19,6 +19,8 @@ from jax.tree_util import register_pytree_node_class import jax.numpy as jnp +from transformer_engine_jax import JAXX_Scaling_Mode + __all__ = ["ScalingMode"] @@ -216,25 +218,20 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) -# (Phuong: Map the NVTEScalingMode value to the ScalingMode - - @dataclass(frozen=True) @register_pytree_node_class class ScalingMode(Enum): """Enumeration of tensor scaling modes with their corresponding metadata implementations. This class defines the available scaling modes for tensor quantization: - - NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - - NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - - NVTE_INVALID_SCALING: Invalid scaling mode - - NVTE_NO_SCALING: No scaling applied + - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales + - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales + - NO_SCALING: No scaling applied """ - NVTE_DELAYED_TENSOR_SCALING = 0 - NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 100 - NVTE_NO_SCALING = 1000 + NO_SCALING = JAXX_Scaling_Mode.NO_SCALING + DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING + MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -329,8 +326,8 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR - ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c34a235d94..0ef30f4728 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -236,13 +236,12 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # TODO(Phuong): Handle padding !? scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv - # TODO(Phuong): constaint padded scale_inv? return ScaledTensor1x( data=data, scale_inv=scale_inv, From 20e95ba3d3f7af540678af74dc5960332221a0b3 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:05:13 -0700 Subject: [PATCH 12/53] [PyTorch] Explicitly specify quantized tensor usages needed for linear op backward (#1646) Explicitly specify quantized tensor usages needed for linear op backward Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 9 ++++-- .../pytorch/ops/basic/basic_linear.py | 32 +++++++++++-------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9c1a842cd8..ddc79af426 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1420,15 +1420,17 @@ def test_activation( test_device=device, test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, + test_is_fp8=quantized_compute, requires_grad=False, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1459,6 +1461,7 @@ def test_activation( swiglu=te_ops.SwiGLU, )[activation] forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantized_compute), make_op(), te_ops.Quantize(forward=quantized_compute, backward=False), ) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb93eb5e6b..b451acea9a 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -523,7 +523,7 @@ def _functional_forward( # Configure input tensor for backward pass if own_quantized_x_local: - x_local.update_usage(rowwise_usage=False) + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save @@ -679,7 +679,9 @@ def _functional_backward( quantizer=input_quantizer, ) else: - if not isinstance(x_local, QuantizedTensor): + if isinstance(x_local, QuantizedTensor): + x_local.update_usage(columnwise_usage=True) + else: x_local = input_quantizer(x_local) x = x_local else: @@ -706,15 +708,19 @@ def _functional_backward( raise ValueError("Weight tensor is required to compute input grad") w = weight w_is_quantized = isinstance(w, QuantizedTensor) - if with_quantized_compute and not w_is_quantized: - if weight_quantizer is None: - raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(columnwise=True) - w = weight_quantizer(w) - elif not with_quantized_compute and w_is_quantized: - w = w.dequantize() - if not with_quantized_compute and w.dtype != dtype: - w = w.to(dtype=dtype) + if with_quantized_compute: + if w_is_quantized: + w.update_usage(columnwise_usage=True) + else: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + else: + if w_is_quantized: + w = w.dequantize(dtype=dtype) + elif w.dtype != dtype: + w = w.to(dtype=dtype) # Synchronize tensor-parallel communication _wait_async(dy_async) @@ -867,8 +873,8 @@ def op_forward( # Configure quantizers # Note: We cache the quantized input for backward pass, # but discard the quantized weights. - input_quantizer.set_usage(columnwise=weight_requires_grad) - weight_quantizer.set_usage(columnwise=False) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) # Get autocast dtype if needed dtype = None From 0da604497446c5ef4cf28e35c7362f5ad913e434 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:05:24 -0700 Subject: [PATCH 13/53] [PyTorch] Debug checkpointing with te.Sequential (#1629) * Debug checkpointing with te.Sequential Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 116 ++++++++++++++++++ transformer_engine/pytorch/ops/op.py | 71 ++++++----- .../pytorch/tensor/mxfp8_tensor.py | 5 +- 3 files changed, 158 insertions(+), 34 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ddc79af426..59af228861 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import Iterable +import io import math from typing import Optional @@ -1885,3 +1886,118 @@ def test_backward_linear_add( torch.testing.assert_close(y2_test, y2_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + +class TestCheckpointing: + """Tests for checkpointing""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear( + self, + *, + pre_checkpoint_steps: int = 2, + post_checkpoint_steps: int = 2, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Check checkpointing with linear op""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + + # Construct model + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_save = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25) + + # Warmup training steps + for _ in range(pre_checkpoint_steps): + x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + dy = torch.randn(out_shape, dtype=dtype, device=device) + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(x) + y.backward(dy) + optim_save.step() + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save( + {"model": model_save.state_dict(), "optim": optim_save.state_dict()}, + byte_stream, + ) + checkpoint_bytes = byte_stream.getvalue() + del byte_stream + + # Synthetic data for evaluation + xs_save = [ + torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + for _ in range(post_checkpoint_steps) + ] + with torch.no_grad(): + xs_load = [x.clone().requires_grad_() for x in xs_save] + dys = [ + torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps) + ] + + # Training steps with original model + ys_save = [] + for i in range(post_checkpoint_steps): + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(xs_save[i]) + y.backward(dys[i]) + optim_save.step() + ys_save.append(y) + + # Load checkpoint + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_load = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25) + state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False) + model_load.load_state_dict(state_dict["model"]) + optim_load.load_state_dict(state_dict["optim"]) + + # Training steps with loaded model + ys_load = [] + for i in range(post_checkpoint_steps): + optim_load.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_load(xs_load[i]) + y.backward(dys[i]) + optim_load.step() + ys_load.append(y) + + # Check that original and loaded model match exactly + tols = {"rtol": 0, "atol": 0} + for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): + torch.testing.assert_close(param_load, param_save, **tols) + torch.testing.assert_close(param_load.grad, param_save.grad, **tols) + for y_load, y_save in zip(ys_load, ys_save): + torch.testing.assert_close(y_load, y_save, **tols) + for x_load, x_save in zip(xs_load, xs_save): + torch.testing.assert_close(x_load.grad, x_save.grad, **tols) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 2e212e15f4..ad32055479 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -19,6 +19,7 @@ DelayedScalingRecipeState, FP8GlobalStateManager, RecipeState, + fp8_autocast, ) from ..tensor import Quantizer @@ -508,7 +509,7 @@ def forward( def get_extra_state(self) -> torch.Tensor: """Serialize extra state - Contains metadata for FP8 casting. + Contains metadata for quantization recipe. """ @@ -540,23 +541,27 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: dst.copy_(src, non_blocking=True) return dst - # Store FP8 state + # Store quantizer state if needed state = {} for mode in ("forward", "backward"): - # Get state for a given FP8 tensor - if self.num_quantizers(mode) == 0: + # Skip if op has no quantizer state + if self._fp8_metas is None or self._fp8_metas[mode] is None: continue - fp8_meta = self.get_fp8_meta(mode) + + # Quantizer state + fp8_meta = self._fp8_metas[mode] state[mode] = {} + state[mode]["recipe"] = fp8_meta["recipe"] - # Store tensors - if "scaling_fwd" in fp8_meta: - state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) - if "scaling_bwd" in fp8_meta: - state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) + # Copy tensors to CPU and store + if state[mode]["recipe"].delayed(): + if mode == "forward": + state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) + state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) + if mode == "backward": + state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) + state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items extra = {} @@ -595,37 +600,37 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) dst.copy_(src, non_blocking=True) - # Load FP8 state + # Load quantizer state if needed for mode in ("forward", "backward"): - # Get state for a given FP8 tensor + # Skip if checkpoint has no quantizer state if mode not in state: continue - if self.num_quantizers(mode) == 0: - continue - fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue - # Load extra state + # Get op's quantizer state, initializing if needed + if self._fp8_metas is None or self._fp8_metas[mode] is None: + with fp8_autocast(fp8_recipe=state[mode]["recipe"]): + self._reset_quantization_recipe_state() + fp8_meta = self._fp8_metas[mode] + + # Load extra items + fp8_meta["recipe"] = state[mode]["recipe"] fp8_meta.update(state[mode]["extra_fp8_variables"]) - if "amax_history_fwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0) - elif "amax_history_bwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0) if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Load tensors - fp8_meta = self.get_fp8_meta(mode) - if "scaling_fwd" in fp8_meta: - fp8_meta_fwd = fp8_meta["scaling_fwd"] - copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) - if "scaling_bwd" in fp8_meta: - fp8_meta_bwd = fp8_meta["scaling_bwd"] - copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) + if state[mode]["recipe"].delayed(): + if mode == "forward": + copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale) + copy_tensor( + state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history + ) + if mode == "backward": + copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale) + copy_tensor( + state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history + ) # Finish CPU-GPU memory transfers torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 843c7936f2..2694319a0f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -347,6 +347,7 @@ def _make_in_reduce_ex( columnwise_scale_inv: torch.Tensor, fp8_dtype: TE_DType, dtype: torch.dtype, + shape: torch.shape, ) -> MXFP8Tensor: """Build MXFP8Tensor, for use in __reduce__ @@ -361,10 +362,11 @@ def _make_in_reduce_ex( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling""" return ( MXFP8Tensor._make_in_reduce_ex, ( @@ -374,6 +376,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._fp8_dtype, self.dtype, + self.shape, ), ) From a8f0fe03b4ac801ee50b35ff24ec2998eaa301ac Mon Sep 17 00:00:00 2001 From: kwyss-nvidia Date: Thu, 10 Apr 2025 10:36:04 -0700 Subject: [PATCH 14/53] Blockwise scaling linear quantization recipe (#1559) * Add GEMM logic for blockwise quantized tensors. GEMM test cases included in pytorch integration. Signed-off-by: Keith Wyss * Update NVTE_BLOCK_SCALING for GEMM. Signed-off-by: Keith Wyss * Gate feature on CUDA 12.9 Signed-off-by: Keith Wyss * Gemm typo. Signed-off-by: Keith Wyss * Remove unecessary type converter change. Signed-off-by: Keith Wyss * Reflect epilogue availability and test supported epilogues. Signed-off-by: Keith Wyss * GEMM simplifications from recipe branch. Signed-off-by: Keith Wyss * Format py code. Signed-off-by: Keith Wyss * Update GEMM DGelu tests to match support depending on output dtype. Signed-off-by: Keith Wyss * Force pow2Scales in GEMM Signed-off-by: Keith Wyss * Add GEMM test to pytorch test suite. Signed-off-by: Keith Wyss * Add copyright to GEMM test. Signed-off-by: Keith Wyss * Update import for GEMM test. Signed-off-by: Keith Wyss * Add license. Signed-off-by: Keith Wyss * Update test gemm supported predicate. Signed-off-by: Keith Wyss * Use sgemm like interfaces and naming. Signed-off-by: Keith Wyss * Rewrite GEMM comment. Signed-off-by: Keith Wyss * MR Feedback. Signed-off-by: Keith Wyss * Recipe setup for Linear modules. Signed-off-by: Keith Wyss * Use 12.9 feature test. Signed-off-by: Keith Wyss * Run against tensor dumps from internal library. Signed-off-by: Keith Wyss * Update FIXME to TODO with linked issue. Signed-off-by: Keith Wyss * Update full recompute feature to save recipe. The recompute context uses the same recipe and fp8 settings as the original fwd pass. Signed-off-by: Keith Wyss * MR Feedback. Avoid reusing quantizer objects. Signed-off-by: Keith Wyss * Update logic in module. Signed-off-by: Keith Wyss * Format py. Signed-off-by: Keith Wyss * Update for PP bug. Signed-off-by: Keith Wyss * Update test numerics. Signed-off-by: Keith Wyss * Update force_power_of_2 scales in the recipe. Signed-off-by: Keith Wyss * Update usage method to satisfy upstream changes. Signed-off-by: Keith Wyss * fix subchannel recipe in distributed test with bf16 gather Signed-off-by: zhongboz * Edit and cleanup BF16 gather code. Signed-off-by: Keith Wyss * Update test import. Signed-off-by: Keith Wyss * support columnwise only mode to 1D quantize kernel Signed-off-by: zhongboz * Format and move enum Signed-off-by: Keith Wyss * Skip alloc. Signed-off-by: Keith Wyss * try async bf16 gather Signed-off-by: zhongboz * Format python code. Signed-off-by: Keith Wyss * Document and type code. Signed-off-by: Keith Wyss * Update pytorch lint errors. Signed-off-by: Keith Wyss * Dont set high precision dtype. Signed-off-by: Keith Wyss * Add test for sanity and CG; fix CG for sequential? Signed-off-by: Kirthi Shankar Sivamani * Keep make_quantizers API stable Update num_quantizers instead to pass cuda_graph tests. Signed-off-by: Keith Wyss * Fix import name. Signed-off-by: Keith Wyss * Rename recipe method. Signed-off-by: Keith Wyss * Skip grouped linear sanity test. Signed-off-by: Keith Wyss * Set usage before BF16 gather. Signed-off-by: Keith Wyss * refactor for nvte_quantize_v2 Signed-off-by: zhongboz * Format code. Signed-off-by: Keith Wyss * Cleanup nvte_quantize_v2 Signed-off-by: Keith Wyss * Test fp32 scales. Signed-off-by: Keith Wyss * Disable CUDA graph. Signed-off-by: Keith Wyss * Simplify layernorm linear Signed-off-by: Keith Wyss * Cleanup layernorm linear. Signed-off-by: Keith Wyss * LayerNorm linear bwd gather logic. Signed-off-by: Keith Wyss * Communication updates. Signed-off-by: Keith Wyss * Update transformer_engine/pytorch/ops/op.py Apply MR comment change. Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: kwyss-nvidia * Lint fix. Signed-off-by: Keith Wyss * MR feedback. Signed-off-by: Keith Wyss * Enable cuda graph tests. Signed-off-by: Keith Wyss * Reduce chance of spurious failure and reword. Signed-off-by: Keith Wyss * Review suggestions from @timmoon10 Signed-off-by: Tim Moon * Update CPP tests. Signed-off-by: Keith Wyss * Update common.h Signed-off-by: Xin Yao * Update test_float8blockwisetensor.py Signed-off-by: Xin Yao --------- Signed-off-by: Keith Wyss Signed-off-by: zhongboz Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: kwyss-nvidia Signed-off-by: Tim Moon Signed-off-by: Xin Yao Co-authored-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon Co-authored-by: Xin Yao --- .../cpp/operator/test_cast_float8blockwise.cu | 32 ++- tests/cpp/test_common.cu | 7 +- tests/cpp/test_common.h | 14 +- tests/pytorch/distributed/run_numerics.py | 5 +- tests/pytorch/distributed/test_numerics.py | 7 +- tests/pytorch/test_cuda_graphs.py | 8 + .../test_float8_blockwise_gemm_exact.py | 9 +- .../test_float8_blockwise_scaling_exact.py | 265 +++++++++++++++--- .../test_float8_current_scaling_exact.py | 22 +- tests/pytorch/test_float8blockwisetensor.py | 21 ++ tests/pytorch/test_float8tensor.py | 28 +- tests/pytorch/test_numerics.py | 14 + tests/pytorch/test_sanity.py | 34 +++ .../common/activation/activation_template.h | 8 +- transformer_engine/common/common.h | 6 +- .../common/include/transformer_engine/cast.h | 12 +- .../transformer_engine/transformer_engine.h | 12 + transformer_engine/common/recipe/__init__.py | 101 +++++++ .../common/transformer_engine.cpp | 6 + .../common/transpose/cast_transpose.h | 28 +- .../quantize_transpose_vector_blockwise.cu | 70 +++-- transformer_engine/common/util/cast.cu | 32 ++- .../common/util/cast_kernels.cuh | 36 ++- .../pytorch/cpp_extensions/gemm.py | 5 + transformer_engine/pytorch/csrc/common.h | 4 +- .../pytorch/csrc/extensions/activation.cpp | 7 +- .../pytorch/csrc/extensions/cast.cpp | 15 +- .../pytorch/csrc/extensions/normalization.cpp | 20 +- .../pytorch/csrc/extensions/quantizer.cpp | 8 +- .../pytorch/csrc/extensions/transpose.cpp | 1 + transformer_engine/pytorch/distributed.py | 85 +++++- transformer_engine/pytorch/fp8.py | 138 ++++++++- transformer_engine/pytorch/module/base.py | 31 +- .../pytorch/module/grouped_linear.py | 2 + .../pytorch/module/layernorm_linear.py | 38 ++- .../pytorch/module/layernorm_mlp.py | 94 +++++-- transformer_engine/pytorch/module/linear.py | 54 +++- .../pytorch/ops/basic/basic_linear.py | 7 + transformer_engine/pytorch/ops/op.py | 15 +- .../_internal/float8_blockwise_tensor_base.py | 16 +- .../pytorch/tensor/float8_blockwise_tensor.py | 24 +- 41 files changed, 1121 insertions(+), 220 deletions(-) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index cc27f72769..10b52e065f 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -19,6 +19,12 @@ using namespace test; namespace { +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + constexpr size_t kBlockLen = 128; enum ProcessingMethod { @@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector ref_output = std::make_unique(rows * cols); @@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector(&input, fill_case); fillUniform(&grad); + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); Tensor workspace; switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, Tensor input("input", shape, itype); Tensor grad("grad", shape, itype); Tensor output_c("output_c", shape, otype, rowwise, colwise, - opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); Tensor output_dbias("output_dbias", {cols}, itype); std::unique_ptr ref_output = std::make_unique(rows * cols); @@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, fillUniform(&grad); Tensor workspace; + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(opts.force_pow_2_scales); + quant_config.set_amax_epsilon(opts.amax_epsilon); switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_quantize(input.data(), output_c.data(), 0); + nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr); break; } } @@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { - {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, - {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, + {1, 16}, {65, 96}, {256, 256}, {993, 512}, + {256, 65536}, {4096, 1632}, {1024, 1}, + {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, }; std::vector input_scenarios = { @@ -429,6 +441,8 @@ std::vector Activation_types = { std::vector amax_epsilons = { 0.0f, + 1.0f, // Make large to be observable. + }; } // namespace @@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); @@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 071c2186e0..61d3075265 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -216,8 +216,7 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode, - const QuantizationOptions* q_opts) { + const NVTEScalingMode &scaling_mode) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -328,10 +327,6 @@ Tensor::Tensor(const std::string& name, tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } - if (q_opts != nullptr) { - NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation."); - NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation."); - } } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 08df3cf7d1..d5ecc6d0f5 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,29 +95,21 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; -struct QuantizationOptions { - bool force_pow_2_scales = false; - float amax_epsilon = 0.0; - size_t block_scaling_dim = 2u; -}; - class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, - const QuantizationOptions* q_opts = nullptr) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ae5993eb1e..b423bce53d 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -19,6 +19,7 @@ MXFP8BlockScaling, DelayedScaling, Float8CurrentScaling, + Float8BlockScaling, Format, Recipe, ) @@ -49,6 +50,8 @@ def quantization_recipe() -> Recipe: return MXFP8BlockScaling() if QUANTIZATION == "fp8_cs": return Float8CurrentScaling() + if QUANTIZATION == "fp8_block_scaling": + return Float8BlockScaling() return te.fp8.get_default_fp8_recipe() @@ -85,7 +88,7 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE SEQ_LEN = 32 BATCH_SIZE = 32 diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index b4e2b680b3..632f50e90a 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -28,6 +28,9 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -48,7 +51,7 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -56,4 +59,6 @@ def test_distributed(quantization): pytest.skip(fp8_available) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) _run_test(quantization) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 5a1dc3f732..7bfe506f26 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -27,6 +27,9 @@ # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -55,6 +58,7 @@ class ModelConfig: recipe.DelayedScaling(), recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] # Supported data types @@ -316,9 +320,13 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling() and module == "linear_op": + pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") # Run model with different CUDA graph settings. model_config = model_configs[model_config] kwargs = dict( diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 9a1cfa2db8..ec23cfe8c5 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,21 +8,18 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from transformer_engine.pytorch.utils import get_device_compute_capability from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - return ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ) + supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() + return supported def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index e638fe8c5b..0baee4975d 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -4,11 +4,14 @@ from typing import Tuple import math +import os +import pathlib import pytest import torch import transformer_engine as te import transformer_engine_torch as tex -from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -18,10 +21,29 @@ BlockwiseQuantizerReference, QuantizeResult, ) +from test_float8_current_scaling_exact import ( + TestFP8RecipeLinearBase, + TestFP8RecipeLayerNormLinearBase, +) + +# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() + + +class GetRecipes: -# TODO replace with call to fp8.py when recipe added. -recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 -reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + @staticmethod + def none(): + return None + + @staticmethod + def fp8_blockwise(): + # return default configs + return Float8BlockScaling() def initialize_for_many_scales( @@ -66,35 +88,7 @@ def initialize_for_many_scales( return result -@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize( - "M, N", - [ - # full tile cases - (128, 128), - (256, 256), - (256, 1024), - (1024, 256), - # Padding required cases - (256, 272), - (303, 300), - (305, 256), - # Some larger tiles. - (2000, 2000), - (2048, 2000), - (2000, 1024), - (2048, 1024), - ], -) -@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] -) -@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) -@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) -def test_quantization_block_tiling_versus_reference( +def check_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, N: int, @@ -199,12 +193,90 @@ def test_quantization_block_tiling_versus_reference( [ # full tile cases (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (303, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) @pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (256, 256), + (2048, 1024), + # Padding required cases + (256, 272), + (303, 300), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference_fp32_scales( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + check_quantization_block_tiling_versus_reference( + x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"]) @pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) def test_quantization_block_tiling_extrema_versus_reference( @@ -292,3 +364,130 @@ def test_quantization_block_tiling_extrema_versus_reference( atol=0.0, rtol=0.0, ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_blockwise), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1.6, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 9741b1258c..8911ecc159 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -82,7 +82,8 @@ def _get_sum_abs_error(a, b): @staticmethod def _get_mean_abs_relative_error(a, b): - return torch.mean(torch.abs((a - b) / b)) + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) @staticmethod def _load_golden_tensor_values(a, b): @@ -97,9 +98,12 @@ def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { @@ -437,9 +441,13 @@ def _check_golden_tensor_dumps( fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) # Expected tensor names based on the naming template - scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example - "ScalingType.PER_TENSOR" - ) + if recipe.float8_current_scaling(): + scaling_type = "ScalingType.PER_TENSOR" + elif recipe.float8_block_scaling(): + scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W" + else: + scaling_type = "Unknown" + current_seed = torch.initial_seed() # Get the current seed expected_tensor_names = { diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index d030426b74..6d3e879970 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -110,7 +110,10 @@ def _test_quantize_dequantize( dims = _to_list(dims) # Initialize random data + # Note: Make sure values are not all close to zero, or else + # test may pass trivially. x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref.view(-1)[0] = 0.75 x_ref_cuda = x_ref.to("cuda") # Cast to FP8 and back @@ -150,6 +153,24 @@ def test_quantize_dequantize_dtypes( ) self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1]) + def test_quantize_dequantize_columnwise_only( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=False, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True + ) + @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 42600e3099..d36da704b0 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import io -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import pytest import torch @@ -158,6 +158,32 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("noop", [True, False]) + def test_quantize_dequantize_noop( + self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool + ) -> None: + noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda") + if noop: + noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda") + dims = 23 + scale: float = 3.5 + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + # if noop, then when we input a different tensor, output should still be x_fp8_orig + x_ref_noop_test = 2 * x_ref.cuda() + x_fp8_orig = x_fp8.clone() + x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor) + if noop_tensor.item() == 1.0: + torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0) + else: + torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype]) + def test_basic_ops( self, dims: DimsType = 23, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 35f65a75f4..7a930b6cde 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -50,6 +50,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) sm_80plus = get_device_compute_capability() >= (8, 0) @@ -104,6 +107,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), ] @@ -563,6 +567,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -675,6 +681,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -1528,6 +1536,8 @@ def test_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.float8_block_scaling(): + pytest.skip("Grouped linear for FP8 blockwise unsupported.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1723,6 +1733,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip("MXFP8 unsupported for grouped linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for grouped linear.") + if recipe.float8_block_scaling(): + pytest.skip("Float8 block scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1933,6 +1945,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 69ac8f7996..d8552d63a4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,6 +46,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -106,6 +109,7 @@ def is_fp8_supported(self): None, # Test non-FP8 recipe.MXFP8BlockScaling(), # Test default recipe.Float8CurrentScaling(), # Test default + recipe.Float8BlockScaling(), # Test default recipe.DelayedScaling(), # Test default recipe.DelayedScaling( # Test most_recent algo amax_history_len=16, @@ -439,6 +443,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -470,6 +476,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -502,6 +510,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -543,10 +553,14 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8(): pytest.skip("Grouped linear does not support MXFP8") if fp8_recipe.float8_current_scaling(): pytest.skip("Grouped linear does not support FP8 current scaling") + if fp8_recipe.float8_block_scaling(): + pytest.skip("Grouped linear does not support FP8 block scaling") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -590,6 +604,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -640,6 +656,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -707,6 +725,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -766,6 +786,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -823,6 +845,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -858,6 +882,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -896,6 +922,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -937,6 +965,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): @@ -979,8 +1009,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling(): + pytest.skip("cuda graph not supported for float8_block_scaling recipe") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 708403f911..67f173a4ab 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template @@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, - workspace, stream); + quantize_helper(input, grad, output, dbias, workspace, + nullptr, stream); } template diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b1fe436379..728f8ad147 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -233,10 +233,12 @@ struct Tensor { struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; + NVTETensor noop_tensor = nullptr; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float) // amax_epsilon + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor) // noop_tensor }; }; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 7fa7957fa4..64136b2c43 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,7 +89,7 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. @@ -102,6 +102,16 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream); +/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output quantized tensor. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream); + /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ba47b9d38c..d3ee446f83 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -286,6 +286,12 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigForcePow2Scales = 0, /*! Small value to add to amax for numerical stability */ kNVTEQuantizationConfigAmaxEpsilon = 1, + /*! Noop tensor (containing a scalar). + If the scalar element value = 1, quantization kernel will early exit. + This is a tensor because the flag must be on GPU in order to enable + conditional early even when captured in a static CUDA graph. + */ + kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNumAttributes }; @@ -724,6 +730,12 @@ class QuantizationConfigWrapper { &amax_epsilon, sizeof(float)); } + /*! \brief Set noop tensor pointer */ + void set_noop_tensor(NVTETensor noop_tensor) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, &noop_tensor, + sizeof(NVTETensor)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b676bf6ab0..80857e565c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -5,6 +5,7 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations import warnings +import os from enum import Enum from typing import Literal, Optional, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass @@ -81,6 +82,10 @@ def float8_per_tensor_scaling(self): """Whether the given recipe is per-tensor scaling.""" return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + def float8_block_scaling(self): + """Whether the given recipe is float8 blockwise scaling.""" + return isinstance(self, Float8BlockScaling) + @dataclass() class DelayedScaling(Recipe): @@ -287,3 +292,99 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + + +@dataclass() +class Float8BlockScaling(Recipe): + """ + Use block-wise scaling for FP8 tensors. + + In this strategy, tensors are scaled in blockwise fashion. Values within + each block share a common scaling factor. The block dimensionality + can be configured. The scaling factors are float32 containers. They + will by default be constrained to powers of 2. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), the quantized tensor and its transpose are not numerically + equivalent. Due to this, when Transformer Engine needs both the FP8 tensor + and its transpose (e.g. to calculate both forward and backward pass), + during the quantization both versions are computed from the high precision + input to avoid double quantization errors. + + NOTE: To relax the default constraint that scales be powers of 2, set env variable + NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. + export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 + Or initialize the Recipe with non-default QParams in code for increased control. + + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} + used for quantization of gradient tensor dY + x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for x. + w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for w. + grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) + qblock scaling for grad. + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + """ + + use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" + + fp8_format: Format = Format.E4M3 + fp8_quant_fwd_inp = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0) + x_block_scaling_dim: int = 1 + w_block_scaling_dim: int = 2 + grad_block_scaling_dim: int = 1 + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" + assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" + assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" + assert not ( + self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert not ( + self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2 + ), "2D by 2D block gemm not supported." + assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." + assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." + assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"x_block_scaling_dim={self.x_block_scaling_dim}, " + f"w_block_scaling_dim={self.w_block_scaling_dim}, " + f"grad_block_scaling_dim={self.grad_block_scaling_dim}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 97df5892b6..706e0bd0b5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -429,6 +429,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(buf, &config_.amax_epsilon, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(buf, &config_.noop_tensor, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -458,6 +461,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigAmaxEpsilon: std::memcpy(&config_.amax_epsilon, buf, attr_size); break; + case kNVTEQuantizationConfigNoopTensor: + std::memcpy(&config_.noop_tensor, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 298d087337..3148b4f720 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -29,11 +29,35 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor const bool return_transpose, const bool pow_2_scale, cudaStream_t stream); +// enum class for rowwise usage +enum class FP8BlockwiseRowwiseOption { + // No rowwise data + NONE, + // Rowwise data, scales in GEMM format + ROWWISE + // TODO: FP8 all gather requires some changes. + // 1. Compact scales are better for gathering than the GEMM format. +}; + +// enum class for columnwise usage +// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling +enum class FP8BlockwiseColumnwiseOption { + // No columnwise data + NONE, + // Columnwise data transposed from original shape. + // Scales in GEMM format corresponding to GEMM ingesting transposed column data. + COLUMNWISE_TRANSPOSE + // TODO: FP8 all gather requires some changes. + // 1. The transpose gets in the way of the all gather. + // 2. Compact scales are better for gathering than the GEMM format. +}; + void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, - const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scale, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 732d97999c..91f73dea1e 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -16,11 +16,15 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" #include "common/utils.cuh" namespace transformer_engine { namespace { +using transformer_engine::detail::FP8BlockwiseColumnwiseOption; +using transformer_engine::detail::FP8BlockwiseRowwiseOption; + // clang-format off /* @@ -138,15 +142,17 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); template -__global__ void __launch_bounds__(kThreadsPerBlock) - block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, - OType* const output_t, CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, const size_t row_length, - const size_t num_rows, const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, const float epsilon, - bool return_transpose, bool pow_2_scaling) { +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, + FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow_2_scaling) { + bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; + bool return_columnwise_transpose = + columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; + using SMemVec = Vec; using OVec = Vec; union IVec { @@ -203,7 +209,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) __syncthreads(); // Step 2: Cast and store to output_c - { + if (return_rowwise) { constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; @@ -294,7 +300,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Step 3: Transpose, cast and store to output_t - if (return_transpose) { + if (return_columnwise_transpose) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); @@ -389,10 +395,15 @@ namespace transformer_engine::detail { void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, - const bool return_transpose, const bool pow2_scale, - cudaStream_t stream) { + FP8BlockwiseRowwiseOption rowwise_option, + FP8BlockwiseColumnwiseOption columnwise_option, + const bool pow2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); - NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + // assert that rowwise_option and columnwise_option are not both NONE + NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE || + columnwise_option != FP8BlockwiseColumnwiseOption::NONE, + "rowwise_option and columnwise_option cannot both be NONE"); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -408,21 +419,24 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor } // Options for scale layout of cuBLAS GEMM kernel. - - NVTE_CHECK(input.shape.size() == output.shape.size(), - "Input and output must have the same shape."); - size_t scale_stride_x = 0; size_t scale_stride_y = 0; - NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); - size_t scale_k = scale_inv.shape[1]; - scale_stride_x = scale_k; - scale_stride_y = 1; - size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; - if (return_transpose) { + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, + "Unexpected rowwise enum value"); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = scale_k; + scale_stride_y = 1; + } + + if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { + NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE, + "Unexpected columnwise enum value"); NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); if (output_t.shape.size() > 0) { @@ -469,10 +483,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(output_t.dptr), reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, - pow2_scale);) // kAligned - ) // OutputType - ) // InputType + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, + columnwise_option, pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 22a50025df..1f146c7a33 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, nullptr, output, - dbias, workspace, stream); + detail::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -44,6 +44,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no NVTE_API_CALL(nvte_quantize_noop); using namespace transformer_engine; + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; @@ -51,8 +63,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, noop, output, - dbias, workspace, stream); + detail::quantize_helper( + input, grad, output, dbias, workspace, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, @@ -66,7 +78,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr const NVTETensor activation_input = nullptr; detail::quantize_helper( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, @@ -80,7 +92,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, @@ -94,7 +106,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, @@ -108,7 +120,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, @@ -122,7 +134,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, @@ -136,7 +148,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - activation_input, input, nullptr, output, dbias, workspace, stream); + activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index c6a8b0f23c..a599d88530 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1215,9 +1215,9 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) { @@ -1232,6 +1232,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe auto output_tensor = reinterpret_cast(output); auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); + + const QuantizationConfig *quant_config_cpp = + reinterpret_cast(quant_config); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); switch (output_tensor->scaling_mode) { @@ -1263,11 +1269,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - constexpr bool force_pow_2_scales = true; + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, + output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } @@ -1275,12 +1281,18 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - constexpr bool force_pow_2_scales = true; - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/0.0, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; + float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() + ? FP8BlockwiseRowwiseOption::ROWWISE + : FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = + output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE + : FP8BlockwiseColumnwiseOption::NONE; + quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 948a13a03e..79d6391e79 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,6 +14,7 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = [ "general_gemm", @@ -112,6 +113,10 @@ def general_gemm( # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + # There is not use_split_accumulator == False + # implementation for Float8BlockwiseQTensorBase GEMM + use_split_accumulator = True args = ( A, transa, # transa diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 338f1fcbb1..3b349e7f09 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,13 +167,13 @@ class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. DType dtype; - - private: // Options about how to quantize the tensor // Quantization scales are rounded down to powers of 2. bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + + private: int block_scaling_dim = 2; public: diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 1ef6f5258d..bf037fe931 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + // sanity check, since activation fusion is not supported for blockwise quantization yet + // need to raise an error here instead of silently going into act_func with wrong numerics + NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); } else { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2c3ccff154..84e50dea22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob if (te_output.numel() == 0) return out; + QuantizationConfigWrapper quant_config; + quant_config.set_noop_tensor(te_noop.data()); + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; + // this config is used for cs scaling factor computation + // because compute scale is cannot be fused with quantize kernel + // so in nvte_quantize_v2 with current scaling, the quant config is not used again quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, + at::cuda::getCurrentCUDAStream()); return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index cbdeee5b48..dae6ce42e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -150,6 +150,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -166,15 +167,18 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::cast(mu), py::cast(rsigma)}; @@ -293,6 +297,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { + QuantizationConfigWrapper quant_config; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); @@ -309,15 +314,18 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w allreduce_opts.reduceOp = c10d::ReduceOp::MAX; process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); } - QuantizationConfigWrapper quant_config; quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + auto my_quantizer_bw = static_cast(my_quantizer.get()); + quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); } - nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, + at::cuda::getCurrentCUDAStream()); } return {out, py::none(), py::cast(rsigma)}; diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 9ac6292e53..fbf31a7f5b 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -257,12 +257,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); - NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast(), - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires pow2 scales"); - NVTE_CHECK(quantizer.attr("amax_epsilon").cast() == 0.0, - "Pending additional parameters to the nvte_quantize API, " - "float8 block quantization requires amax_epsilon==0"); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, "Unsupported block scaling dim."); } diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index a873586032..e12990f79c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -69,6 +69,7 @@ std::vector fused_multi_quantize(std::vector input_list, nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); } else { for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + // TODO: switch to nvte_quantize_v2 with advanced numerical options nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e245b788b4..7a1fde164b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -24,10 +24,11 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -937,6 +938,74 @@ def _all_gather_fp8( return out, handle +def _all_gather_fp8_blockwise( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, # pylint: disable=unused-argument + quantizer: Optional[Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: + """ + All-gather FP8 tensor along first dimension for blockwise quantization. + + Returns: quantizer(gather(inp)) + + NOTE: The implementation is not sophisticated enough to honor async_op=True. + In some cases it falls back to synchronous gather and invokes the quantizer. + """ + + # Input tensor attributes + device: torch.device + dtype: torch.dtype + if isinstance(inp, torch.Tensor): + device = inp.device + dtype = inp.dtype + elif isinstance(inp, Float8BlockwiseQTensorBase): + if inp._rowwise_data is not None: + device = inp._rowwise_data.device + elif inp._columnwise_data is not None: + device = inp._columnwise_data.device + else: + raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") + dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " + f"found {inp.__class__.__name__})" + ) + world_size = get_distributed_world_size(process_group) + + # Check that quantizer is valid + if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): + raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") + if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): + raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") + + # Output tensor dims + if out_shape is None: + out_shape = list(inp.size()) + out_shape[0] *= world_size + + # Doing BF16 gather for now as baseline because it's simpler + if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) + out = quantizer(out) + return out, None + # Implementation of fp8 gather needs to account for: + # * Getting columnwise data as a transpose of how it is stored for GEMMS. + # * Gathering non GEMM swizzled scales. + # * Refer to scaffold code when implementing at: + # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 + raise NotImplementedError("fp8 blockwise allgather not yet implemented") + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1075,7 +1144,9 @@ def gather_along_first_dim( async_op: bool = False, quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: - """All-gather tensors and concatenate along first dimension.""" + """ + All-gather tensors and concatenate along first dimension. + """ # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) @@ -1100,6 +1171,16 @@ def gather_along_first_dim( out_shape=out_shape, ) + # FP8 block scaling case, block length = 128 + if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + return _all_gather_fp8_blockwise( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # MXFP8 case if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 38f829c079..c02ff73391 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import itertools import os from contextlib import contextmanager from collections import deque @@ -19,6 +20,7 @@ Format, MXFP8BlockScaling, Float8CurrentScaling, + Float8BlockScaling, ) from .constants import dist_group_type @@ -49,6 +51,17 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available""" + if ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ): + return True, "" + return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + + def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -109,6 +122,8 @@ class FP8GlobalStateManager: skip_fp8_weight_update_tensor = None mxfp8_available = None reason_for_no_mxfp8 = "" + fp8_block_scaling_available = None + reason_for_no_fp8_block_scaling = None @classmethod def reset(cls) -> None: @@ -134,6 +149,8 @@ def reset(cls) -> None: cls.skip_fp8_weight_update_tensor = None cls.mxfp8_available = None cls.reason_for_no_mxfp8 = "" + cls.fp8_block_scaling_available = None + cls.reason_for_no_fp8_block_scaling = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -161,6 +178,15 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @classmethod + def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: + """Return if Float8 block scaling support is available.""" + if cls.fp8_block_scaling_available is None: + cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( + check_fp8_block_scaling_support() + ) + return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -434,6 +460,9 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8 + if isinstance(fp8_recipe, Float8BlockScaling): + fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() + assert fp8_block_available, reason_for_no_fp8_block @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -786,8 +815,10 @@ def create( cls = MXFP8BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState + elif recipe.float8_block_scaling(): + cls = Float8BlockScalingRecipeState else: - raise ValueError("{recipe.__class__.__name__} is not supported") + raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( recipe, mode=mode, @@ -928,3 +959,108 @@ def make_quantizers(self) -> list: from .tensor.mxfp8_tensor import MXFP8Quantizer return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + + +class Float8BlockScalingRecipeState(RecipeState): + """Configuration for Float8BlockScaling quantization. + + Float8BlockScaling quantization does not require state, + but different quantizers use different modes. + """ + + recipe: Float8BlockScaling + mode: str + qx_dtype: tex.DType + qw_dtype: tex.DType + qgrad_dtype: tex.DType + + def __init__( + self, + recipe: Float8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.qx_dtype = get_fp8_te_dtype(recipe, True) + self.qw_dtype = get_fp8_te_dtype(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_blockwise_tensor import Float8BlockQuantizer + + if self.mode == "forward": + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward, and doesn't play nicely with QuantizeOp, + # which is not associated with a GEMM. + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qw_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] + ) + ) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] + ) + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 31a464caad..65f47a0817 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -23,6 +23,7 @@ MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -34,8 +35,10 @@ ) from ..constants import dist_group_type from ..tensor import QuantizedTensor, Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase __all__ = ["initialize_ub", "destroy_ub"] @@ -516,6 +519,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8CurrentScalingRecipeState ): return + if recipe.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -858,7 +865,13 @@ def grad_output_preprocess( if ctx.ub_overlap_ag: # Quantize the gradient if needed if not isinstance( - grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + grad_output, + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), ): grad_output = quantizer(grad_output) @@ -876,11 +889,21 @@ def grad_output_preprocess( # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: - if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + if isinstance(quantizer, Float8BlockQuantizer): + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9cd52b1e5..1ea66a7f2c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -91,6 +91,8 @@ def forward( # TODO Support Float8 Current Scaling # pylint: disable=fixme if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f49bad48c3..df3ae05f31 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -57,9 +57,11 @@ restore_from_saved, ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + from ..cpp_extensions import ( general_gemm, ) @@ -138,11 +140,6 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") - # Avoid quantized norm kernel if norm output will be returned - with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered - ) - tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output @@ -175,6 +172,18 @@ def forward( columnwise_usage = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + # Avoid quantized norm kernel if norm output will be returned + # or if a gather of ln_out must be in high precision. + force_hp_blockwise_ln_out_gather = ( + fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + with_quantized_norm = ( + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not force_hp_blockwise_ln_out_gather + ) + # Apply normalization nvtx_range_push(f"{nvtx_label}.norm") ln_out, mu, rsigma = apply_normalization( @@ -211,7 +220,7 @@ def forward( ln_out_total = input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: @@ -317,6 +326,7 @@ def forward( ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: @@ -327,6 +337,10 @@ def forward( if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # For force_hp_blockwise_ln_out_gather, we should + # be saving the unquantized ln_out to ctx. + assert not force_hp_blockwise_ln_out_gather + # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) @@ -605,11 +619,14 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + # async_op is not compatible with high precision gather since + # gather_along_first_dim does not offer callback chaining. + gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -690,6 +707,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather may have been done in BF16 + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7dae573688..e51fe43cc0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -52,7 +52,6 @@ in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, ) - from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing @@ -62,6 +61,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..tensor.quantized_tensor import ( @@ -104,17 +104,19 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } # no activation fusion written yet - # Per-tensor current scaling: [] - return { - "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), - "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, None), - "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, None), - } + # Per-tensor current scaling or fp8 blockwise scaling: [] + if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, None), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, None), + } + raise NotImplementedError(f"Unhandled recipe type {recipe}") def _act_func(activation: str, recipe: Optional[Recipe] = None): @@ -122,7 +124,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None): # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] - # Per-tensor current scaling: [] + # Per-tensor current scaling or fp8 blockwise scaling: [] funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -214,12 +216,20 @@ def forward( with_quantized_norm = ( fp8 and not return_layernorm_output and not return_layernorm_output_gathered ) + if isinstance(fc1_input_quantizer, Float8BlockQuantizer): + # Kernels not available for norm fusion. + with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_rs = ub_overlap_rs and is_grad_enabled backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + # TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe + force_hp_fc1_input_gather = ( + fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. + # Configure quantizer for norm output if fp8: if fc1_input_quantizer is None: @@ -261,12 +271,13 @@ def forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8: - ln_out = fc1_input_quantizer(ln_out) + if not force_hp_fc1_input_gather: + ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = fc1_input_quantizer(ln_out_total) else: if fp8: - if not with_quantized_norm: + if not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -282,7 +293,10 @@ def forward( quantizer=(fc1_input_quantizer if fp8 else None), ) else: - if fp8 and not with_quantized_norm: + # NOTE: force_hp_fc1_input_gather is redundant with else, but + # here for clarity. We should not quantize ln_out if bwd needs + # to gather in hp. + if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) ln_out_total = ln_out @@ -336,6 +350,7 @@ def forward( # - bias_gelu_fusion - only for full precision. # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer if activation != "gelu": + # blockwise scaled gemms don't support gemm_gelu_fusion in fwd. gemm_gelu_fusion = bias_gelu_fusion = False else: if fp8: @@ -376,7 +391,12 @@ def forward( act_out, _, fc1_out, _ = fc1_outputs else: fc1_out, *_ = fc1_outputs - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + # tex.quantize does not support GELU fusion for blockwise. + act_out = activation_func(fc1_out, None) + act_out = tex.quantize(act_out, fc2_input_quantizer) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -462,6 +482,8 @@ def forward( if not return_layernorm_output: clear_tensor_data(ln_out) ln_out = None + elif force_hp_fc1_input_gather: + assert not isinstance(ln_out, QuantizedTensor) if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None @@ -490,6 +512,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -505,6 +528,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -696,11 +720,12 @@ def backward( else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) + gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) else: ln_out_total = ln_out @@ -712,12 +737,13 @@ def backward( ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - # There are 5 possible fusion paths + # There are 6 possible fusion paths # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) ) @@ -753,6 +779,9 @@ def backward( if isinstance(grad_output, QuantizedTensor): grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + grad_arg = True + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + grad_arg = False fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( act_out, grad_output, @@ -764,14 +793,18 @@ def backward( ), quantization_params=None, # wgrad in high precision layout="NT", - grad=True, - bias=fc2_bias if fc2_bias_grad is None else None, + grad=grad_arg, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if fc2_bias_grad is None: + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None: + # BGRAD not fused with GEMM for float8 blockwise gemm. + fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ clear_tensor_data(act_out) # bias computation @@ -808,7 +841,14 @@ def backward( ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + # TODO float8 blockwise current scaling has no bgrad fusion for now + if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer): + fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) + dact = ctx.grad_fc1_output_quantizer(dact) + else: + fc1_bias_grad, dact = tex.bgrad_quantize( + dact, ctx.grad_fc1_output_quantizer + ) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 @@ -904,6 +944,13 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None + if ctx.fc1_input_quantizer is not None and not isinstance( + ln_out_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.fc1_input_quantizer(ln_out_total) # Make sure GEMM inputs have required data if isinstance(ln_out_total, QuantizedTensor): @@ -1556,7 +1603,8 @@ def _get_quantizers(self, fp8_output): fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( - rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + rowwise=True, + columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), ) fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ca9dd29043..2887b2e452 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -60,9 +60,10 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase - +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + __all__ = ["Linear"] @@ -130,6 +131,10 @@ def forward( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False + # TODO(kwyss): Support FP8 allgather for FP8 block quantization. + force_hp_input_gather = ( + fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) + ) # Perform TP communication in high precision. if fp8: assert_dim_for_fp8_exec(inputmat, weight) if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( @@ -143,19 +148,27 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: - if not isinstance(inputmat, QuantizedTensor): - columnwise_usage = backward_needs_input and isinstance( - input_quantizer, MXFP8Quantizer + if force_hp_input_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, tp_group, quantizer=input_quantizer + ) + else: + if not isinstance(inputmat, QuantizedTensor): + columnwise_usage = backward_needs_input and isinstance( + input_quantizer, MXFP8Quantizer + ) + # force_hp_input_gather should enforce this + assert not isinstance(input_quantizer, Float8BlockQuantizer) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + inputmat = input_quantizer(inputmat) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, ) - input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - inputmat = input_quantizer(inputmat) - own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=False) - inputmat_total, _ = gather_along_first_dim( - inputmat, - tp_group, - quantizer=input_quantizer, - ) else: if ( FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() @@ -277,6 +290,8 @@ def forward( # can be allgathered. if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if force_hp_input_gather: + assert not isinstance(inputmat, QuantizedTensor) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. @@ -323,8 +338,9 @@ def forward( ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -520,11 +536,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + gather_quantizer = None if ctx.force_hp_input_gather else quantizer inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, ctx.tp_group, async_op=True, - quantizer=quantizer, + quantizer=gather_quantizer, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -610,6 +627,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None + if ctx.input_quantizer is not None and not isinstance( + inputmat_total, QuantizedTensor + ): + # Async gather in BF16 does not asynchronously + # call quantizer after gather. + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmat_total = ctx.input_quantizer(inputmat_total) # Make sure GEMM inputs have required data if isinstance(inputmat_total, QuantizedTensor): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index b451acea9a..f4d4254537 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -23,6 +23,7 @@ from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext @@ -483,6 +484,12 @@ def _functional_forward( "Attempting to generate MXFP8 output tensor, " "but GEMM with MXFP8 output is not supported" ) + if isinstance(output_quantizer, Float8BlockQuantizer): + raise RuntimeError( + "Attempting to generate Float8BlockQuantized output tensor, " + "but GEMM with Float8BlockQuantized output is not supported" + ) + if output_quantizer is not None: output_quantizer.set_usage(rowwise=True, columnwise=False) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index ad32055479..802f4c25e3 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -17,6 +17,7 @@ from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, + Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, fp8_autocast, @@ -219,6 +220,11 @@ def _reset_quantization_recipe_state( if num_quantizers == 0: continue + if recipe.float8_block_scaling(): + raise NotImplementedError( + "Fusible operations do not support FP8 block scaling recipe" + ) + # Construct quantization recipe state recipe_state = RecipeState.create( recipe, @@ -260,8 +266,13 @@ def _update_quantization_recipe_state( continue recipe_state = self._fp8_metas[mode][fp8_meta_key] need_to_reset_recipe_state = ( - recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) - ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) + or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + or ( + recipe.float8_block_scaling() + and not isinstance(recipe_state, Float8BlockScalingRecipeState) + ) + ) if need_to_reset_recipe_state: self._reset_quantization_recipe_state(recipe=recipe) return diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 9135237854..7dc380606d 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -36,8 +36,8 @@ class Float8BlockwiseQTensorBase: def __new__( cls, *args, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, @@ -71,10 +71,16 @@ def get_metadata(self) -> Dict[str, Any]: def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: - """Prepare the tensor base for saving for backward""" + """ + Prepare the tensor base for saving for backward + + This does not clear the tensors currently, because with PP config + that clears the weight cache between micro-batches. If the rowwise + data is not required for backward, this is a possible memory + pessimization, but is consistent with the other quantized tensor + classes. + """ tensors = [self._rowwise_data, self._columnwise_data] - self._rowwise_data = None - self._columnwise_data = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 138d1fd29e..695c5ffb8c 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -44,7 +44,6 @@ def __init__( block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - assert rowwise self.dtype = fp8_dtype self.block_len = 128 self.force_pow_2_scales = force_pow_2_scales @@ -168,6 +167,11 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: colwise_shape.append(shape[i]) return tuple(colwise_shape) + # TODO(kwyss): With FP8 gather support, we need to implement a + # shape/layout/swizzle check to know whether FP8 gather works + # cleanly by stacking data without aliasing tiles and whether + # the scales also stack on the proper dimensions. + def make_empty( self, shape: Iterable[int], @@ -181,13 +185,16 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, - dtype=torch.float32, - device=device, - ) + data = None + scale_inv = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -489,7 +496,6 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv - dst.dtype = src.dtype # Check that tensor dimensions match if ( From 2856c3e09e935baf2d3fbfc583739873a2dfbfeb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 10 Apr 2025 18:16:25 -0700 Subject: [PATCH 15/53] Add user to TE CI (#1669) Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 8c3b84b1b9..f86fdd1066 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -50,6 +50,7 @@ jobs: || github.actor == 'BestJuly' || github.actor == 'xiaopoc' || github.actor == 'jreiffers' + || github.actor == 'lhb8125' ) steps: - name: Check if comment is issued by authorized person From d91ed12283934b9f2d420cd4489f1bf1fd9a4684 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 11 Apr 2025 07:44:20 -0700 Subject: [PATCH 16/53] add wgrad split for layernorm_mlp Signed-off-by: Hongbin Liu --- tests/pytorch/test_numerics.py | 66 ++++++++++- transformer_engine/pytorch/module/_common.py | 7 +- transformer_engine/pytorch/module/base.py | 20 +++- .../pytorch/module/grouped_linear.py | 5 +- .../pytorch/module/layernorm_linear.py | 16 +-- .../pytorch/module/layernorm_mlp.py | 108 +++++++++++++----- transformer_engine/pytorch/module/linear.py | 16 +-- 7 files changed, 174 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 946f0c74d9..5d7338ee69 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1044,8 +1044,8 @@ def _test_granular_accuracy(block, bs, dtype, config, split_bw=False): out = out[0] loss = out.sum() loss.backward() - if split_bw and hasattr(block, "wgrad_comp"): - block.wgrad_comp() + if split_bw: + block.backward_dw() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] @@ -1554,6 +1554,64 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("activation", all_activations) +@pytest.mark.parametrize("normalization", all_normalizations) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +def test_layernorm_mlp_accuracy_split_bw(dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation): + config = model_configs[model] + + ln_mlp_split_bw = LayerNormMLP( + hidden_size=config.hidden_size, + ffn_hidden_size=4 * config.hidden_size, + eps=config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + device="cuda", + split_bw=True, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + ln_mlp_ref = LayerNormMLP( + hidden_size=config.hidden_size, + ffn_hidden_size=4 * config.hidden_size, + eps=config.eps, + bias=bias, + normalization=normalization, + params_dtype=dtype, + device="cuda", + split_bw=False, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + + + # Share params + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp_split_bw.layer_norm_weight.clone()) + if normalization != "RMSNorm": + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp_split_bw.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp_split_bw.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp_split_bw.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp_split_bw.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp_split_bw.fc2_bias.clone()) + if fuse_wgrad_accumulation: + ln_mlp_split_bw.fc1_weight.main_grad = torch.rand_like(ln_mlp_split_bw.fc1_weight, dtype=torch.float32) + ln_mlp_ref.fc1_weight.main_grad = ln_mlp_split_bw.fc1_weight.main_grad.clone() + ln_mlp_split_bw.fc2_weight.main_grad = torch.rand_like(ln_mlp_split_bw.fc2_weight, dtype=torch.float32) + ln_mlp_ref.fc2_weight.main_grad = ln_mlp_split_bw.fc2_weight.main_grad.clone() + + te_outputs = _test_granular_accuracy(ln_mlp_split_bw, bs, dtype, config, split_bw=True) + te_outputs_ref = _test_granular_accuracy(ln_mlp_ref, bs, dtype, config, split_bw=False) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + def _test_grouped_linear_accuracy( block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw=False @@ -1601,10 +1659,10 @@ def _test_grouped_linear_accuracy( loss.backward() if split_bw: if isinstance(block, GroupedLinear): - block.wgrad_comp() + block.backward_dw() else: for i in range(num_gemms): - block[i].wgrad_comp() + block[i].backward_dw() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 11d2e40b96..e1c98c4810 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -281,8 +281,11 @@ def pop(self): tensor_list, func = self.context.get() return func(*tensor_list), tensor_list else: - rank = torch.distributed.get_rank() - raise Exception(f"Pop empty queue. rank {rank}") + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + raise Exception(f"Pop empty queue. rank {rank}") + else: + raise Exception("Pop empty queue. No distributed environment detected.") def assert_empty(self): """ diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 31a464caad..909f9df4e6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -18,7 +18,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from ._common import _ParameterInitMeta +from ._common import _ParameterInitMeta, noop_cat from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, @@ -1081,3 +1081,21 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + + def backward_dw(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ + (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() + if not self.fuse_wgrad_accumulation: + unfused_weights = [getattr(self, name) for name in self.weight_names] + weight_tensor = noop_cat(unfused_weights) + if weight_tensor.grad is None: + weight_tensor.grad = wgrad.to(weight_tensor.dtype) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + if bias_tensor.grad is None: + bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) + del grad_bias_ + del wgrad diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 265958a884..b24deac6b0 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -693,7 +693,8 @@ def forward( return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out - def wgrad_comp(self): + + def backward_dw(self): """ Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. @@ -714,3 +715,5 @@ def wgrad_comp(self): if bias_param.grad is None: bias_param.grad = grad_biases_[i].to(bias_param.dtype) del grad_biases_ + del wgrad_list + del tensor_list \ No newline at end of file diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 070af70da9..86b47b26b9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1469,7 +1469,8 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - def wgrad_comp(self): + + def backward_dw(self): """ Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. @@ -1477,15 +1478,4 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_LayerNormLinear_wgrad"): - (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() - if not self.fuse_wgrad_accumulation: - unfused_weights = [getattr(self, name) for name in self.weight_names] - weight_tensor = noop_cat(unfused_weights) - if weight_tensor.grad is None: - weight_tensor.grad = wgrad.to(weight_tensor.dtype) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - if bias_tensor.grad is None: - bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) - del grad_bias_ - del wgrad + super().backward_dw() \ No newline at end of file diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7dae573688..e9aacc8e9a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -8,6 +8,7 @@ from typing import Callable, Optional, Tuple, Union from functools import reduce from operator import mul as multiply_op +import functools import torch from torch.nn.parameter import Parameter @@ -62,7 +63,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ._common import apply_normalization, _fix_gathered_fp8_transpose +from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..tensor.quantized_tensor import ( QuantizedTensor, @@ -148,6 +149,7 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, + wgrad_store: WeightGradStore, fuse_wgrad_accumulation: bool, fc1_input_quantizer: Optional[Quantizer], fc1_weight_quantizer: Optional[Quantizer], @@ -540,6 +542,8 @@ def forward( if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store + # Row Parallel Linear if ub_overlap_rs: fc2_out = rs_out @@ -753,26 +757,31 @@ def backward( if isinstance(grad_output, QuantizedTensor): grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) - fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( - act_out, - grad_output, - get_workspace(), - out_dtype=( - origin_fc2_weight.main_grad.dtype - if ctx.fuse_wgrad_accumulation - else ctx.activation_dtype - ), - quantization_params=None, # wgrad in high precision + general_gemm_fc2_wgrad = functools.partial(general_gemm, + out_dtype=origin_fc2_weight.main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, + workspace=get_workspace(), layout="NT", grad=True, - bias=fc2_bias if fc2_bias_grad is None else None, + bias=(fc2_bias if fc2_bias_grad is None else None), accumulate=accumulate_wgrad_into_param_main_grad, use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) - if fc2_bias_grad is None: - fc2_bias_grad = fc2_bias_grad_ - clear_tensor_data(act_out) + if ctx.wgrad_store.split_bw(): + ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) + fc2_wgrad = None + fc2_bias_grad = None + # (fc2_wgrad, fc2_bias_grad_, *_), _ = ctx.wgrad_store.pop() + else: + fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad( + act_out, + grad_output, + ) + + if fc2_bias_grad is None: + fc2_bias_grad = fc2_bias_grad_ + if not ctx.wgrad_store.split_bw(): + clear_tensor_data(act_out) # bias computation fc1_bias_grad = None @@ -919,18 +928,12 @@ def backward( ) # wgrad GEMM - fc1_wgrad_outputs = general_gemm( - ln_out_total, - dact, - get_workspace(), - out_dtype=( - origin_fc1_weight.main_grad.dtype - if ctx.fuse_wgrad_accumulation - else ctx.activation_dtype - ), + general_gemm_fc1_wgrad = functools.partial(general_gemm, + out_dtype=origin_fc1_weight.main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, + workspace=get_workspace(), layout="NT", grad=fuse_gemm_and_bias_fc1_wgrad, - bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, + bias=(fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None), accumulate=accumulate_wgrad_into_param_main_grad, out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub=ub_obj_fc1_wgrad, @@ -938,13 +941,24 @@ def backward( extra_output=fc1_dgrad_rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + if ctx.wgrad_store.split_bw(): + ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) + fc1_wgrad = None + # (fc1_wgrad_outputs), _ = ctx.wgrad_store.pop() + if fuse_gemm_and_bias_fc1_wgrad: + fc1_bias_grad = None + else: + fc1_wgrad_outputs = general_gemm_fc1_wgrad( + ln_out_total, + dact, + ) - clear_tensor_data(ln_out_total, dact) + clear_tensor_data(ln_out_total, dact) - if fuse_gemm_and_bias_fc1_wgrad: - fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs - else: - fc1_wgrad, *_ = fc1_wgrad_outputs + if fuse_gemm_and_bias_fc1_wgrad: + fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs + else: + fc1_wgrad, *_ = fc1_wgrad_outputs if ctx.ub_bulk_wgrad: if ub_obj_fc1_wgrad.is_fp8_ubuf(): @@ -1061,6 +1075,7 @@ def backward( None, # is_first_microbatch None, # fp8 None, # fp8_calibration + None, # wgrad_store None, # fuse_wgrad_accumulation None, # fc1_input_quantizer None, # fc1_weight_quantizer @@ -1220,6 +1235,7 @@ def __init__( ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, + split_bw: bool = False, ) -> None: super().__init__() @@ -1246,6 +1262,8 @@ def __init__( and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) + self.wgrad_store = WeightGradStore(split_bw, ub_bulk_wgrad) + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1487,6 +1505,7 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, + self.wgrad_store, self.fuse_wgrad_accumulation, fc1_input_quantizer, fc1_weight_quantizer, @@ -1651,3 +1670,32 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + + def backward_dw(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ + if not self.wgrad_store.split_bw(): + return + with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): + (fc2_wgrad, fc2_bias_grad_, *_), _ = self.wgrad_store.pop() + if self.use_bias and self.fc1_bias.grad is None: + (fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop() + else: + (fc1_wgrad, *_), _ = self.wgrad_store.pop() + fc1_bias_grad = None + if self.use_bias: + if self.fc2_bias.grad is None: + self.fc2_bias.grad = fc2_bias_grad_.to(self.fc2_bias.dtype) + if self.fc1_bias.grad is None: + self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype) + if not self.fuse_wgrad_accumulation: + if self.fc2_weight.grad is None: + self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) + if self.fc1_weight.grad is None: + self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) + del fc2_bias_grad_ + del fc2_wgrad + del fc1_wgrad + del fc1_bias_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 27d4811799..a394919af2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1275,7 +1275,8 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group - def wgrad_comp(self): + + def backward_dw(self): """ Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. @@ -1283,15 +1284,4 @@ def wgrad_comp(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_Linear_wgrad"): - (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() - if not self.fuse_wgrad_accumulation: - unfused_weights = [getattr(self, name) for name in self.weight_names] - weight_tensor = noop_cat(unfused_weights) - if weight_tensor.grad is None: - weight_tensor.grad = wgrad.to(weight_tensor.dtype) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - if bias_tensor.grad is None: - bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) - del grad_bias_ - del wgrad + super().backward_dw() \ No newline at end of file From 3b38bb4e340f58412c840369c6ac83a0179f92ee Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 11 Apr 2025 08:10:39 -0700 Subject: [PATCH 17/53] minor fix Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/layernorm_mlp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index adfd2f8486..4f525f3748 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1740,8 +1740,7 @@ def backward_dw(self): fc1_bias_grad = None if self.use_bias: if self.fc2_bias.grad is None: - if self.fp8 and self.fp8_recipe.float8_block_scaling() - and self.apply_bias and not self.gemm_bias_unfused_add: + if self.fp8 and self.fp8_recipe.float8_block_scaling() and self.apply_bias and not self.gemm_bias_unfused_add: act_out = tensor_list_fc2[0] # BGRAD not fused with GEMM for float8 blockwise gemm. fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) From 7ec418299937e98a5ca9c841d0d45cf5d658da95 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:11:15 +0000 Subject: [PATCH 18/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_numerics.py | 16 ++++++++---- .../pytorch/module/grouped_linear.py | 3 +-- .../pytorch/module/layernorm_linear.py | 3 +-- .../pytorch/module/layernorm_mlp.py | 26 ++++++++++++++----- transformer_engine/pytorch/module/linear.py | 3 +-- 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 9b22a9da83..c574704c59 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1562,6 +1562,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) @@ -1569,7 +1570,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -def test_layernorm_mlp_accuracy_split_bw(dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation): +def test_layernorm_mlp_accuracy_split_bw( + dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation +): config = model_configs[model] ln_mlp_split_bw = LayerNormMLP( @@ -1594,8 +1597,7 @@ def test_layernorm_mlp_accuracy_split_bw(dtype, bs, model, activation, normaliza device="cuda", split_bw=False, fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ).eval() - + ).eval() # Share params with torch.no_grad(): @@ -1608,9 +1610,13 @@ def test_layernorm_mlp_accuracy_split_bw(dtype, bs, model, activation, normaliza ln_mlp_ref.fc1_bias = Parameter(ln_mlp_split_bw.fc1_bias.clone()) ln_mlp_ref.fc2_bias = Parameter(ln_mlp_split_bw.fc2_bias.clone()) if fuse_wgrad_accumulation: - ln_mlp_split_bw.fc1_weight.main_grad = torch.rand_like(ln_mlp_split_bw.fc1_weight, dtype=torch.float32) + ln_mlp_split_bw.fc1_weight.main_grad = torch.rand_like( + ln_mlp_split_bw.fc1_weight, dtype=torch.float32 + ) ln_mlp_ref.fc1_weight.main_grad = ln_mlp_split_bw.fc1_weight.main_grad.clone() - ln_mlp_split_bw.fc2_weight.main_grad = torch.rand_like(ln_mlp_split_bw.fc2_weight, dtype=torch.float32) + ln_mlp_split_bw.fc2_weight.main_grad = torch.rand_like( + ln_mlp_split_bw.fc2_weight, dtype=torch.float32 + ) ln_mlp_ref.fc2_weight.main_grad = ln_mlp_split_bw.fc2_weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_mlp_split_bw, bs, dtype, config, split_bw=True) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 7fbb9aedcc..473e19b8aa 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -695,7 +695,6 @@ def forward( return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out - def backward_dw(self): """ Execute the delayed weight gradient computation. @@ -718,4 +717,4 @@ def backward_dw(self): bias_param.grad = grad_biases_[i].to(bias_param.dtype) del grad_biases_ del wgrad_list - del tensor_list \ No newline at end of file + del tensor_list diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 82d4afb8c6..7a2857670c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1493,7 +1493,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - def backward_dw(self): """ Execute the delayed weight gradient computation. @@ -1502,4 +1501,4 @@ def backward_dw(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_LayerNormLinear_wgrad"): - super().backward_dw() \ No newline at end of file + super().backward_dw() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4f525f3748..7f9ac1aa1b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -786,13 +786,13 @@ def backward( grad_arg = True if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False - general_gemm_fc2_wgrad = functools.partial(general_gemm, + general_gemm_fc2_wgrad = functools.partial( + general_gemm, out_dtype=( origin_fc2_weight.main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - workspace=get_workspace(), layout="NT", grad=grad_arg, @@ -812,7 +812,11 @@ def backward( ) if fc2_bias_grad is None: - if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None: + if ( + ctx.fp8 + and ctx.fp8_recipe.float8_block_scaling() + and fc2_bias is not None + ): # BGRAD not fused with GEMM for float8 blockwise gemm. fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) @@ -980,8 +984,13 @@ def backward( ) # wgrad GEMM - general_gemm_fc1_wgrad = functools.partial(general_gemm, - out_dtype=origin_fc1_weight.main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype, + general_gemm_fc1_wgrad = functools.partial( + general_gemm, + out_dtype=( + origin_fc1_weight.main_grad.dtype + if ctx.fuse_wgrad_accumulation + else ctx.activation_dtype + ), workspace=get_workspace(), layout="NT", grad=fuse_gemm_and_bias_fc1_wgrad, @@ -1740,7 +1749,12 @@ def backward_dw(self): fc1_bias_grad = None if self.use_bias: if self.fc2_bias.grad is None: - if self.fp8 and self.fp8_recipe.float8_block_scaling() and self.apply_bias and not self.gemm_bias_unfused_add: + if ( + self.fp8 + and self.fp8_recipe.float8_block_scaling() + and self.apply_bias + and not self.gemm_bias_unfused_add + ): act_out = tensor_list_fc2[0] # BGRAD not fused with GEMM for float8 blockwise gemm. fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e48c726d92..77a56ef83a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1299,7 +1299,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group - def backward_dw(self): """ Execute the delayed weight gradient computation. @@ -1308,4 +1307,4 @@ def backward_dw(self): if not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_Linear_wgrad"): - super().backward_dw() \ No newline at end of file + super().backward_dw() From dfb3c486977c5e8e54a50d7e721f4882a7a1970d Mon Sep 17 00:00:00 2001 From: kwyss-nvidia Date: Fri, 11 Apr 2025 10:50:05 -0700 Subject: [PATCH 19/53] Make shape cache invalidation more conservative. (#1670) Repeated calls to nvte_shape should not invalidate previous data pointers. It would be possible to avoid unnecessary comparisons by duplicating some of the logic from shape() so that the cache is only relevant when columnwise shapes are involved. Whether this code duplication is preferable to the comparisons that arise from by value semantics of reusing shape is a judgment call. Signed-off-by: Keith Wyss --- transformer_engine/common/common.h | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 728f8ad147..a852bda410 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -195,7 +195,18 @@ struct Tensor { } const std::vector &rowwise_shape_ref() const { - rowwise_shape_cache = shape(); + auto shape_queried = shape(); + // This method is primarily designed for nvte_shape. + // An unfortunate consequence of unconditionally assigning + // values to rowwise_shape_cache without a check is that + // repeated calls to rowwise_shape_ref are likely to + // invalidate the data pointers from previous calls. + // If the shape has changed, then invalidating is necessary + // in at least some cases, but we want to keep the data + // valid otherwise. + if (rowwise_shape_cache != shape_queried) { + rowwise_shape_cache = std::move(shape_queried); + } return rowwise_shape_cache; } From 04642bf8910da5b800b10d98e32ad1a310e106cd Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 11 Apr 2025 13:39:13 -0700 Subject: [PATCH 20/53] [PyTorch] Add option in activation ops to cache input in FP8 (#1665) * Add option to cache activation input in FP8 Signed-off-by: Tim Moon * Avoid casting to FP8 transpose Signed-off-by: Tim Moon * Skip input caching if device is not supported Signed-off-by: Tim Moon * Add documentation that FP8 input caching is experimental Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 10 +++++-- .../pytorch/ops/basic/activation.py | 30 ++++++++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 59af228861..c2b32ca272 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1394,6 +1394,7 @@ def test_make_extra_output( @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("cache_quantized_input", (False, True)) def test_activation( self, *, @@ -1402,6 +1403,7 @@ def test_activation( dtype: torch.dtype, device: torch.device = "cuda", quantization: Optional[str], + cache_quantized_input: bool, ) -> None: """Activation functions""" @@ -1413,6 +1415,8 @@ def test_activation( # Skip invalid configurations quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) + if cache_quantized_input: + maybe_skip_quantization("fp8", device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1463,7 +1467,7 @@ def test_activation( )[activation] forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantized_compute), - make_op(), + make_op(cache_quantized_input=cache_quantized_input), te_ops.Quantize(forward=quantized_compute, backward=False), ) with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): @@ -1472,9 +1476,9 @@ def test_activation( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute: + if quantized_compute or cache_quantized_input: tols = dtype_tols(tex.DType.kFloat8E4M3) - if activation == "relu": + if activation == "relu" and not cache_quantized_input: tols = {"atol": 0, "rtol": 0} # Check results diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 45c78bea87..aa0bb1a52b 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -13,6 +13,7 @@ import transformer_engine_torch as tex from ...fp8 import FP8GlobalStateManager from ...tensor import QuantizedTensor +from ...tensor.float8_tensor import Float8CurrentScalingQuantizer from ...utils import clear_tensor_data, devices_match from ..op import BasicOperation, OperationContext from .._common import reshape @@ -37,8 +38,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): the first half of the input tensor, while PyTorch applies it to the second half. + Parameters + ---------- + cache_quantized_input: bool, default = False + Quantize input tensor when caching for use in the backward + pass. This will typically reduce memory usage but require + extra compute and increase numerical error. This feature is + highly experimental. + """ + def __init__(self, *, cache_quantized_input: bool = False): + super().__init__() + self.cache_quantized_input: bool = cache_quantized_input + @abc.abstractmethod def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: """Forward implementation @@ -100,9 +113,16 @@ def op_forward( if y.dim() != x.dim(): y = y.reshape(list(x.shape[:-1]) + [-1]) + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + quantizer.set_usage(rowwise=True, columnwise=False) + x = quantizer(x) + # Save state for backward pass ctx.save_for_backward(x.detach()) ctx.fp8_enabled = fp8_enabled + ctx.dtype = dtype ctx.prev_op = prev_op return y @@ -116,10 +136,18 @@ def op_backward( # Saved tensors from forward pass (x,) = ctx.saved_tensors + # Check input tensor + if isinstance(x, QuantizedTensor): + x = x.dequantize(dtype=ctx.dtype) + elif x.dtype != ctx.dtype: + x = x.to(dtype=ctx.dtype) + if not x.is_contiguous(): + x = x.contiguous() + # Check grad output tensor dy = grad_output if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() + dy = dy.dequantize(dtype=ctx.dtype) if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: dy = dy.to(device=x.device, dtype=x.dtype) if not dy.is_contiguous(): From c638c4362b0286e032a74d49055b890c4ca27061 Mon Sep 17 00:00:00 2001 From: linxiddd Date: Sat, 12 Apr 2025 09:37:04 +0800 Subject: [PATCH 21/53] [QA] Extend error handling (#1660) [QA] Add error handling - Standardize test failure handling using the unified 'test_fail' function and 'error_exit' function Signed-off-by: Linxi Ding Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- qa/L0_jax_unittest/test.sh | 6 ++++-- qa/L2_jax_unittest/test.sh | 39 +++++++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 7989eaf528..4cde60b5a9 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -2,6 +2,8 @@ # # See LICENSE for license information. +set -x + function error_exit() { echo "Error: $1" exit 1 @@ -23,10 +25,10 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" # Test without custom calls -NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py without TE custom calls" +NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "mnist" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" # Make encoder tests to have run-to-run deterministic to have the stable CI results diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index ec651a1317..59212e38e1 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -2,22 +2,43 @@ # # See LICENSE for license information. -set -xe +set -x -pip install "nltk>=3.8.2" -pip install pytest==8.2.1 +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" + +pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${TE_PATH:=/opt/transformerengine} -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "tests/jax/*not_distributed_*" # Test without custom calls -NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py +NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" -pip install -r $TE_PATH/examples/jax/mnist/requirements.txt -pip install -r $TE_PATH/examples/jax/encoder/requirements.txt +pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" +pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "mnist" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 From d9eb0582b5bebce3354e1a2ed249c8efa688cf2d Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Mon, 14 Apr 2025 00:02:55 -0700 Subject: [PATCH 22/53] [PyTorch] Added attention activation offloading support for TE v2.0 (#1671) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added attention activation offloading support for TE v2.0 Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0d442435bf..8a3f259575 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -80,6 +80,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb +from .cpu_offload import set_offloading_param # Setup Attention Logging @@ -4324,7 +4325,7 @@ def forward( tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] for tensor in tensor_list: if tensor is not None: - tensor.activation_offloading = True + set_offloading_param(tensor, "activation_offloading", True) with self.attention_dropout_ctx(): # | API | use cases @@ -4726,12 +4727,14 @@ def forward( else: tensor_list = [q, k, v, out_save] - tensor_list.extend(aux_ctx_tensors) - qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: if tensor is not None: - tensor.activation_offloading = True + set_offloading_param(tensor, "activation_offloading", True) + + for tensor in aux_ctx_tensors: + if tensor is not None: + set_offloading_param(tensor, "activation_offloading", True) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 From c8e7cc024ee2f5af8fbb143bba0a04476770564f Mon Sep 17 00:00:00 2001 From: Autumn1998 <1515848689@qq.com> Date: Mon, 14 Apr 2025 16:19:34 +0800 Subject: [PATCH 23/53] [MoE] Support new fp8 recipes for permute_fusion (#1649) * add support for new recipe on permute_fusion, rm fp unpermute Signed-off-by: tongliu * fix lint Signed-off-by: Xin Yao * remove fp8 from index map Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * skip unsupported tests Signed-off-by: Xin Yao --------- Signed-off-by: tongliu Signed-off-by: Xin Yao Co-authored-by: tongliu Co-authored-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_permutation.py | 451 +++++++----------- .../common/permutation/permutation.cu | 4 +- .../pytorch/csrc/extensions/permutation.cu | 39 +- transformer_engine/pytorch/permutation.py | 302 ++++++------ .../pytorch/triton/permutation.py | 95 ++-- 5 files changed, 384 insertions(+), 507 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 0dc183e298..07b5f7c529 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -8,6 +8,7 @@ import pytest from typing import Dict, List +from transformer_engine.common import recipe from transformer_engine.pytorch import ( moe_permute as te_permute, moe_permute_with_probs as te_permute_with_probs, @@ -17,9 +18,14 @@ ) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex - +import copy seed = 1234 torch.manual_seed(seed) @@ -234,7 +240,6 @@ def _test_permutation_index_map( f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -242,48 +247,12 @@ def _test_permutation_index_map( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - permute_fwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - permute_bwd_input = torch.rand( - size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - unpermute_bwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - _permute_fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _permute_bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _unpermute_bwd_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) - permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) - unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input) - - pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() - pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input.requires_grad_(True) @@ -323,9 +292,9 @@ def _test_permutation_index_map( # TE Permutation # ################################################################################################################################### - te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input = pytorch_permute_fwd_input.detach() te_permute_fwd_input.requires_grad_(True) - te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + te_permute_bwd_input = pytorch_permute_bwd_input.detach() te_permute_output, row_id_map = te_permute( te_permute_fwd_input, indices, num_out_tokens, map_type="index" @@ -338,7 +307,7 @@ def _test_permutation_index_map( te_probs.requires_grad_(True) te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input.requires_grad_(True) - te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() te_unpermute_output = te_unpermute( te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" @@ -352,16 +321,10 @@ def _test_permutation_index_map( ################################################################################################################################### tols = dtype_tols(te_dtype) - if fp8: - te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) - te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) - else: - te_permute_output_ = te_permute_output.float() - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.float() - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() torch.testing.assert_close( pytorch_permute_output.float(), @@ -487,7 +450,6 @@ def _test_permutation_mask_map( f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -495,49 +457,12 @@ def _test_permutation_mask_map( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - permute_fwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - permute_bwd_input = torch.rand( - size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - unpermute_bwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - - _permute_fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _permute_bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _unpermute_bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) - permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) - unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input) - - pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() - pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input.requires_grad_(True) @@ -553,10 +478,7 @@ def _test_permutation_mask_map( probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) probs = probs / row_sums - if fp8: - probs = probs.to(torch.float16) - else: - probs = probs.to(dtype) + probs = probs.to(dtype) probs.requires_grad_(True) ################################################################################################################################### @@ -582,9 +504,9 @@ def _test_permutation_mask_map( # TE Permutation # ################################################################################################################################### - te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input = pytorch_permute_fwd_input.detach() te_permute_fwd_input.requires_grad_(True) - te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + te_permute_bwd_input = pytorch_permute_bwd_input.detach() te_permute_output, row_id_map = te_permute( te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" @@ -597,7 +519,7 @@ def _test_permutation_mask_map( te_probs.requires_grad_(True) te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input.requires_grad_(True) - te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() te_unpermute_output = te_unpermute( te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" @@ -611,16 +533,10 @@ def _test_permutation_mask_map( ################################################################################################################################### tols = dtype_tols(te_dtype) - if fp8: - te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) - te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) - else: - te_permute_output_ = te_permute_output.float() - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.float() - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() torch.testing.assert_close( pytorch_permute_output.float(), @@ -730,6 +646,118 @@ def _test_permutation_mask_map( print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_mask_map_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + recipe, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + if recipe.delayed(): + quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + elif recipe.float8_current_scaling(): + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=te_dtype, + device=torch.device("cuda"), + columnwise=False, + ) + elif recipe.float8_block_scaling(): + quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + amax_epsilon=0.0, + force_pow_2_scales=True, # Fp8 sub-channel a2a requires e8 scales + block_scaling_dim=1, # 1x128 scaling + ) + elif recipe.mxfp8(): + quantizer = MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + ) + else: + raise ValueError("Unsupported FP8 recipe") + + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + # Make an empty fp8 tensor + permute_fwd_input_fp8 = quantizer.make_empty( + permute_fwd_input.shape, + dtype=permute_fwd_input.dtype, + device=permute_fwd_input.device, + ) + # quantize the tensor + quantizer.update_quantized(permute_fwd_input, permute_fwd_input_fp8) + if recipe.float8_block_scaling(): + pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data) + pytorch_permute_fwd_scale_input = copy.deepcopy( + permute_fwd_input_fp8._rowwise_scale_inv.T.contiguous() + ) + elif recipe.mxfp8(): + pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data) + pytorch_permute_fwd_scale_input = copy.deepcopy( + permute_fwd_input_fp8._rowwise_scale_inv.contiguous() + ) + else: + pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._data) + pytorch_permute_fwd_scale_input = None + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + # PyTorch Permutaion + pytorch_permute_output, _ = pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map) + if pytorch_permute_fwd_scale_input is not None: + pytorch_permute_scale_output, _ = pytorch_permute_mask_map( + pytorch_permute_fwd_scale_input, routing_map + ) + + # TE Permutation + permute_output, _ = te_permute( + permute_fwd_input_fp8, routing_map, num_out_tokens=num_out_tokens, map_type="mask" + ) + if recipe.float8_block_scaling(): + te_permute_output = permute_output._rowwise_data + te_permute_scale_output = permute_output._rowwise_scale_inv.T.contiguous() + elif recipe.mxfp8(): + te_permute_output = permute_output._rowwise_data + te_permute_scale_output = permute_output._rowwise_scale_inv.contiguous() + else: + te_permute_output = permute_output._data + te_permute_scale_output = None + + # check the permute output + torch.testing.assert_close( + pytorch_permute_output, + te_permute_output, + atol=0, + rtol=0, + ) + if recipe.float8_block_scaling() or recipe.mxfp8(): + torch.testing.assert_close( + pytorch_permute_scale_output, + te_permute_scale_output, + atol=0, + rtol=0, + ) + + def _test_moe_chunk_sort( te_dtype, num_tokens, @@ -743,7 +771,6 @@ def _test_moe_chunk_sort( f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -751,34 +778,11 @@ def _test_moe_chunk_sort( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") - bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") - - _fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _bwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - fwd_input = _fwd_input_quantizer.quantize(fwd_input) - bwd_input = _bwd_input_quantizer.quantize(bwd_input) - - pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16) - pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_fwd_input.requires_grad_(True) @@ -806,9 +810,9 @@ def _test_moe_chunk_sort( # TE Permutation # ################################################################################################################################### - te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach() + te_fwd_input = pytorch_fwd_input.detach() te_fwd_input.requires_grad_(True) - te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach() + te_bwd_input = pytorch_bwd_input.detach() te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) te_output.backward(te_bwd_input, retain_graph=True) @@ -820,12 +824,8 @@ def _test_moe_chunk_sort( ################################################################################################################################### tols = dtype_tols(te_dtype) - if fp8: - te_output_ = te_output.dequantize(dtype=torch.float32) - te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32) - else: - te_output_ = te_output.float() - te_fwd_input_grad = te_fwd_input.grad.float() + te_output_ = te_output.float() + te_fwd_input_grad = te_fwd_input.grad.float() torch.testing.assert_close( pytorch_output.float(), @@ -899,7 +899,6 @@ def _test_permutation_mask_map_alongside_probs( f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) - fp8 = False # Convert TE dtypes to PyTorch dtypes if te_dtype == tex.DType.kFloat32: dtype = torch.float32 @@ -907,38 +906,11 @@ def _test_permutation_mask_map_alongside_probs( dtype = torch.float16 elif te_dtype == tex.DType.kBFloat16: dtype = torch.bfloat16 - elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): - dtype = torch.uint8 - fp8 = True else: pytest.skip("Invalid dtype.") - if fp8: - permute_fwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - unpermute_bwd_input = torch.rand( - size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" - ) - - _permute_fwd_input_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - _unpermute_bwd_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input) - unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input) - - pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) - else: - pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input.requires_grad_(True) @@ -952,10 +924,7 @@ def _test_permutation_mask_map_alongside_probs( probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) probs = probs / row_sums - if fp8: - probs = probs.to(torch.float16) - else: - probs = probs.to(dtype) + probs = probs.to(dtype) probs.requires_grad_(True) split_sizes = [0] * (num_expert * tp_size) @@ -1006,13 +975,12 @@ def _test_permutation_mask_map_alongside_probs( # TE Permutation # ################################################################################################################################### - te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input = pytorch_permute_fwd_input.detach() te_permute_fwd_input.requires_grad_(True) - te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() te_probs = probs.detach() te_probs.requires_grad_(True) - print(te_probs.shape) te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( te_permute_fwd_input, @@ -1020,27 +988,14 @@ def _test_permutation_mask_map_alongside_probs( routing_map, num_out_tokens=num_out_tokens, ) - print(te_permuted_probs.shape) te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda ) - if fp8: - _permute_output_quantizer = Float8Quantizer( - scale=torch.full([1], 1.0).cuda().squeeze(), - amax=torch.full([1], 1.0).cuda(), - fp8_dtype=te_dtype, - ) - te_permute_output = te_permute_output.dequantize(dtype=torch.float32) - te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) - te_permute_output = _permute_output_quantizer.quantize(te_permute_output) - else: - te_permute_output_dtype = te_permute_output.dtype - print(te_permute_output.shape) - print(te_permuted_probs.shape) - te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) - te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) + te_permute_output_dtype = te_permute_output.dtype + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) te_permute_output = te_sort_chunks_by_index( te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda @@ -1058,13 +1013,8 @@ def _test_permutation_mask_map_alongside_probs( tols = dtype_tols(te_dtype) - if fp8: - # backward of dequantize is in high precision - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) - else: - te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() - te_unpermute_output_ = te_unpermute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() torch.testing.assert_close( pytorch_unpermute_output.float(), @@ -1228,6 +1178,16 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) +fp8_recipes = [ + recipe.MXFP8BlockScaling(), + recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), +] @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @@ -1237,36 +1197,7 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) -def test_permutation_index_map_fp8( - te_dtype, - num_tokens, - num_expert, - hidden_size, - topK, - num_out_tokens, -): - with_probs = True - BENCHMARK = False - - _test_permutation_index_map( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=BENCHMARK, - ) - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) -@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("recipe", fp8_recipes) def test_permutation_mask_map_fp8( te_dtype, num_tokens, @@ -1274,47 +1205,21 @@ def test_permutation_mask_map_fp8( hidden_size, topK, num_out_tokens, + recipe, ): - with_probs = True - BENCHMARK = False - - _test_permutation_mask_map( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=BENCHMARK, - ) - + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) -@pytest.mark.parametrize("num_out_tokens", [None, 2039]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) -def test_permutation_mask_map_alongside_probs_fp8( - te_dtype, - num_tokens, - num_expert, - hidden_size, - topK, - num_out_tokens, - tp_size, -): - _test_permutation_mask_map_alongside_probs( + _test_permutation_mask_map_fp8( te_dtype=te_dtype, num_tokens=num_tokens, num_expert=num_expert, hidden_size=hidden_size, topK=topK, num_out_tokens=num_out_tokens, - tp_size=tp_size, + recipe=recipe, ) @@ -1415,11 +1320,9 @@ def test_permutation_single_case(): # te_dtype = tex.DType.kFloat32 # te_dtype = tex.DType.kFloat16 - # te_dtype = tex.DType.kBFloat16 - te_dtype = tex.DType.kFloat8E5M2 - # te_dtype = tex.DType.kFloat8E4M3 + te_dtype = tex.DType.kBFloat16 - num_tokens = 10 + num_tokens = 12 num_expert = 4 hidden_size = 16 topK = 2 diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 7e9e2a97f7..2ac38c93cf 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -333,7 +333,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so const transformer_engine::Tensor *input_fwd_cu = reinterpret_cast(input_fwd); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), reinterpret_cast(output_cu->data.dptr), @@ -359,7 +359,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id const transformer_engine::Tensor *prob_cu = reinterpret_cast(prob); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), reinterpret_cast(output_cu->data.dptr), diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 47282da504..1dc8dd0d17 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -52,18 +52,11 @@ std::tuple> moe_permute_fwd( sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, num_tokens * topK); - // Activations type - at::ScalarType _st; - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Byte; - else - _st = input.scalar_type(); - // Output buffer alloc num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; - at::Tensor permuted_output = torch::empty( - {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor permuted_output = + torch::empty({num_out_tokens, num_cols}, + torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); at::Tensor row_id_map = torch::empty( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); @@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d using namespace transformer_engine::pytorch; int num_cols = input.size(1); - // Activations type - at::ScalarType _st; - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Byte; - else - _st = input.scalar_type(); - // Output buffer alloc - at::Tensor unpermuted_output = torch::empty( - {num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor unpermuted_output = + torch::empty({num_tokens, num_cols}, + torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -136,17 +122,10 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); int num_cols = input_bwd.size(1); - // Activations type - at::ScalarType _st; - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Byte; - else - _st = input_bwd.scalar_type(); - // Output buffer alloc - at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, - torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor act_grad = + torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false)); at::Tensor prob_grad = torch::empty( {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index dd2f60deba..d88047a012 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -4,14 +4,16 @@ """MoE Permutaion API""" import warnings -from typing import Tuple +from typing import Optional, Tuple import torch import transformer_engine_torch as tex import transformer_engine.pytorch.triton.permutation as triton_permutation from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.float8_tensor import Float8Tensor - +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor __all__ = [ "moe_permute", @@ -46,17 +48,7 @@ def forward( assert inp.size(0) == index.size(0), "Permute not possible" # Data type check - fp8 = isinstance(inp, Float8Tensor) - if fp8: - assert ( - inp._quantizer.scale.ndim == 0 - ), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute." - dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data - else: - dtype = TE_DType[inp.dtype] + dtype = TE_DType[inp.dtype] if index.dtype != torch.int32: warnings.warn( f"The data type of the input `index` of Permute is {index.dtype}! " @@ -80,19 +72,9 @@ def forward( _moe_permute_index_map.max_expanded_token_num, ) - if fp8: - permuted_act = Float8Tensor( - data=permuted_act, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv, - shape=permuted_act.shape, - dtype=fake_dtype, - ) - ctx.row_id_map = row_id_map ctx.num_tokens = index.size(0) ctx.topK = index.size(1) - ctx.fp8 = fp8 return permuted_act, row_id_map @staticmethod @@ -109,30 +91,12 @@ def backward( if not permuted_act_grad.is_contiguous(): permuted_act_grad = permuted_act_grad.contiguous() - if ctx.fp8: - assert isinstance( - permuted_act_grad, Float8Tensor - ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." - dtype = permuted_act_grad._fp8_dtype - fp8_scale_inv = permuted_act_grad._scale_inv - fake_dtype = permuted_act_grad.dtype - permuted_act_grad = permuted_act_grad._data - else: - dtype = TE_DType[permuted_act_grad.dtype] - + dtype = TE_DType[permuted_act_grad.dtype] act_grad = None if ctx.needs_input_grad[0]: act_grad = tex.moe_permute_bwd( permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK ) - if ctx.fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv * ctx.topK, - shape=act_grad.shape, - dtype=fake_dtype, - ) return act_grad, None, None, None @@ -176,14 +140,7 @@ def forward( assert row_id_map.is_cuda, "TransformerEngine needs CUDA." # Data type check - fp8 = isinstance(inp, Float8Tensor) - if fp8: - dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data - else: - dtype = TE_DType[inp.dtype] + dtype = TE_DType[inp.dtype] if row_id_map.dtype != torch.int32: warnings.warn( f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " @@ -193,17 +150,7 @@ def forward( unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) - if fp8: - unpermuted_output = Float8Tensor( - data=unpermuted_output, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv, - shape=unpermuted_output.shape, - dtype=fake_dtype, - ) - ctx.save_for_backward(inp, row_id_map, probs) - ctx.fp8 = fp8 return unpermuted_output @staticmethod @@ -219,17 +166,7 @@ def backward( if not unpermuted_act_grad.is_contiguous(): unpermuted_act_grad = unpermuted_act_grad.contiguous() - if ctx.fp8: - assert isinstance( - unpermuted_act_grad, Float8Tensor - ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." - dtype = unpermuted_act_grad._fp8_dtype - fp8_scale_inv = unpermuted_act_grad._scale_inv - fake_dtype = unpermuted_act_grad.dtype - unpermuted_act_grad = unpermuted_act_grad._data - else: - dtype = TE_DType[unpermuted_act_grad.dtype] - + dtype = TE_DType[unpermuted_act_grad.dtype] inp, row_id_map, probs = ctx.saved_tensors act_grad = None @@ -238,14 +175,6 @@ def backward( act_grad, prob_grad = tex.moe_unpermute_bwd( unpermuted_act_grad, inp, dtype, row_id_map, probs ) - if ctx.fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) if not ctx.needs_input_grad[2]: prob_grad = None @@ -282,29 +211,86 @@ def forward( row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) - fp8 = isinstance(inp, Float8Tensor) + fp8 = isinstance(inp, QuantizedTensor) + per_tensor_recipe = isinstance(inp, Float8Tensor) + blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(inp, MXFP8Tensor) + if fp8: fp8_dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv fake_dtype = inp.dtype - inp = inp._data - output, permuted_probs = triton_permutation.permute_with_mask_map( + # blockwise scaling + if blockwise_recipe: + fp8_scale = inp._rowwise_scale_inv.T.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + # mxfp8 scaling + elif mxfp8_recipe: + fp8_scale = inp._rowwise_scale_inv.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + # per-tensor scaling + elif per_tensor_recipe: + # Kernel does not need scale in per-tensor scaling + fp8_scale = None + scale_hidden_dim = None + fp8_scale_inv = inp._scale_inv + inp = inp._data + else: + raise ValueError("Unsupported FP8 recipe") + else: + fp8_scale = None + fp8_dtype = None + scale_hidden_dim = None + + output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map( inp, row_id_map, probs, + fp8_scale, num_tokens, num_experts, num_out_tokens, hidden_size, + scale_hidden_dim, ) + if fp8: - output = Float8Tensor( - data=output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=output.shape, - dtype=fake_dtype, - ) + if per_tensor_recipe: + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) + elif blockwise_recipe: + output = Float8BlockwiseQTensor( + shape=output.shape, + dtype=fake_dtype, + rowwise_data=output, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, + requires_grad=output.requires_grad, + ) + elif mxfp8_recipe: + output = MXFP8Tensor( + shape=output.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=output, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=output.requires_grad, + ) ctx.save_for_backward(row_id_map) ctx.num_experts = num_experts @@ -327,14 +313,9 @@ def backward( probs_grad = None if ctx.needs_input_grad[0]: (row_id_map,) = ctx.saved_tensors - fp8 = isinstance(permuted_act_grad, Float8Tensor) - if fp8: - fp8_dtype = permuted_act_grad._fp8_dtype - fp8_scale_inv = permuted_act_grad._scale_inv - fake_dtype = permuted_act_grad.dtype - permuted_act_grad = permuted_act_grad._data - else: - fp8_dtype = None + assert not isinstance( + permuted_act_grad, QuantizedTensor + ), "The backward of moe_permute does not support FP8." act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( permuted_act_grad, row_id_map, @@ -343,16 +324,7 @@ def backward( ctx.num_tokens, ctx.num_experts, ctx.hidden_size, - fp8_dtype, ) - if fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv * ctx.num_experts, - shape=act_grad.shape, - dtype=fake_dtype, - ) if not ctx.needs_input_grad[3]: probs_grad = None return act_grad, None, None, probs_grad @@ -366,8 +338,8 @@ def forward( ctx, inp: torch.Tensor, row_id_map: torch.Tensor, - merging_probs: torch.Tensor, - restore_shape: torch.Size, + merging_probs: Optional[torch.Tensor], + restore_shape: Optional[torch.Size], ) -> torch.Tensor: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -387,17 +359,9 @@ def forward( assert inp.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA." - fp8 = isinstance(inp, Float8Tensor) - if fp8: - fp8_dtype = inp._fp8_dtype - if not with_probs: - fp8_scale_inv = inp._scale_inv * num_experts - else: - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data - else: - fp8_dtype = None + assert not isinstance( + inp, QuantizedTensor + ), "The forward of moe_unpermute does not support FP8." unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( inp, row_id_map, @@ -406,16 +370,7 @@ def forward( num_tokens, num_experts, hidden_size, - fp8_dtype=fp8_dtype, ) - if fp8: - unpermuted_output = Float8Tensor( - data=unpermuted_output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=unpermuted_output.shape, - dtype=fake_dtype, - ) if with_probs: ctx.save_for_backward(inp, row_id_map, merging_probs) @@ -442,16 +397,44 @@ def backward(ctx, unpermuted_act_grad): else: (row_id_map,) = ctx.saved_tensors - fp8 = isinstance(unpermuted_act_grad, Float8Tensor) + fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) + per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) + blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor) + if fp8: fp8_dtype = unpermuted_act_grad._fp8_dtype - fp8_scale_inv = unpermuted_act_grad._scale_inv fake_dtype = unpermuted_act_grad.dtype - unpermuted_act_grad = unpermuted_act_grad._data + # per-tensor scaling + if per_tensor_recipe: + # Kernel does not need scale in per-tensor scaling + fp8_scale = None + scale_hidden_dim = None + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + # blockwise scaling + elif blockwise_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + # mxfp8 scaling + elif mxfp8_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + else: + raise ValueError("Unsupported FP8 recipe") else: + scale_hidden_dim = None fp8_dtype = None + fp8_scale = None if ctx.with_probs: + assert ( + not fp8 + ), "The backward of moe_unpermute with merging probs does not support FP8." act_grad, probs_grad = ( triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( unpermuted_act_grad, @@ -462,28 +445,55 @@ def backward(ctx, unpermuted_act_grad): ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, - fp8_dtype, ) ) else: - act_grad, _ = triton_permutation.permute_with_mask_map( + act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map( unpermuted_act_grad, row_id_map, None, + fp8_scale, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, + scale_hidden_dim, ) if fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) + if per_tensor_recipe: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) + elif blockwise_recipe: + act_grad = Float8BlockwiseQTensor( + shape=act_grad.shape, + dtype=fake_dtype, + rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, + requires_grad=act_grad.requires_grad, + ) + elif mxfp8_recipe: + act_grad = MXFP8Tensor( + shape=act_grad.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=act_grad.requires_grad, + ) if not ctx.needs_input_grad[2]: probs_grad = None @@ -568,10 +578,10 @@ def moe_permute_with_probs( def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, - merging_probs: torch.Tensor = None, - restore_shape: torch.Tensor = None, + merging_probs: Optional[torch.Tensor] = None, + restore_shape: Optional[torch.Size] = None, map_type: str = "mask", - probs: torch.Tensor = None, + probs: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -588,7 +598,7 @@ def moe_unpermute( The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. - restore_shape: torch.Tensor + restore_shape: torch.Size, default = None The output shape after the unpermute operation. map_type: str, default = 'mask' Type of the routing map tensor. Should be the same as the value passed to moe_permute. diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 1c5fd73581..ebf8dd551e 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -10,8 +10,6 @@ import triton import triton.language as tl -from transformer_engine_torch import DType as TE_DType - @triton.jit def _row_id_map_pass_1_kernel( @@ -116,11 +114,14 @@ def _permute_kernel( output_ptr, row_id_map_ptr, probs_ptr, + scale_ptr, permuted_probs_ptr, + permuted_scale_ptr, # sizes num_tokens, num_experts, hidden_size, + scale_hidden_dim, # strides stride_input_token, stride_input_hidden, @@ -128,9 +129,14 @@ def _permute_kernel( stride_output_hidden, stride_probs_token, stride_probs_expert, + stride_scale_token, + stride_scale_hidden, stride_permuted_probs_token, + stride_permuted_scale_token, + stride_permuted_scale_hidden, # metas PERMUTE_PROBS: tl.constexpr, + PERMUTE_SCALE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -140,11 +146,21 @@ def _permute_kernel( mask = cur_off < hidden_size input_off = pid * stride_input_token + cur_off * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) + if PERMUTE_SCALE: + mask_scale = cur_off < scale_hidden_dim + scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden + scale = tl.load(scale_ptr + scale_off, mask=mask_scale) for expert_idx in range(num_experts): dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) if dst_row != -1: output_off = dst_row * stride_output_token + cur_off * stride_output_hidden tl.store(output_ptr + output_off, inp, mask=mask) + if PERMUTE_SCALE: + permuted_scale_off = ( + dst_row * stride_permuted_scale_token + + cur_off * stride_permuted_scale_hidden + ) + tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) if PERMUTE_PROBS: if cur_pos == 0: prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert @@ -173,10 +189,12 @@ def permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor, + scale: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, hidden_size: int, + scale_hidden_dim: int, ): # pylint: disable=missing-function-docstring output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") @@ -184,26 +202,42 @@ def permute_with_mask_map( permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") else: permuted_probs = None + + if scale is not None: + permuted_scale = torch.empty( + (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda" + ) + else: + permuted_scale = None + grid = (num_tokens,) _permute_kernel[grid]( inp, output, row_id_map, probs, + scale, permuted_probs, + permuted_scale, num_tokens, num_experts, hidden_size, + scale_hidden_dim, inp.stride(0), inp.stride(1), output.stride(0), output.stride(1), probs.stride(0) if probs is not None else None, probs.stride(1) if probs is not None else None, + scale.stride(0) if scale is not None else None, + scale.stride(1) if scale is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None, + permuted_scale.stride(0) if permuted_scale is not None else None, + permuted_scale.stride(1) if permuted_scale is not None else None, PERMUTE_PROBS=probs is not None, + PERMUTE_SCALE=scale is not None, ) - return output, permuted_probs + return output, permuted_scale, permuted_probs @triton.jit @@ -232,18 +266,9 @@ def _unpermute_kernel( # metas WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, - FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - if FP8_DTYPE == "e5m2": - data_type = tl.float8e5 - pytorch_tensor_dtype = tl.uint8 - elif FP8_DTYPE == "e4m3": - data_type = tl.float8e4nv - pytorch_tensor_dtype = tl.uint8 - else: - data_type = input_ptr.dtype.element_ty - assert FP8_DTYPE is None + data_type = input_ptr.dtype.element_ty compute_type = tl.float32 pid = tl.program_id(0) @@ -257,8 +282,6 @@ def _unpermute_kernel( if src_row != -1: input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) - if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True) inp = inp.to(compute_type) if WITH_MERGING_PROBS: merging_prob_off = ( @@ -279,14 +302,7 @@ def _unpermute_kernel( tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) else: tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) - if FP8_DTYPE is not None: - if not WITH_MERGING_PROBS: - # Directly adding these value may cause overflow for fp8, we scale it here. - # The outside fp8_scale_inv is also scaled in the meantime. - accumulator /= num_experts - accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True) - else: - accumulator = accumulator.to(data_type) + accumulator = accumulator.to(data_type) output_off = pid * stride_output_token + current_offset * stride_output_hidden tl.store(output_ptr + output_off, accumulator, mask=mask) current_start += BLOCK_SIZE @@ -315,15 +331,8 @@ def unpermute_with_mask_map( num_tokens: int, num_experts: int, hidden_size: int, - fp8_dtype: TE_DType, ): # pylint: disable=missing-function-docstring - if fp8_dtype == TE_DType.kFloat8E5M2: - fp8_dtype = "e5m2" - elif fp8_dtype == TE_DType.kFloat8E4M3: - fp8_dtype = "e4m3" - else: - fp8_dtype = None output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") if permuted_probs is not None: unpermuted_probs = torch.empty( @@ -353,7 +362,6 @@ def unpermute_with_mask_map( unpermuted_probs.stride(1) if unpermuted_probs is not None else None, WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, - FP8_DTYPE=fp8_dtype, ) return output, unpermuted_probs @@ -383,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel( stride_merging_probs_grad_token, stride_merging_probs_grad_expert, # metas - FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - if FP8_DTYPE == "e5m2": - data_type = tl.float8e5 - pytorch_tensor_dtype = tl.uint8 - elif FP8_DTYPE == "e4m3": - data_type = tl.float8e4nv - pytorch_tensor_dtype = tl.uint8 - else: - data_type = fwd_output_grad_ptr.dtype.element_ty - assert FP8_DTYPE is None + data_type = fwd_output_grad_ptr.dtype.element_ty compute_type = tl.float32 pid = tl.program_id(0) @@ -411,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel( + current_offset * stride_fwd_output_grad_hidden ) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) - if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True) inp = inp.to(compute_type) merging_prob_off = ( pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert @@ -420,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel( merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) output = inp * merging_prob output = output.to(data_type) - if FP8_DTYPE is not None: - output = output.to(pytorch_tensor_dtype, bitcast=True) output_off = ( dst_row * stride_fwd_input_grad_token + current_offset * stride_fwd_input_grad_hidden @@ -432,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel( dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden ) fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) - if FP8_DTYPE is not None: - fwd_input = fwd_input.to(data_type, bitcast=True) prob_grad_accum += fwd_input.to(compute_type) * inp current_start += BLOCK_SIZE probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) @@ -474,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_experts: int, num_out_tokens: int, hidden_size: int, - fp8_dtype: TE_DType, ): # pylint: disable=missing-function-docstring - if fp8_dtype == TE_DType.kFloat8E5M2: - fp8_dtype = "e5m2" - elif fp8_dtype == TE_DType.kFloat8E4M3: - fp8_dtype = "e4m3" - else: - fp8_dtype = None act_grad = torch.empty( (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" ) @@ -510,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( merging_probs.stride(1), merging_probs_grad.stride(0), merging_probs_grad.stride(1), - fp8_dtype, ) return act_grad, merging_probs_grad From 38e18f7fd7d50e0dfa1f78c2b11d3ac29ecd0819 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 14 Apr 2025 05:01:47 -0700 Subject: [PATCH 24/53] fix unittest Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/_common.py | 23 ++++++++----------- .../pytorch/module/grouped_linear.py | 10 ++++---- .../pytorch/module/layernorm_linear.py | 8 +++---- .../pytorch/module/layernorm_mlp.py | 8 +++---- transformer_engine/pytorch/module/linear.py | 8 +++---- 5 files changed, 27 insertions(+), 30 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index e1c98c4810..9f82b4554b 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -9,8 +9,8 @@ from functools import reduce from operator import mul as multiply_op -import torch import queue +import torch from .. import cpp_extensions as tex from ..constants import TE_DType @@ -226,7 +226,7 @@ class WeightGradStore: """ def __init__( - self, split_bw=False, use_bias=False, fuse_wgrad_accumulation=True, ub_bulk_wgrad=False + self, split_bw=False, ub_bulk_wgrad=False ): """ Initialize the WeightGradStore. @@ -236,7 +236,7 @@ def __init__( """ if split_bw: self.context = queue.Queue() - assert ub_bulk_wgrad == False, "ub_bulk_wgrad is not supported when enabling split_bw" + assert ub_bulk_wgrad is False, "ub_bulk_wgrad is not supported when enabling split_bw" self.enabled = split_bw else: self.context = None @@ -267,31 +267,28 @@ def put(self, tensor_list, func): tensor_list (list): List of tensors needed for computation func (callable): Function to be executed with the tensors """ - assert self.enabled == True, "split_bw is not enabled" + assert self.enabled is True, "split_bw is not enabled" self.context.put([tensor_list, func]) - return def pop(self): """ Execute the stored computation with the stored tensors. Raises an exception if the queue is empty. """ - assert self.enabled == True, "split_bw is not enabled" + assert self.enabled is True, "split_bw is not enabled" if self.context.qsize() > 0: tensor_list, func = self.context.get() return func(*tensor_list), tensor_list - else: - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - raise Exception(f"Pop empty queue. rank {rank}") - else: - raise Exception("Pop empty queue. No distributed environment detected.") + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + raise Exception(f"Pop empty queue. rank {rank}") + raise Exception("Pop empty queue. No distributed environment detected.") def assert_empty(self): """ Assert that the queue is empty. Used for debugging and ensuring proper cleanup. """ - assert self.enabled == True, "split_bw is not enabled" + assert self.enabled is True, "split_bw is not enabled" rank = torch.distributed.get_rank() assert self.context.empty(), f"Queue is not empty. rank {rank}" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 473e19b8aa..52741789cf 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -5,8 +5,8 @@ """GroupedLinear API""" from typing import Union, Optional, Callable, Tuple, List -import torch import functools +import torch import transformer_engine_torch as tex @@ -313,7 +313,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate=accumulate_wgrad_into_param_main_grad, ) # WGRAD - if ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) else: _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) @@ -360,10 +360,10 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias or (ctx.wgrad_store.split_bw() and not ctx.fp8): + if not ctx.use_bias or (ctx.wgrad_store is not None and ctx.wgrad_store.split_bw() and not ctx.fp8): grad_biases = [None] * ctx.num_gemms if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): @@ -700,7 +700,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if not self.wgrad_store.split_bw(): + if self.wgrad_store is None or not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7a2857670c..da7007f0dc 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -9,9 +9,9 @@ from functools import reduce from operator import mul as multiply_op +import functools import torch from torch.nn import init -import functools import transformer_engine_torch as tex @@ -759,7 +759,7 @@ def backward( bulk_overlap=ctx.ub_bulk_wgrad, ) - if ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad) else: wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output) @@ -781,7 +781,7 @@ def backward( dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) # Don't return grad bias if not needed - if not ctx.use_bias or ctx.wgrad_store.split_bw(): + if not ctx.use_bias or (ctx.wgrad_store is not None and ctx.wgrad_store.split_bw()): grad_bias = None # Synchronize tensor parallel communication @@ -1498,7 +1498,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if not self.wgrad_store.split_bw(): + if self.wgrad_store is None or not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_LayerNormLinear_wgrad"): super().backward_dw() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7f9ac1aa1b..6e9e6181c5 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -801,7 +801,7 @@ def backward( use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) - if ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) fc2_wgrad = None fc2_bias_grad = None @@ -822,7 +822,7 @@ def backward( fc2_bias_grad = fc2_bias_grad_ del fc2_bias_grad_ - if not ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and not ctx.wgrad_store.split_bw(): clear_tensor_data(act_out) # bias computation @@ -1002,7 +1002,7 @@ def backward( extra_output=fc1_dgrad_rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) - if ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) fc1_wgrad = None # (fc1_wgrad_outputs), _ = ctx.wgrad_store.pop() @@ -1738,7 +1738,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if not self.wgrad_store.split_bw(): + if self.wgrad_store is None or not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 77a56ef83a..5fc9217487 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,8 +7,8 @@ from functools import reduce from operator import mul as multiply_op -import torch import functools +import torch import transformer_engine_torch as tex @@ -679,7 +679,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], bulk_overlap=ctx.ub_bulk_wgrad, ) - if ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad) else: wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output) @@ -700,7 +700,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) # Don't return grad bias if not needed - if not ctx.use_bias or ctx.wgrad_store.split_bw(): + if not ctx.use_bias or (ctx.wgrad_store is not None and ctx.wgrad_store.split_bw()): grad_bias = None # Make sure all tensor-parallel communication is finished @@ -1304,7 +1304,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if not self.wgrad_store.split_bw(): + if self.wgrad_store is None or not self.wgrad_store.split_bw(): return with torch.cuda.nvtx.range("_Linear_wgrad"): super().backward_dw() From 5f16c7924536ab945781c3b32d5a50f19a0650bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:02:37 +0000 Subject: [PATCH 25/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/_common.py | 4 +--- transformer_engine/pytorch/module/grouped_linear.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 9f82b4554b..bbe22370b2 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -225,9 +225,7 @@ class WeightGradStore: This class enables split backward propagation for better memory efficiency. """ - def __init__( - self, split_bw=False, ub_bulk_wgrad=False - ): + def __init__(self, split_bw=False, ub_bulk_wgrad=False): """ Initialize the WeightGradStore. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 52741789cf..2dc4d1dd38 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -363,7 +363,9 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias or (ctx.wgrad_store is not None and ctx.wgrad_store.split_bw() and not ctx.fp8): + if not ctx.use_bias or ( + ctx.wgrad_store is not None and ctx.wgrad_store.split_bw() and not ctx.fp8 + ): grad_biases = [None] * ctx.num_gemms if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): From 98b4c0d90e68dd84f8ac4d00a4539d097fa618b4 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Mon, 14 Apr 2025 08:16:19 -0700 Subject: [PATCH 26/53] [JAX] grouped_gemm() uses variadic arguments (#1658) * New GroupedGemmPrimitive using variadic args * Remove squeeze() to reduce D2D memcpy * Revert to the list append fashion to simplify code --------- Signed-off-by: Hua Huang Co-authored-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 215 +++++++----------- .../jax/csrc/extensions/gemm.cpp | 197 ++++++++-------- 2 files changed, 173 insertions(+), 239 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0327542c2f..588e7a469d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9) + impl_static_args = () inner_primitive = None outer_primitive = None @staticmethod - def abstract( - lhs_contig_aval, - lhs_scale_contig_aval, - rhs_contig_aval, - rhs_scale_contig_aval, - bias_contig_aval, - dim_list_aval, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ): - del lhs_contig_aval, lhs_scale_contig_aval - del rhs_contig_aval, rhs_scale_contig_aval - del bias_contig_aval, dim_list_aval - del num_gemms, scaling_mode - out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype) - wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams - wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8) - return (out_flat_aval, wkspace_aval) + def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + """ + Args: + *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: + args[ 0 : num_gemms] are the lhs tensors, + args[ num_gemms : 2*num_gemms] are the rhs tensors, + args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, + args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, + args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. + num_gemms: Number of GEMM operations to perform. + scaling_mode: Scaling mode for the GEMM operations. + out_dtype: Data type of the output tensors. + has_bias: Boolean indicating if bias tensors are provided. + + Returns: + A tuple of ShapedArray objects of size num_gemms+1: + ret[0 : num_gemms]: GEMM output tensors, + ret[num_gemms]:workspace tensor. + """ + del scaling_mode + expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms + assert ( + len(args) == expected_num_args + ), f"Expected {expected_num_args} input arguments, but got {len(args)}" + A_list = args[0:num_gemms] + B_list = args[num_gemms : 2 * num_gemms] + # A and B have shapes [1, m, k] and [1, n, k] + out_list_aval = tuple( + jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) + for A, B in zip(A_list, B_list) + ) + workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + return (*out_list_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): @@ -74,60 +87,27 @@ def outer_abstract(*args, **kwargs): return out_aval @staticmethod - def lowering( - ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: - del out_dtype, out_flat_size + def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + *args, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=int(scaling_mode), + has_bias=has_bias, ) @staticmethod - def impl( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: + def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): assert GroupedGemmPrimitive.inner_primitive is not None out = GroupedGemmPrimitive.inner_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + *args, num_gemms=num_gemms, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, - out_flat_size=out_flat_size, + has_bias=has_bias, ) - return out[0] # out is [out_flat, wkspace], only return out_flat + return out[:-1] # out is [out_list, wkspace], only return out_list register_primitive(GroupedGemmPrimitive) @@ -366,6 +346,7 @@ def swizzled_scale(scales): rows, cols = scales.shape scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) + scales = scales.reshape(rows, cols) return scales @@ -380,18 +361,12 @@ def grouped_gemm( len(lhs_list) == len(rhs_list) == len(contracting_dims_list) ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - # Flatten inputs and save their shapes - num_gemms = len(lhs_list) - out_flat_size = 0 - dims = [] - lhs_contig_ = [] - rhs_contig_ = [] - lhs_scale_inv_contig_ = [] - rhs_scale_inv_contig_ = [] - bias_contig_ = [] - out_offsets = [] - remain_shape_list = [] num_gemms = len(lhs_list) + lhs_list_ = [] + rhs_list_ = [] + lhs_sinv_list_ = [] + rhs_sinv_list_ = [] + bias_list_ = [] for i in range(num_gemms): lhs = lhs_list[i] rhs = rhs_list[i] @@ -402,7 +377,7 @@ def grouped_gemm( lhs_shape = lhs.data.shape rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype - # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout + # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 @@ -427,6 +402,7 @@ def grouped_gemm( lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) + # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) @@ -438,13 +414,13 @@ def grouped_gemm( rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) + # swizzled_scale requires a matrix lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) else: raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - # Note: if _shape_normalization() is updated to support non-TN, need to update here - # already_transposed doesn't matter for the output shape + # Note: already_transposed doesn't matter for the output shape # x.shape = [B, D1, D2] # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] @@ -455,66 +431,37 @@ def grouped_gemm( bn = rhs_remain_shape[0] kl = lhs_3d.shape[-1] kr = rhs_3d.shape[-1] - remain_shape_list.append(((bm,), (bn,))) - assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" - k = kl - - if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): - print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print( - f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" - " of 16" - ) - assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0 - - dims.append((bm, bn, k)) - lhs_contig_.append(lhs_3d.reshape(-1)) - rhs_contig_.append(rhs_3d.reshape(-1)) + assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" + if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): + print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") + print(f"m = {bm}, n = {bn}, k = {kl}; ") + print("cuBLAS requires the problem shapes being multiples of 16") + assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) + + lhs_list_.append(lhs_3d) + rhs_list_.append(rhs_3d) if scaling_mode == ScalingMode.NO_SCALING: - lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) + lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: - lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) + lhs_sinv_list_.append(lhs.scale_inv) + rhs_sinv_list_.append(rhs.scale_inv) if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) + lhs_sinv_list_.append(lhs_scale_inv) + rhs_sinv_list_.append(rhs_scale_inv) if bias_list is not None: - bias_contig_.append(bias_list[i].reshape(-1)) - out_flat_size += bm * bn - out_offsets.append(out_flat_size) - - lhs_contig = jnp.concatenate(lhs_contig_) - rhs_contig = jnp.concatenate(rhs_contig_) - lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_) - rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_) - bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_) - dim_list = jnp.array(dims, dtype=jnp.int32) - - # TE/common does not support NVTE_NO_SCALING yet - # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NO_SCALING: - scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING - - # Perform batched GEMM on flattened inputs - out_contig = GroupedGemmPrimitive.outer_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + bias_list_.append(bias_list[i]) + + out_list = GroupedGemmPrimitive.outer_primitive.bind( + *lhs_list_, + *rhs_list_, + *lhs_sinv_list_, + *rhs_sinv_list_, + *bias_list_, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=scaling_mode, out_dtype=out_dtype, - out_flat_size=out_flat_size, + has_bias=1 if bias_list is not None else 0, ) - # Split the output back into tensors - out_offsets = jnp.array(out_offsets) - out_flat_list = jnp.split(out_contig, out_offsets[:-1]) - out_tensors = [] - for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list): - out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) - - return out_tensors + return out_list diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d4b9bf720e..4318e19c75 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -15,34 +15,9 @@ namespace transformer_engine { namespace jax { -constexpr static size_t MXFP8_BLOCK_SIZE = 32; - -// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX) -Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr, - const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype, - uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, - const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, - uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode, - cudaStream_t stream) { - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms; - std::unique_ptr dim_list_host = std::make_unique(3 * num_gemms); - - cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - +Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, + Variadic_Result_Type output_list, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], @@ -56,6 +31,18 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + if (num_gemms <= 0) { + return ffi_with_cuda_error_check(); + } + size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; + size_t expected_output_size = num_gemms + 1; + size_t actual_input_size = input_list.size(); + size_t actual_output_size = output_list.size(); + NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", + expected_input_size, actual_input_size); + NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", + expected_output_size, actual_output_size); + bool trans_lhs = true; bool trans_rhs = false; auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh std::vector out_list; std::vector workspace_list; + int lhs_list_offset = 0; + int rhs_list_offset = num_gemms; + int lhs_sinv_list_offset = 2 * num_gemms; + int rhs_sinv_list_offset = 3 * num_gemms; + int bias_list_offset = 4 * num_gemms; + int out_list_offset = 0; for (int i = 0; i < num_gemms; i++) { - size_t m = dim_list_host[i * 3]; - size_t n = dim_list_host[i * 3 + 1]; - size_t k = dim_list_host[i * 3 + 2]; + Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); + Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); + Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); + Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); + Result_Type out_i = output_list.get(out_list_offset + i).value(); + + DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); + DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); + DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); + + void *lhs_ptr = lhs_i.untyped_data(); + void *rhs_ptr = rhs_i.untyped_data(); + void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); + void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); + void *out_ptr = out_i->untyped_data(); + + // Placeholder for bias since it can be empty + DType bias_dtype = DType::kFloat32; + void *bias_ptr = nullptr; + + auto lhs_shape_ = lhs_i.dimensions(); + auto rhs_shape_ = rhs_i.dimensions(); + + // lhs and rhs has shape [1, m, k] and [1, n, k] + size_t m = lhs_shape_[1]; + size_t n = rhs_shape_[1]; + size_t k = lhs_shape_[2]; auto lhs_shape = std::vector{m, k}; auto rhs_shape = std::vector{n, k}; @@ -90,52 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); - - if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, - std::vector{1}); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, - std::vector{1}); + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + float *amax_dptr = nullptr; + float *scale_dptr = nullptr; + auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(lhs_sinv_ptr)); + auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(rhs_sinv_ptr)); + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", - MXFP8_BLOCK_SIZE, k); - size_t sinv_k = k / MXFP8_BLOCK_SIZE; - lhs_sinv_shape[0] = m; - lhs_sinv_shape[1] = sinv_k; - rhs_sinv_shape[0] = n; - rhs_sinv_shape[1] = sinv_k; - // Note: the scale_inv array should have been swizzled in Python before lowering - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, - lhs_sinv_shape); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, - rhs_sinv_shape); + auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); + auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); + for (int i = 0; i < 2; i++) { + lhs_sinv_shape[i] = lhs_sinv_shape_[i]; + rhs_sinv_shape[i] = rhs_sinv_shape_[i]; + } + + NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); + TensorWrapper lhs_i_(nvte_scaling_mode); + TensorWrapper rhs_i_(nvte_scaling_mode); + lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); + rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); + lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); + rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); + + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); } else { NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); - lhs_ptr += m * k * lhs_dtype_bytes; - rhs_ptr += n * k * rhs_dtype_bytes; - out_ptr += m * n * out_dtype_bytes; - lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes; + auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); void *pre_gelu_ptr = nullptr; auto bias_shape = std::vector{0}; auto pre_gelu_shape = std::vector{0}; - if (bias_ptr != nullptr) bias_shape[0] = n; + if (has_bias) { + auto bias_i_get = input_list.get(bias_list_offset + i); + Buffer_Type bias_i = bias_i_get.value(); + bias_ptr = bias_i.untyped_data(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); + bias_shape[0] = n; + } auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes; auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); - out_wrapper_list.push_back(std::move(out_i)); + out_wrapper_list.push_back(std::move(out_i_)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -146,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh out_list.push_back(out_wrapper_list.back().data()); } + auto workspace_get = output_list.get(num_gemms); + Result_Type workspace = workspace_get.value(); + uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); + size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -163,50 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh return ffi_with_cuda_error_check(); } -Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, - Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, - Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, - Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode) { - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv_flatten.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv_flatten.untyped_data()); - auto bias_ptr = reinterpret_cast(bias_flatten.untyped_data()); - auto dim_list_ptr = reinterpret_cast(dim_list.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type()); - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type()); - - // Outputs - auto out_ptr = reinterpret_cast(out_flatten->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type()); - auto workspace_ptr = reinterpret_cast(workspace_flatten->untyped_data()); - auto workspace_size = workspace_flatten->dimensions().back() / num_streams; - - return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype, - rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype, - workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode, - stream); -} - XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_flatten - .Arg() // lhs_sinv_flatten - .Arg() // rhs_flatten - .Arg() // rhs_sinv_flatten - .Arg() // bias_flatten - .Arg() // dim_list - .Ret() // out_flatten - .Ret() // workspace_flatten + .RemainingArgs() // input list + .RemainingRets() // output list .Attr("num_gemms") - .Attr("scaling_mode"), + .Attr("scaling_mode") + .Attr("has_bias"), FFI_CudaGraph_Traits); } // namespace jax From 6117b20caa632a9e358710ac50b233632ee0df77 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 14 Apr 2025 19:20:54 +0200 Subject: [PATCH 27/53] Add experimental Shardy support. (#1642) * Add experimental Shardy support. Production use is not yet recommended. --------- Signed-off-by: Johannes Reifferscheid --- .../run_test_multiprocessing_encoder.sh | 12 + .../encoder/test_model_parallel_encoder.py | 38 ++- examples/jax/encoder/test_multigpu_encoder.py | 27 +- .../encoder/test_multiprocessing_encoder.py | 23 +- .../jax/encoder/test_single_gpu_encoder.py | 7 +- tests/jax/pytest.ini | 2 + tests/jax/test_distributed_fused_attn.py | 285 +++++++++++++----- tests/jax/test_distributed_layernorm.py | 6 + tests/jax/test_distributed_layernorm_mlp.py | 79 ++++- tests/jax/test_distributed_softmax.py | 96 +++++- .../jax/cpp_extensions/activation.py | 89 ++++++ .../jax/cpp_extensions/attention.py | 44 ++- transformer_engine/jax/cpp_extensions/base.py | 13 +- .../jax/cpp_extensions/normalization.py | 57 ++++ .../jax/cpp_extensions/quantization.py | 43 +++ .../jax/cpp_extensions/softmax.py | 30 ++ .../jax/quantize/scaling_modes.py | 107 ++++++- 17 files changed, 849 insertions(+), 109 deletions(-) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index c14a462f75..ff38c7e335 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -21,3 +21,15 @@ do pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i & done wait + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16_shardy --num-process=$NUM_GPUS --process-id=$i & +done +wait + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8_shardy --num-process=$NUM_GPUS --process-id=$i & +done +wait diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index eabd1b2a3f..0577787b7c 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -258,6 +258,8 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) num_gpu = jax.local_device_count() @@ -441,6 +443,9 @@ def encoder_parser(args): parser.add_argument( "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." ) + parser.add_argument( + "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." + ) return parser.parse_args(args) @@ -451,10 +456,9 @@ class TestEncoder(unittest.TestCase): is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) - @classmethod - def setUpClass(cls): + def setUp(self): """Run 3 epochs for testing""" - cls.args = encoder_parser(["--epochs", "3"]) + self.args = encoder_parser(["--epochs", "3"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -503,6 +507,34 @@ def test_te_mxfp8_with_sp(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.455 and actual[1] > 0.785 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + def test_te_bf16_shardy(self): + """Test Transformer Engine with BF16""" + self.args.enable_shardy = True + actual = train_and_evaluate(self.args) + assert actual[0] < 0.455 and actual[1] > 0.785 + + @unittest.skipIf(not is_fp8_supported, fp8_reason) + def test_te_delayed_scaling_fp8_shardy(self): + """Test Transformer Engine with DelayedScaling FP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "DelayedScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.455 and actual[1] > 0.785 + + @unittest.skipIf(not is_fp8_supported, fp8_reason) + def test_te_delayed_scaling_fp8_with_sp_shardy(self): + """Test Transformer Engine with DelayedScaling FP8 + SP""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "DelayedScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.455 and actual[1] > 0.785 + + # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 839bc3175e..c196692757 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -238,6 +238,7 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) num_gpu = jax.local_device_count() @@ -409,6 +410,9 @@ def encoder_parser(args): default="DelayedScaling", help="Use FP8 recipe (default: DelayedScaling)", ) + parser.add_argument( + "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." + ) return parser.parse_args(args) @@ -419,10 +423,9 @@ class TestEncoder(unittest.TestCase): is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) - @classmethod - def setUpClass(cls): + def setUp(self): """Run 3 epochs for testing""" - cls.args = encoder_parser(["--epochs", "3"]) + self.args = encoder_parser(["--epochs", "3"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -446,6 +449,24 @@ def test_te_mxfp8(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.535 and actual[1] > 0.73 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + def test_te_bf16_shardy(self): + """Test Transformer Engine with BF16""" + self.args.enable_shardy = True + actual = train_and_evaluate(self.args) + assert actual[0] < 0.535 and actual[1] > 0.73 + + @unittest.skipIf(not is_fp8_supported, fp8_reason) + def test_te_delayed_scaling_fp8_shardy(self): + """Test Transformer Engine with DelayedScaling FP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "DelayedScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.535 and actual[1] > 0.73 + + # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index a2b160b522..56d386a3f5 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -343,6 +343,7 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) if args.process_id == 0: nltk.download("punkt_tab") @@ -565,6 +566,9 @@ def encoder_parser(args): default=0, help="the ID number of the current process (default: 0)", ) + parser.add_argument( + "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." + ) return parser.parse_args(args) @@ -573,7 +577,7 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - def exec(self, use_fp8, fp8_recipe): + def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): """Run 3 epochs for testing""" args = encoder_parser([]) @@ -589,6 +593,7 @@ def exec(self, use_fp8, fp8_recipe): args.num_process = num_gpu args.process_id = self.process_id args.fp8_recipe = fp8_recipe + args.enable_shardy = enable_shardy return train_and_evaluate(args) @@ -614,6 +619,22 @@ def test_te_mxfp8(self): result = self.exec(True, "MXFP8BlockScaling") assert result[0] < 0.505 and result[1] > 0.754 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + def test_te_bf16_shardy(self): + """Test Transformer Engine with BF16""" + result = self.exec(False, None, enable_shardy=True) + assert result[0] < 0.505 and result[1] > 0.755 + + @unittest.skipIf( + not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" + ) + def test_te_delayed_scaling_fp8_shardy(self): + """Test Transformer Engine with DelayedScaling FP8""" + result = self.exec(True, "DelayedScaling", enable_shardy=True) + assert result[0] < 0.505 and result[1] > 0.754 + + # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index df78157cc5..1783ca8177 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -330,10 +330,9 @@ class TestEncoder(unittest.TestCase): is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) - @classmethod - def setUpClass(cls): - """Run 4 epochs for testing""" - cls.args = encoder_parser(["--epochs", "3"]) + def setUp(self): + """Run 3 epochs for testing""" + self.args = encoder_parser(["--epochs", "3"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): diff --git a/tests/jax/pytest.ini b/tests/jax/pytest.ini index 1e835b2187..70d4188c5f 100644 --- a/tests/jax/pytest.ini +++ b/tests/jax/pytest.ini @@ -25,3 +25,5 @@ filterwarnings= ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning ignore:The host_callback APIs are deprecated .*:DeprecationWarning ignore:Scan loop is disabled for fused ring attention.*:UserWarning + ignore:jax.extend.ffi.register_ffi_target is deprecated + ignore:jax.extend.ffi.ffi_lowering is deprecated diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index bb7f83b319..ecca5ab322 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -48,31 +48,7 @@ def generate_collectives_count_ref( # for loss and dbias return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "data_shape", - [ - pytest.param((32, 512, 12, 64), id="32-512-12-64"), - pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), - ], - ) - @pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), - ], - ) - @pytest.mark.parametrize( - "attn_mask_type", - [ - pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), - pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), - ], - ) - @pytest.mark.parametrize("dtype", DTYPES) - def test_self_attn( + def impl_test_self_attn( self, device_count, mesh_shape, @@ -83,7 +59,9 @@ def test_self_attn( bias_shape, attn_mask_type, dtype, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) dropout_prob = 0.0 is_training = True @@ -137,6 +115,80 @@ def test_self_attn( ) runner.test_backward() + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize( + "data_shape", + [ + pytest.param((32, 512, 12, 64), id="32-512-12-64"), + pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), + ], + ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + ], + ) + @pytest.mark.parametrize("dtype", DTYPES) + def test_self_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_bias_type, + bias_shape, + attn_mask_type, + dtype, + ): + self.impl_test_self_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_bias_type, + bias_shape, + attn_mask_type, + dtype, + use_shardy=False, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + ], + ) + def test_self_attn_shardy( + self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape + ): + data_shape = (32, 512, 12, 64) + self.impl_test_self_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_bias_type, + bias_shape, + AttnMaskType.PADDING_MASK, + jnp.bfloat16, + use_shardy=True, + ) + class TestDistributedCrossAttn: @@ -203,37 +255,23 @@ def test_cross_attn( runner.test_backward() -@pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() -) -@pytest.mark.parametrize( - "data_shape", - [ - # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. - pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), - pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), - ], -) -@pytest.mark.parametrize("kv_groups", [1, 8]) -@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) -@pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - [ - pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), - pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), - pytest.param( - QKVLayout.THD_THD_THD, - AttnMaskType.PADDING_CAUSAL_MASK, - id="THD_SEPARATE-PADDING_CAUSAL", - ), - ], -) -@pytest.mark.parametrize( - "load_balanced", - [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], -) +DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [ + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), + pytest.param( + QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL" + ), +] + +DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [ + # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), + pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), +] + + class TestDistributedContextParallelSelfAttn: def impl_test_context_parallel_attn( @@ -249,7 +287,23 @@ def impl_test_context_parallel_attn( qkv_layout, load_balanced, cp_strategy, + use_shardy, + use_scan_ring=False, ): + if qkv_layout.is_thd(): + if cp_strategy == CPStrategy.ALL_GATHER: + pytest.skip("THD doesn't support all gather context parallelism.") + if not load_balanced and cp_strategy == CPStrategy.RING: + pytest.skip("THD + ring doesn't support unbalanced context parallelism.") + + assert not use_scan_ring or cp_strategy == CPStrategy.RING + + if use_scan_ring: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" + else: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" + + jax.config.update("jax_use_shardy_partitioner", use_shardy) attn_bias_type = AttnBiasType.NO_BIAS bias_shape = None dropout_prob = 0.0 @@ -324,7 +378,58 @@ def check_has_backend_for_mask(mask_type): pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") runner.test_backward() + del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] + + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + def test_context_parallel_allgather_attn_shardy( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_mask_type, + dtype, + qkv_layout, + ): + kv_groups = 8 + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced=True, + cp_strategy=CPStrategy.ALL_GATHER, + use_shardy=True, + ) + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], + ) def test_context_parallel_allgather_attn( self, device_count, @@ -338,9 +443,7 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, ): - if qkv_layout.is_thd(): - pytest.skip("THD doesn't support all gather context parallelism.") - return self.impl_test_context_parallel_attn( + self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -352,8 +455,23 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, + use_shardy=False, ) + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], + ) @pytest.mark.parametrize( "use_scan", [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], @@ -372,14 +490,6 @@ def test_context_parallel_ring_attn( load_balanced, use_scan, ): - if use_scan: - os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" - else: - os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" - - if qkv_layout.is_thd() and not load_balanced: - pytest.skip("THD + ring doesn't support unbalanced context parallelism.") - self.impl_test_context_parallel_attn( device_count, mesh_shape, @@ -392,9 +502,46 @@ def test_context_parallel_ring_attn( qkv_layout, load_balanced, CPStrategy.RING, + use_shardy=False, + use_scan_ring=use_scan, + ) + + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + ) + def test_context_parallel_ring_attn_shardy( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_mask_type, + dtype, + qkv_layout, + ): + kv_groups = 8 + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced=True, + cp_strategy=CPStrategy.RING, + use_shardy=False, + use_scan_ring=True, ) - del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] - return class TestReorderCausalLoadBalancing: diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 476d455a6a..0358a2a2e3 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -86,6 +86,7 @@ def generate_collectives_count_ref( @pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_layernorm( self, device_count, @@ -97,7 +98,9 @@ def test_layernorm( zero_centered_gamma, shard_weights, fp8_recipe, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "layernorm" q_dtype = jnp.float8_e4m3fn @@ -168,6 +171,7 @@ def ref_func(x, gamma, beta): @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_rmsnorm( self, device_count, @@ -178,7 +182,9 @@ def test_rmsnorm( dtype, shard_weights, fp8_recipe, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "rmsnorm" q_dtype = jnp.float8_e4m3fn diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index cf311ac404..f97f264245 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -144,16 +144,10 @@ def layernorm_fp8_mlp_prim_func( ) ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - def test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + def _test_layernorm_mlp_grad( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -257,9 +251,59 @@ def test_layernorm_mlp_grad( err_msg=f"multi_grads[{i}] is not close", ) + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + def test_layernorm_mlp_grad( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + ): + self._test_layernorm_mlp_grad( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + use_shardy=False, + ) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + def test_layernorm_mlp_grad_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype + ): + # We don't test block scaling with Shardy because at the time of writing, + # it is not supported in JAX's scaled_matmul_stablehlo. + self._test_layernorm_mlp_grad( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe=recipe.DelayedScaling(), + use_shardy=True, + ) + def _test_layernorm_mlp( - self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8, + fp8_recipe, + use_shardy, ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -322,9 +366,19 @@ def _test_layernorm_mlp( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): + @pytest_parametrize_wrapper("use_shardy", [False, True]) + def test_layernorm_mlp_layer( + self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy + ): self._test_layernorm_mlp( - mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=False, + fp8_recipe=None, + use_shardy=use_shardy, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -345,4 +399,5 @@ def test_layernorm_mlp_layer_fp8( dtype, use_fp8=True, fp8_recipe=fp8_recipe, + use_shardy=False, ) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 30a9fd53ea..cb30c34abc 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -28,14 +28,16 @@ def generate_collectives_count_ref(self): all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding): + def generate_inputs( + self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask + ): batch, _, sqelen, _ = shape x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: mask = make_causal_mask(batch, sqelen) else: - mask = make_self_mask(batch, sqelen) + mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen) if not bad_sharding: x_pspec = PartitionSpec( @@ -45,7 +47,11 @@ def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_shardin x_pspec = PartitionSpec( mesh_resource.dp_resource, None, None, mesh_resource.tp_resource ) - mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) + + if broadcast_batch_mask: + mask_pspec = PartitionSpec(None, None, None, None) + else: + mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) return (x, mask), (x_pspec, mask_pspec) @@ -67,16 +73,7 @@ def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): output = jax.nn.softmax(x * scale_factor) return jnp.mean(output) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) - @pytest.mark.parametrize( - "softmax_type", - [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], - ) - @pytest.mark.parametrize("scale_factor", [1.0, 3.0]) - @pytest.mark.parametrize("dtype", DTYPES) - @pytest.mark.parametrize("bad_sharding", [False, True]) - def test_softmax( + def impl_test_softmax( self, device_count, mesh_shape, @@ -87,15 +84,20 @@ def test_softmax( scale_factor, dtype, bad_sharding, + broadcast_batch_mask, + use_shardy, ): + if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED: + pytest.skip("Softmax type has no mask.") + jax.config.update("jax_use_shardy_partitioner", use_shardy) target_func = partial( self.target_func, scale_factor=scale_factor, softmax_type=softmax_type ) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) (x, mask), (x_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, softmax_type, dtype, bad_sharding + data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask ) collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) @@ -129,4 +131,70 @@ def test_softmax( assert "Sharding the hidden dimension is not supported" in str(w), ( "Softmax primitive did not raise the correct warning for " "unsupported sharding in the hidden dimension." + f"{str(w)}" ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) + @pytest.mark.parametrize( + "softmax_type", + [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], + ) + @pytest.mark.parametrize("scale_factor", [1.0, 3.0]) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("bad_sharding", [False, True]) + @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) + def test_softmax( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + softmax_type, + scale_factor, + dtype, + bad_sharding, + broadcast_batch_mask, + ): + self.impl_test_softmax( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + softmax_type, + scale_factor, + dtype, + bad_sharding, + broadcast_batch_mask, + use_shardy=False, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) + @pytest.mark.parametrize("bad_sharding", [False, True]) + @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) + def test_softmax_shardy( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + softmax_type, + bad_sharding, + broadcast_batch_mask, + ): + self.impl_test_softmax( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape=[32, 12, 128, 128], + softmax_type=softmax_type, + scale_factor=1.0, + dtype=DTYPES[0], + bad_sharding=bad_sharding, + broadcast_batch_mask=broadcast_batch_mask, + use_shardy=True, + ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index c27f6f50f7..21d5503e3e 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -10,6 +10,7 @@ import jax import jax.numpy as jnp from jax import dtypes +from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec import transformer_engine_jax @@ -406,6 +407,54 @@ def sharded_impl(x, scale): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + out_dtype, + act_enum, + act_len, + scaling_mode, + is_2x, + scale_dtype, + scale_shapes, + is_outer, + mesh, + value_types, + result_types, + ): + del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types + + x_rank = len(value_types[0].shape) + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + x_rank - 1, unique_var="i", flatten_axis=-2 + ) + x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) + out = (*x_axes[:-2], x_axes[-1]) + scale_inv = scale_rules.rowwise_rule + colwise_scale_inv = scale_rules.colwise_rule + + if is_2x: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + colwise_out = tuple( + multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) + ) + else: + colwise_out = out + else: + colwise_out = ("j",) + colwise_scale_inv = ("k",) + + # amax is always a unit tensor. + amax = ("l",) + + return SdyShardingRule( + ( + x_axes, + "…1", + ), + (out, colwise_out, scale_inv, colwise_scale_inv, amax), + **scale_rules.factor_sizes, + ) + register_primitive(ActLuPrimitive) @@ -819,6 +868,46 @@ def sharded_impl(dz, x, scale): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + out_dtype, + scaling_mode, + is_2x, + scale_dtype, + scale_shapes, + is_dbias, + act_enum, + act_len, + is_outer, + mesh, + value_types, + result_types, + ): + del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types + + x_rank = len(value_types[1].shape) + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + x_rank, unique_var="i", flatten_axis=-2 + ) + x_axes = scale_rules.input_spec + out = x_axes + if is_2x: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) + else: + colwise_out = tuple(x_axes) + else: + colwise_out = ("j",) + + dbias = x_axes[-2:] if is_dbias else ("k",) + amax = ("…4",) + + return SdyShardingRule( + (("…0",), tuple(x_axes), ("…2",)), + (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + **scale_rules.factor_sizes, + ) + register_primitive(DActLuDBiasQuantizePrimitive) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 7a31fa729d..ea682d4c47 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -14,6 +14,7 @@ import jax.numpy as jnp from jax import dtypes, lax from jax.sharding import PartitionSpec, NamedSharding +from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax from transformer_engine_jax import NVTE_Fused_Attn_Backend @@ -42,6 +43,7 @@ get_mesh_axis_rank, get_all_mesh_axes, num_of_devices, + with_sharding_constraint, ) @@ -618,6 +620,35 @@ def partition(config, mesh, arg_infos, result_infos): impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + del mesh, result_types + + # Keep in sync with `infer_sharding_from_operands`. + # We only need the first input. Fill up the rest with placeholders. + input_spec = [(f"…{x}",) for x in range(len(value_types))] + # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint + # instead. This has to happen outside of the primitive, see `fused_attn_fwd`. + rng_sharding = (f"…{len(value_types)}",) + + if config.qkv_layout.is_qkvpacked(): + input_spec[0] = ("…0", "seqlen", "three", "head", "hidden") + elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate(): + input_spec[0] = ("…0", "seqlen", "head", "hidden") + else: + raise ValueError(f"Unsupported {config.qkv_layout=}") + + is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd() + out_sharding = ("…0", "seqlen", "head", "hidden") + if is_packed_softmax: + softmax_aux_sharding = ("…0", "seqlen", "head", "i") + else: + softmax_aux_sharding = ("…0", "head", "seqlen", "i") + + return SdyShardingRule( + tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding) + ) + register_primitive(FusedAttnFwdPrimitive) @@ -998,6 +1029,15 @@ def sharded_impl( return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + del config, mesh + # We only care about the four first arguments. + # Keep in sync with `infer_sharding_from_operands`. + input_spec = tuple((f"…{x}",) for x in range(len(value_types))) + output_spec = tuple((f"…{x}",) for x in range(len(result_types))) + return SdyShardingRule(input_spec, output_spec) + register_primitive(FusedAttnBwdPrimitive) @@ -2436,13 +2476,15 @@ def fused_attn_fwd( primitive = FusedRingAttnFwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) - return primitive.bind( + output, softmax_aux, rng_state = primitive.bind( *qkv_for_primitive, bias, seed, *seq_desc_flatten, config=fused_config, ) + rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None)) + return (output, softmax_aux, rng_state) def fused_attn_bwd( diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 5d64fa9bb6..1c9bade0e7 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -98,6 +98,15 @@ def partition(): """ return NotImplemented + @staticmethod + @abstractmethod + def shardy_sharding_rule(*args): + """ + Returns the sharding rule for this primitive. + """ + del args + return "... -> ..." + def register_primitive(cls): """ @@ -123,7 +132,9 @@ def name_of_wrapper_p(): batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition + infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition, + sharding_rule=cls.shardy_sharding_rule, ) mlir.register_lowering( outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 388d4f17ee..3aec30f420 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -12,6 +12,7 @@ import jax import jax.numpy as jnp from jax import dtypes +from jax.experimental.custom_partitioning import SdyShardingRule from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec @@ -519,6 +520,57 @@ def sharded_impl(x, scale, gamma, beta): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + norm_type, + zero_centered_gamma, + epsilon, + out_dtype, + scaling_mode, + is_2x, + scale_dtype, + scale_shapes, + is_outer, + mesh, + value_types, + result_types, + ): + del ( + zero_centered_gamma, + epsilon, + out_dtype, + scale_dtype, + scale_shapes, + is_outer, + mesh, + result_types, + ) + + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + len(value_types[0].shape), unique_var="i", flatten_axis=-1 + ) + x_axes = scale_rules.input_spec + + out = x_axes[:-1] + ("k",) + colwise_out = out if is_2x else ("…4",) + rsigma = x_axes[:-1] + mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = ("…6",) + + return SdyShardingRule( + (x_axes, ("…1",), ("…2",), ("…3",)), + ( + out, + colwise_out, + scale_rules.rowwise_rule, + scale_rules.colwise_rule, + amax, + mu, + rsigma, + ), + **scale_rules.factor_sizes, + ) + register_primitive(NormFwdPrimitive) @@ -722,6 +774,11 @@ def sharded_impl(dz, x, mu, rsigma, gamma): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule(*args): + del args + return "...0, ...1 i, ...2, ...3, ...4 -> ...1 j, k, l" + register_primitive(NormBwdPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 2911b5a420..23d8572994 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -10,6 +10,7 @@ import jax import jax.numpy as jnp from jax import dtypes +from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec import transformer_engine_jax @@ -470,6 +471,48 @@ def sharded_impl(x, scale): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def shardy_sharding_rule( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + scale_shapes, + is_dbias, + is_outer, + mesh, + value_types, + result_types, + ): + del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types + + scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( + len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis + ) + + x_axes = scale_rules.input_spec + colwise_scale_inv = scale_rules.colwise_rule + + out = x_axes + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) + else: + colwise_out = x_axes + else: + colwise_out = ("j",) + colwise_scale_inv = ("k",) + + dbias = x_axes[flatten_axis:] if is_dbias else ("l",) + amax = ("m",) + + return SdyShardingRule( + (x_axes, ("…1",)), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, + ) + register_primitive(DBiasQuantizePrimitive) diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index b50e98081d..e16bf76c94 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -330,6 +330,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "... -> ..." + register_primitive(ScaledSoftmaxFwdPrimitive) @@ -400,6 +405,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "..., ... -> ..." + register_primitive(ScaledSoftmaxBwdPrimitive) @@ -525,6 +535,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "...1, ...2 -> ...1" + register_primitive(ScaledMaskedSoftmaxFwdPrimitive) @@ -596,6 +611,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "..., ... -> ..." + register_primitive(ScaledMaskedSoftmaxBwdPrimitive) @@ -682,6 +702,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): result_infos, ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "... -> ..." + register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) @@ -761,6 +786,11 @@ def partition(scale_factor, mesh, arg_infos, result_infos): result_infos, ) + @staticmethod + def shardy_sharding_rule(*args): + del args + return "..., ... -> ..." + register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 34f63a994c..2a5b23bdcf 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -16,13 +16,33 @@ from functools import reduce import operator +from jax.experimental.custom_partitioning import CompoundFactor from jax.tree_util import register_pytree_node_class import jax.numpy as jnp from transformer_engine_jax import JAXX_Scaling_Mode -__all__ = ["ScalingMode"] +__all__ = ["QuantizeShardyRules", "ScalingMode"] + + +@dataclass +class QuantizeShardyRules: + """Information necessary to shard scale tensors with Shardy. + + Attributes: + input_spec: Specification for the input axes + rowwise_rule: Sharding rule for the row-wise scale tensor, depends on + the axes in `input_spec` + colwise_rule: Likewise for the column-wise scale tensor. + factor_sizes: For block scaling, contains the block size factor, which is + used in `input_spec`. + """ + + input_spec: Tuple[str] + rowwise_rule: Tuple[str] + colwise_rule: Tuple[str] + factor_sizes: Dict[str, int] class ScalingModeMetadataImpl(ABC): @@ -59,6 +79,21 @@ def get_scale_shape( The shape for scale tensors """ + @abstractmethod + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for delayed scaling mode. @@ -95,6 +130,23 @@ def get_scale_shape( del data_shape, is_colwise return (1,) + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + del flatten_axis + input_spec = tuple(f"x{i}" for i in range(input_rank)) + return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {}) + class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for block scaling mode. @@ -217,6 +269,45 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + + Returns: + The Shardy rules for the scaling mode + """ + input_spec = [f"x{i}" for i in range(input_rank)] + + # We have to use two different factors in the two CompoundFactors because of Shardy + # verifier requirements, even though they are the same. + rowwise_var = unique_var + colwise_var = f"{unique_var}_" + input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") + input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") + + # The rowwise and colwise scale tensors should be sharded the same way as the input. + # However, we need to adjust the dimensions where the block scaling factor applies. + rowwise = input_spec.copy() + rowwise[-1] = rowwise_var + + colwise = input_spec.copy() + colwise[flatten_axis - 1] = colwise_var + + # This implementation needs to be updated for different block dims. + assert self._block_dims == (1, 32) + + return QuantizeShardyRules( + tuple(input_spec), + tuple(rowwise), + tuple(colwise), + {"block_size_rowwise": 32, "block_size_colwise": 32}, + ) + @dataclass(frozen=True) @register_pytree_node_class @@ -290,6 +381,20 @@ def get_scale_shape( """ return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis=-1 + ) -> Tuple[Tuple[str]]: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + + Returns: + The Shardy rules for the scaling mode + """ + return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + def __eq__(self, other): """Compare this scaling mode with another. From 4c9626e7ef7ea5325efdf6511156c651de2e69db Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 15 Apr 2025 01:55:29 +0800 Subject: [PATCH 28/53] [PyTorch][MoE] Enable New Recipes for Grouped Linear (#1525) * Enable MXFP8 and Per-Tensor Current Scaling for Grouped Linear Signed-off-by: Xin Yao * enable float8blockwise Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Xin Yao * remove grouped linear parallel mode test Signed-off-by: Xin Yao * update test Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * internal=False for now Signed-off-by: Xin Yao * remove unused import Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 55 ++---- tests/pytorch/test_sanity.py | 8 +- .../pytorch/cpp_extensions/gemm.py | 15 +- transformer_engine/pytorch/csrc/extensions.h | 8 +- .../pytorch/csrc/extensions/pybind.cpp | 10 +- .../pytorch/csrc/extensions/quantizer.cpp | 1 + .../pytorch/csrc/extensions/transpose.cpp | 43 ++--- .../pytorch/module/fp8_padding.py | 22 ++- .../pytorch/module/fp8_unpadding.py | 20 ++- .../pytorch/module/grouped_linear.py | 162 +++++++++++++----- 10 files changed, 210 insertions(+), 134 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 7a930b6cde..1dd7a9fd54 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1470,8 +1470,7 @@ def _test_grouped_linear_accuracy( if num_gemms > 1: split_size = 1 if fp8: - if recipe.delayed(): - split_size = 16 + split_size = 16 if recipe.mxfp8(): split_size = 128 m = config.seq_len // split_size @@ -1509,12 +1508,11 @@ def _test_grouped_linear_accuracy( return outputs -@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_grouped_linear_accuracy( @@ -1522,22 +1520,18 @@ def test_grouped_linear_accuracy( num_gemms, bs, model, - fp8, recipe, fp8_model_params, fuse_wgrad_accumulation, parallel_mode=None, ): + fp8 = recipe is not None if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: + if fp8 and recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches - pytest.skip("MXFP8 unsupported for grouped linear.") - if fp8 and recipe.float8_current_scaling(): - pytest.skip("Float8 Current Scaling unsupported for grouped linear.") - if recipe.float8_block_scaling(): - pytest.skip("Grouped linear for FP8 blockwise unsupported.") + if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1591,24 +1585,7 @@ def test_grouped_linear_accuracy( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("parallel_mode", ["column", "row"]) -@pytest.mark.parametrize("recipe", fp8_recipes) -def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): - """Split the tests to save CI time""" - test_grouped_linear_accuracy( - dtype=torch.float32, - num_gemms=6, - bs=2, - model="126m", - fp8=True, - recipe=recipe, - fp8_model_params=True, - parallel_mode=parallel_mode, - fuse_wgrad_accumulation=True, - ) - - -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( @@ -1616,7 +1593,6 @@ def test_grouped_linear_accuracy_single_gemm(recipe): num_gemms=1, bs=2, model="126m", - fp8=True, recipe=recipe, fp8_model_params=True, fuse_wgrad_accumulation=True, @@ -1626,9 +1602,12 @@ def test_grouped_linear_accuracy_single_gemm(recipe): def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): - """Padding tensor shapes to multiples of 16.""" + align_size = 16 + if recipe.mxfp8(): + align_size = 32 padded_tokens_per_expert = [ - (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert + (num_tokens + align_size - 1) // align_size * align_size + for num_tokens in tokens_per_expert ] hidden_states = torch.split(hidden_states, tokens_per_expert) padded_hidden_states = [] @@ -1729,12 +1708,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches - pytest.skip("MXFP8 unsupported for grouped linear.") - if fp8 and recipe.float8_current_scaling(): - pytest.skip("Float8 Current Scaling unsupported for grouped linear.") - if recipe.float8_block_scaling(): - pytest.skip("Float8 block scaling unsupported for grouped linear.") + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] if config.seq_len % 16 != 0 and fp8: diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index d8552d63a4..661ee1e046 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -553,14 +553,10 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8(): - pytest.skip("Grouped linear does not support MXFP8") - if fp8_recipe.float8_current_scaling(): - pytest.skip("Grouped linear does not support FP8 current scaling") - if fp8_recipe.float8_block_scaling(): - pytest.skip("Grouped linear does not support FP8 block scaling") if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 79d6391e79..737d92eb75 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -9,10 +9,9 @@ import torch import transformer_engine_torch as tex from ..constants import TE_DType -from ..utils import assert_dim_for_fp8_exec, get_sm_count +from ..utils import get_sm_count from ..tensor.quantized_tensor import Quantizer -from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase @@ -174,14 +173,6 @@ def general_grouped_gemm( transa = layout[0] == "T" transb = layout[1] == "T" - # assert [a.is_contiguous() for a in A] - # assert [b.is_contiguous() for b in B] - - if isinstance(A[0], Float8TensorBase): - for a, b in zip(A, B): - assert_dim_for_fp8_exec(a._data) - assert_dim_for_fp8_exec(b._data) - empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms @@ -208,6 +199,8 @@ def general_grouped_gemm( for o in out ] # this should differ with respect to single output + # TODO: Move the swizzle to the C++ side. # pylint: disable=fixme + original_scale_inverses_list = [swizzle_inputs(A[i], B[i], layout) for i in range(num_gemms)] bias = tex.te_general_grouped_gemm( A, transa, @@ -227,5 +220,7 @@ def general_grouped_gemm( use_split_accumulator, sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))), ) + for i in range(num_gemms): + reset_swizzled_inputs(A[i], B[i], original_scale_inverses_list[i]) return out, bias, gelu_input diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a66fbf950d..14609762fc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -101,18 +101,22 @@ std::optional> te_general_grouped_gemm( bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +namespace transformer_engine::pytorch { + /*************************************************************************************************** * Transpose **************************************************************************************************/ -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, std::vector quantizer_list, transformer_engine::DType otype); at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, std::optional output = std::nullopt); +} // namespace transformer_engine::pytorch + namespace transformer_engine::pytorch { /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 617ba42d4a..9b11ec5685 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -196,12 +196,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); - m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", - py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); + m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, + "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), + py::arg("quantizer_list"), py::arg("otype")); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); - m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), - py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", + py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), + py::call_guard()); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax")); diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index fbf31a7f5b..3be719eaf6 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -109,6 +109,7 @@ std::pair Float8Quantizer::create_tensor( } const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); opts = opts.dtype(torch::kFloat32); + // TODO: Replace with an empty tensor. at::Tensor scale_inv = at::reciprocal(scale); py::object ret; if (internal) { diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index e12990f79c..5b8c121517 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -6,27 +6,38 @@ #include -#include "ATen/core/TensorBody.h" #include "extensions.h" +#include "pybind.h" -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, +namespace transformer_engine::pytorch { + +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, std::vector quantizer_list, transformer_engine::DType otype) { - using namespace transformer_engine::pytorch; + init_extension(); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; std::vector py_output_objects_list; std::vector tensor_wrappers; - auto none = py::none(); + if (output_list.has_value()) { + py_output_objects_list = output_list.value(); + } + + // Choose implementation + // Note: Currently only have fused kernel for FP8 cast-transpose + bool with_fused_kernel = true; // create TE tensors from input for (size_t i = 0; i < input_list.size(); i++) { - auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + auto input_tensor = makeTransformerEngineTensor(input_list[i]); const NVTEShape input_shape = input_tensor.shape(); transformer_engine::TensorWrapper output_tensor; + if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) { + with_fused_kernel = false; + } if (output_list == std::nullopt) { std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); @@ -48,16 +59,8 @@ std::vector fused_multi_quantize(std::vector input_list, NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), "Number of input and output tensors must match"); - // Choose implementation - // Note: Currently only have fused kernel for FP8 cast-transpose - bool with_fused_kernel = true; for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { - const auto& tensor = nvte_tensor_output_list[i]; - if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) { - with_fused_kernel = false; - break; - } - if (nvte_tensor_columnwise_data(tensor) == nullptr) { + if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) { with_fused_kernel = false; break; } @@ -68,10 +71,8 @@ std::vector fused_multi_quantize(std::vector input_list, nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); } else { - for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { - // TODO: switch to nvte_quantize_v2 with advanced numerical options - nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], - at::cuda::getCurrentCUDAStream()); + for (size_t i = 0; i < py_output_objects_list.size(); i++) { + quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt); } } return py_output_objects_list; @@ -79,7 +80,7 @@ std::vector fused_multi_quantize(std::vector input_list, at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, std::optional output) { - using namespace transformer_engine::pytorch; + init_extension(); const auto dim = input.dim(); NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); @@ -106,3 +107,5 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, return out; } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 2549d45728..9748408338 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -4,12 +4,13 @@ """FP8 Padding API""" -from typing import Union, List +from typing import List, Optional, Tuple import torch import transformer_engine_torch as tex +from ..fp8 import FP8GlobalStateManager from ..jit import no_torch_dynamo @@ -74,22 +75,30 @@ class Fp8Padding(torch.nn.Module): ---------- num_gemms: int number of GEMMs to be performed simutaneously. + align_size: int, optional + the alignment size for the input tensor. If not provided, the alignment size will + be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. """ def __init__( self, - num_gemms, + num_gemms: int, + align_size: Optional[int] = None, ) -> None: super().__init__() self.num_gemms = num_gemms + if align_size is None: + self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + else: + self.align_size = align_size @no_torch_dynamo() def forward( self, inp: torch.Tensor, m_splits: List[int], - ) -> Union[torch.Tensor, List[int]]: + ) -> Tuple[torch.Tensor, List[int]]: """ Apply the padding to the input. @@ -104,7 +113,12 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." # FP8 padding calculate - padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + padded_m_splits = [ + (m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits + ] + # no padding needed + if m_splits == padded_m_splits: + return inp, m_splits if torch.is_grad_enabled(): fn = _Fp8Padding.apply diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 479b91d396..7e1fbcb2a3 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -4,12 +4,13 @@ """FP8 Padding API""" -from typing import List +from typing import List, Optional import torch import transformer_engine_torch as tex +from ..fp8 import FP8GlobalStateManager from ..jit import no_torch_dynamo @@ -70,15 +71,23 @@ class Fp8Unpadding(torch.nn.Module): ---------- num_gemms: int number of GEMMs to be performed simutaneously. + align_size: int, optional + the alignment size for the input tensor. If not provided, the alignment size will + be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. """ def __init__( self, - num_gemms, + num_gemms: int, + align_size: Optional[int] = None, ) -> None: super().__init__() self.num_gemms = num_gemms + if align_size is None: + self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + else: + self.align_size = align_size @no_torch_dynamo() def forward( @@ -100,7 +109,12 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." # FP8 padding calculate - padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + padded_m_splits = [ + (m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits + ] + # no padding needed + if m_splits == padded_m_splits: + return inp if torch.is_grad_enabled(): fn = _Fp8Unpadding.apply diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1ea66a7f2c..4681c1121e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -9,6 +9,7 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_multi_stream_cublas_workspace, TransformerEngineBaseModule, @@ -37,7 +38,6 @@ from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..tensor.float8_tensor import Float8Tensor from ..cpu_offload import is_cpu_offload_enabled from ..tensor.quantized_tensor import ( @@ -47,7 +47,6 @@ restore_from_saved, ) - __all__ = ["GroupedLinear"] @@ -85,15 +84,6 @@ def forward( biases = weights_and_biases[num_gemms:] device = inp.device - # TODO Support MXFP8 # pylint: disable=fixme - if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): - raise NotImplementedError("GroupedLinear does not yet support MXFP8") - # TODO Support Float8 Current Scaling # pylint: disable=fixme - if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): - raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") - if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): - raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling") - # Make sure input dimensions are compatible in_features = weights[0].shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" @@ -126,7 +116,11 @@ def forward( for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator inputmats = tex.fused_multi_quantize( inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] ) @@ -167,7 +161,7 @@ def forward( m_splits=m_splits, bias=biases, use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, + use_split_accumulator=fprop_gemm_use_split_accumulator, ) if fp8_calibration: @@ -182,6 +176,16 @@ def forward( ctx.weights_shape_1 = weights[0].shape[1] + # TODO: update after #1638 is merged. # pylint: disable=fixme + if weight_requires_grad: + for inputmat in inputmats: + if isinstance(inputmat, QuantizedTensor): + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if inp.requires_grad: + for weight in weights_fp8: + if isinstance(weight, QuantizedTensor): + weight.update_usage(columnwise_usage=True) + tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, @@ -202,6 +206,7 @@ def forward( ctx.num_gemms = num_gemms ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -247,10 +252,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_biases = [None] * ctx.num_gemms if ctx.fp8: if ctx.use_bias: - for i in range(ctx.num_gemms): - grad_biases[i], grad_output[i] = tex.bgrad_quantize( - grad_output_mats[i], ctx.grad_output_quantizers[i] - ) + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready + # for Float8BlockQuantizer. + if ctx.fp8_recipe.float8_block_scaling(): + for i in range(ctx.num_gemms): + grad_biases[i] = grad_output_mats[i].sum(dim=0) + grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i]) + else: + for i in range(ctx.num_gemms): + grad_biases[i], grad_output[i] = tex.bgrad_quantize( + grad_output_mats[i], ctx.grad_output_quantizers[i] + ) else: grad_output = tex.fused_multi_quantize( grad_output_mats, @@ -269,6 +281,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.requires_dgrad: + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_dgrad.use_split_accumulator + ) dgrad = torch.empty( (sum(ctx.m_splits), ctx.weights_shape_1), dtype=ctx.activation_dtype, @@ -285,10 +304,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], layout="NN", m_splits=ctx.m_splits, grad=True, - use_split_accumulator=_2X_ACC_DGRAD, + use_split_accumulator=dgrad_gemm_use_split_accumulator, ) if ctx.weights_requires_grad: + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: @@ -308,7 +334,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], m_splits=ctx.m_splits, use_bias=ctx.use_bias if grad_biases[0] is None else None, bias=biases, - use_split_accumulator=_2X_ACC_WGRAD, + use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, ) for i in range(ctx.num_gemms): @@ -374,8 +400,8 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, - None, # is_grad_enabled - None, # is_grad_enabled + None, + None, *wgrad_list, *grad_biases, ) @@ -425,6 +451,9 @@ class GroupedLinear(TransformerEngineBaseModule): the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and + `parallel_mode` are used to determine the shapes of weights and biases. + The TP communication should be handled in the dispatch and combine stages of MoE models. """ def __init__( @@ -467,7 +496,11 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} + self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} + self._num_fp8_tensors_per_gemm = { + "fwd": 3, + "bwd": 2, + } if tp_group is None: self.tp_size = tp_size @@ -478,6 +511,12 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if self.tp_size > 1 and bias: + raise ValueError( + "GroupedLinear doesn't support bias when TP > 1. " + "Because the TP communication is handled outside of this module." + ) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -504,7 +543,7 @@ def __init__( ), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=self._offsets["weight"] + i, + fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], ) # Construct bias parameters if needed @@ -529,12 +568,18 @@ def __init__( self.reset_parameters(defer_init=device == "meta") - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.apply_bias: - self.gemm_bias_unfused_add = True - else: - self.gemm_bias_unfused_add = False + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + assert not self.tp_size > 1, ( + "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " + "Because the TP communication is handled outside of this module." + ) + self._customize_quantizers_float8_current_scaling(fwd, recipe) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -592,7 +637,7 @@ def forward( produced) """ assert not isinstance( - inp, Float8Tensor + inp, QuantizedTensor ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." @@ -617,20 +662,27 @@ def forward( grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms if self.fp8: input_quantizers = [ - self.quantizers["scaling_fwd"][self._offsets["input"] + i] + self.quantizers["scaling_fwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ] for i in range(self.num_gemms) ] + # TODO: use internal after #1638 is merged. # pylint: disable=fixme for i in range(self.num_gemms): - input_quantizers[i].internal = True + input_quantizers[i].internal = False weight_quantizers = [ - self.quantizers["scaling_fwd"][self._offsets["weight"] + i] + self.quantizers["scaling_fwd"][ + self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ] for i in range(self.num_gemms) ] for i in range(self.num_gemms): weight_quantizers[i].internal = True if torch.is_grad_enabled(): grad_output_quantizers = [ - self.quantizers["scaling_bwd"][self._offsets["input"] + i] + self.quantizers["scaling_bwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] + ] for i in range(self.num_gemms) ] for i in range(self.num_gemms): @@ -645,7 +697,7 @@ def forward( args += ( inp, m_splits, - self.apply_bias and not self.gemm_bias_unfused_add, + self.apply_bias, is_first_microbatch, self.fp8, self.fp8_calibration, @@ -665,17 +717,37 @@ def forward( ) out = linear_fn(*args) - if self.gemm_bias_unfused_add: - out_shape = out.shape - out = torch.cat( - [ - o + cast_if_needed(b, self.activation_dtype) - for o, b in zip( - torch.split(out.view(-1, self.out_features), m_splits), bias_tensors - ) - ] - ).view(out_shape) - if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + for i in range(self.num_gemms): + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + else: + for i in range(self.num_gemms): + # set grad_output_quantizer with amax epsilon and power_2_scale + self.quantizers["scaling_bwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon From 48f3ca9090ffc1d373cbe40cfc85076f00488ff4 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 14 Apr 2025 10:57:54 -0700 Subject: [PATCH 29/53] [PyTorch] Avoid unnecessary tensor usages when caching for linear op backward (#1676) * Avoid unnecessary tensor usages when caching for linear op backward Signed-off-by: Tim Moon * Debug test failure Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/basic/basic_linear.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index f4d4254537..86f17608f4 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -413,7 +413,6 @@ def _functional_forward( x = None x_async = None with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel - own_quantized_x_local = False if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") @@ -429,7 +428,6 @@ def _functional_forward( else: if not isinstance(x_local, QuantizedTensor): x_local = input_quantizer(x_local) - own_quantized_x_local = True x = x_local else: if isinstance(x_local, QuantizedTensor): @@ -528,16 +526,16 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Configure input tensor for backward pass - if own_quantized_x_local: - x_local.update_usage(rowwise_usage=False, columnwise_usage=True) - # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save # input tensor as context for backward pass. if x_local is input: x_local = x_local.detach() + # Configure input tensor for backward pass + if with_quantized_compute and isinstance(x_local, QuantizedTensor): + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) + return y, x_local, w @staticmethod From 5fdd7bb94c702a2e48ccd060b84362b9e63ac684 Mon Sep 17 00:00:00 2001 From: Jianbin Chang Date: Tue, 15 Apr 2025 07:25:04 +0800 Subject: [PATCH 30/53] [PyTorch] check and try to generate fp8 weight transpose cache before dgrad backward (#1648) * Add fp8 weight transpose cache check in backward, and regenerated it if it does not exist Signed-off-by: jianbinc * Properly handle fsdp shard model weight input. Signed-off-by: jianbinc * move Float8Tensor to QuantizedTensor in cast_master_weights_to_fp8 UT Signed-off-by: jianbinc * handle Float8TensorBase issue Signed-off-by: jianbinc * fix bug in activation recompute Signed-off-by: jianbinc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: jianbinc Co-authored-by: Kirthi Shankar Sivamani --- .../run_cast_master_weights_to_fp8.py | 16 +++++----------- .../pytorch/module/grouped_linear.py | 8 +++++++- .../pytorch/module/layernorm_linear.py | 6 ++++++ .../pytorch/module/layernorm_mlp.py | 14 ++++++++++++++ transformer_engine/pytorch/module/linear.py | 7 +++++++ transformer_engine/pytorch/tensor/utils.py | 7 +++++++ 6 files changed, 46 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py index ec06bb7e48..1b38f72512 100644 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -243,10 +243,10 @@ def __init__(self, weights, lr, dp_group): # Flatten the weights and pad to align with world size raw_data_list = [ - _get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1) + _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1) for w in weights ] - if isinstance(weights[0], Float8Tensor): + if isinstance(weights[0], QuantizedTensor): raw_data_list = [_get_raw_data(w).view(-1) for w in weights] else: raw_data_list = [w.view(-1) for w in weights] @@ -282,7 +282,7 @@ def __init__(self, weights, lr, dp_group): self.weight_indices.append((None, None)) self.shard_indices.append((None, None)) - if isinstance(weights[idx], Float8Tensor): + if isinstance(weights[idx], QuantizedTensor): replace_raw_data( weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) ) @@ -378,19 +378,13 @@ def step(self): master_weight -= grad * self.lr # Step 3: Cast master weights to FP8 or BF16 precision - if isinstance(self.weights[0], Float8Tensor): + if isinstance(self.weights[0], QuantizedTensor): local_weights = [] - for model_weight, local_weight in zip(self.weights, self.local_weights): + for local_weight in self.local_weights: if local_weight is None: local_weights.append(None) continue - quantizer = model_weight._get_quantizer() - if isinstance(quantizer, Float8CurrentScalingQuantizer): - local_weight = quantizer.create_tensor_from_data( - local_weight.view(-1), - model_weight.dtype, - ) local_weights.append(local_weight) cast_master_weights_to_fp8( diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4681c1121e..166612ccd4 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -173,7 +173,7 @@ def forward( weight_quantizers[i].calibrate(weights[i]) if is_grad_enabled: - + ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] # TODO: update after #1638 is merged. # pylint: disable=fixme @@ -294,6 +294,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) + for weight, quantizer in zip(weights, ctx.weight_quantizers): + if quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) general_grouped_gemm( weights, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index df3ae05f31..c82a0e2153 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -323,6 +323,7 @@ def forward( clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) @@ -651,6 +652,11 @@ def backward( if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor): + weight.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) dgrad, *_ = general_gemm( weight, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e51fe43cc0..1bf791c12b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -478,6 +478,8 @@ def forward( fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, ) + ctx.fc1_weight_quantizer = fc1_weight_quantizer + ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: clear_tensor_data(ln_out) @@ -749,6 +751,11 @@ def backward( ) # FC2 DGRAD; Unconditional + if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor): + ctx.fc2_weight.update_usage( + rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage, + ) gemm_output, *_ = general_gemm( fc2_weight, grad_output, @@ -895,6 +902,13 @@ def backward( fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) # FC1 DGRAD: Unconditional + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensor + ): + ctx.fc1_weight.update_usage( + rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage, + columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage, + ) fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( fc1_weight, dact, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2887b2e452..2556987fed 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -277,6 +277,7 @@ def forward( nvtx_range_pop(f"{nvtx_label}.gemm") if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer saved_inputmat = None ctx.backward_input_needs_gather = ( @@ -574,6 +575,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], recipe.fp8_gemm_dgrad.use_split_accumulator ) + if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor): + weight_fp8.update_usage( + rowwise_usage=ctx.weight_quantizer.rowwise_usage, + columnwise_usage=ctx.weight_quantizer.columnwise_usage, + ) + dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 33c0953d94..8dd04b52d0 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -305,4 +305,11 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo amax=torch.Tensor(), fp8_dtype=model_weight._fp8_dtype, ) + if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor): + # NOTE: The fsdp shard model weight may be a unit8 tensor instead of + # a float8 tensor. We should handle this situation properly. + model_weight_fragment = quantizer.create_tensor_from_data( + model_weight_fragment.view(-1), + model_weight.dtype, + ) quantizer.update_quantized(master_weight, model_weight_fragment) From 313ab4f44fe885d3e99ea69392ce462bdf129374 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 14 Apr 2025 21:27:58 -0400 Subject: [PATCH 31/53] [JAX] Improving the test_multiprocessing_encoder.py run script (#1673) * script improvement * add wait * add return code back * relax tols for FP8 test in test_multiprocessing_ by 0.001 --------- Signed-off-by: Phuong Nguyen --- .../run_test_multiprocessing_encoder.sh | 70 ++++++++++++------- .../encoder/test_multiprocessing_encoder.py | 2 +- qa/L0_jax_distributed_unittest/test.sh | 2 + 3 files changed, 49 insertions(+), 25 deletions(-) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index ff38c7e335..56fa28dde7 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -4,32 +4,54 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} -for i in $(seq 0 $(($NUM_GPUS-1))) -do - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i & -done -wait +# Define the test cases to run +TEST_CASES=( +"test_te_bf16" +"test_te_delayed_scaling_fp8" +"test_te_mxfp8" +"test_te_bf16_shardy" +"test_te_delayed_scaling_fp8_shardy" +) -for i in $(seq 0 $(($NUM_GPUS-1))) -do - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8 --num-process=$NUM_GPUS --process-id=$i & -done -wait +echo +echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" -for i in $(seq 0 $(($NUM_GPUS-1))) -do - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i & -done -wait +HAS_FAILURE=0 # Global failure flag -for i in $(seq 0 $(($NUM_GPUS-1))) -do - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16_shardy --num-process=$NUM_GPUS --process-id=$i & -done -wait +# Run each test case across all GPUs +for TEST_CASE in "${TEST_CASES[@]}"; do + echo + echo "=== Starting test: $TEST_CASE ..." + + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_CASE}_gpu_${i}.log" -for i in $(seq 0 $(($NUM_GPUS-1))) -do - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8_shardy --num-process=$NUM_GPUS --process-id=$i & + # Run pytest and redirect stdout and stderr to the log file + pytest -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + done + + # Wait for the process to finish + wait + + # Check and print the log content accordingly + if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then + HAS_FAILURE=1 + echo "... $TEST_CASE FAILED" + tail -n +7 "${TEST_CASE}_gpu_0.log" + elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE SKIPPED" + elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE PASSED" + else + echo "Invalid ${TEST_CASE}_gpu_0.log" + fi + + # Remove the log file after processing it + rm ${TEST_CASE}_gpu_*.log done -wait + +exit $HAS_FAILURE diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 56d386a3f5..352160a8ed 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -609,7 +609,7 @@ def test_te_bf16(self): def test_te_delayed_scaling_fp8(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling") - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.505 and result[1] > 0.754 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 3fbfb9cf5c..377a57e909 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -23,7 +23,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" +wait python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" +wait . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py" if [ $RET -ne 0 ]; then From aee78831a69574828fac3d0a06bf477b414d295a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Tue, 15 Apr 2025 09:00:46 +0200 Subject: [PATCH 32/53] [PyTorch] Fix for checkpointing for callables. (#1679) * fix Signed-off-by: Pawel Gadzinski * added test Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test change Signed-off-by: Pawel Gadzinski * changed the test Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_sanity.py | 29 +++++++++++++++++++++++ transformer_engine/pytorch/distributed.py | 11 +++++---- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 661ee1e046..afb17d388a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -42,6 +42,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.utils import replace_raw_data +from transformer_engine.pytorch.distributed import checkpoint from test_numerics import reset_rng_states, dtype_tols # Only run FP8 tests on supported devices. @@ -1285,3 +1286,31 @@ def test_fp8_model_init_high_precision_init_val(): assert not hasattr( weight, "._high_precision_init_val" ), "clear_high_precision_init_val() not work" + + +def test_sanity_checkpointing_on_callables(): + """Test that TE checkpointing works correctly on callable modules.""" + + # torch.autograf.function + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + return inp + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + module = MyFunction.apply + inp = torch.randn(10, 10, device="cuda", requires_grad=True) + + out_checkpoint = checkpoint(module, inp) + out_checkpoint.sum().backward() + grad_checkpoint = inp.grad + + out_standard = module(inp) + out_standard.sum().backward() + grad_standard = inp.grad + + # Assert that gradients are the same + torch.testing.assert_close(grad_checkpoint, grad_standard) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 7a1fde164b..890a1835a8 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -661,10 +661,13 @@ def checkpoint( **kwargs, ) - # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need - # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + from .module.base import TransformerEngineBaseModule + + if isinstance(function, TransformerEngineBaseModule): + # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need + # to scatter/gather activations that we will recompute anyway. + setattr(function, "fsdp_wrapped", False) + setattr(function, "fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing From 66d6afbf6e4bccd3d40e85a3d5ed307229137304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Tue, 15 Apr 2025 09:04:09 +0200 Subject: [PATCH 33/53] [PyTorch] More precise test for the CPU offloading. (#1668) * test change Signed-off-by: Pawel Gadzinski * test fix Signed-off-by: Pawel Gadzinski * small changes Signed-off-by: Pawel Gadzinski * small changes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clear Signed-off-by: Pawel Gadzinski * base Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 2 +- tests/pytorch/test_cpu_offloading.py | 143 +++++++++++++++++++-------- 2 files changed, 102 insertions(+), 43 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 1206012195..ffdc088a49 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -39,7 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_FLASH_ATTN=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index ed7cdda85b..ab4b7634b8 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,41 +2,84 @@ # # See LICENSE for license information. +import os +from contextlib import nullcontext import pytest import torch -from contextlib import nullcontext import transformer_engine.pytorch as te +from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -# Check if FP8 supported +# Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() + +fp8_recipes = [ + None, # non-fp8 + # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet + recipe.Float8CurrentScaling(), + recipe.DelayedScaling(), +] SIZE = 512 +NUM_HEADS = 8 +NUM_LAYERS = 5 +EPSILON = 0.1 + +# Flash attention saves some internal tensor for the backward pass +# that cannot be offloaded to CPU. +assert os.getenv("NVTE_FLASH_ATTN") == "0" -models = { - "linear": te.Linear, - "layernorm_mlp": te.LayerNormMLP, - "layernorm_linear": te.LayerNormLinear, +# Offloading is supported for attention only for fused and flash attention backends, +# so the use of bfloat16 is required. +# +# For the TransformerLayer, activation offloading with dropout is not supported, +# so we set hidden_dropout to 0.0. +model_types = { + "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), + "multihead_attention": lambda: te.MultiheadAttention( + SIZE, NUM_HEADS, params_dtype=torch.bfloat16 + ), + "transformer_layer": lambda: te.TransformerLayer( + SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 + ), } def _get_input(): - return torch.empty((128, SIZE, SIZE)).cuda() + return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() + + +def _get_fp8_weight_cache_size(models, fp8_recipe): + """ + Calculate the total FP8 weight cache size (in MB) for a list of models. + """ + if fp8_recipe is None: + return 0 + params_bytes = 0 + for model in models: + for name, param in model.named_parameters(): + if "weight" in name: + params_bytes += param.numel() -def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): + # One byte for columnwise and one byte for rowwise, + # hence multiply by 2 and convert to MB + # there is 1 byte of scale per 32 elements in mxFP8 + factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1 + return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2) - input_layer = model_cls(SIZE, SIZE) - hidden_layer = model_cls(SIZE, SIZE) - output_layer = model_cls(SIZE, SIZE) - input = _get_input() +def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): + tensor = _get_input() if cpu_offload: offload_context, sync_function = te.get_cpu_offload_context( enabled=True, - num_layers=2, - model_layers=3, + num_layers=len(models) - 1, + model_layers=len(models), offload_activations=True, offload_weights=False, ) @@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): offload_context = nullcontext() sync_function = lambda x: x - with te.fp8_autocast(enabled=fp8), offload_context: - out = input_layer(input) - out = sync_function(out) - with te.fp8_autocast(enabled=fp8), offload_context: - out = hidden_layer(out) - out = sync_function(out) - with te.fp8_autocast(enabled=fp8), offload_context: - out = output_layer(out) - out = sync_function(out) - - max_mem_used = torch.cuda.memory_allocated() / 1024**2 - - out.sum().backward() - - del input_layer - del hidden_layer - del output_layer - del input - del out + for model in models: + with te.fp8_autocast( + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + ), offload_context: + tensor = model(tensor) + tensor = sync_function(tensor) + max_mem_used = torch.cuda.memory_allocated() / (1024**2) torch.cuda.synchronize() return max_mem_used -@pytest.mark.parametrize("fp8", [True, False]) -@pytest.mark.parametrize("model_key", models.keys()) -def test_cpu_offload(fp8, model_key) -> None: +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("model_key", model_types.keys()) +def test_cpu_offload(fp8_recipe, model_key) -> None: + """ + We run three configurations: + (1) No offloading: All activations remain on the GPU between forward and backward passes. + (2) No offloading (one layer): Only the first layer's activations remain on the GPU between + forward and backward passes. + (3) With offloading (all layers): Only the last layer's activations remain on the GPU + between forward and backward passes, while all other layers are offloaded to the CPU. - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) + We expect the memory consumption of configurations (2) and (3) to be similar, with + the difference being the size of the FP8 cache that is not offloaded to the CPU. + We also expect this memory consumption to be smaller than in scenario (1). + """ - model_cls = models[model_key] + model_cls = model_types[model_key] + models_list = [model_cls() for _ in range(NUM_LAYERS)] - without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) - - with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) + if fp8_recipe and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None: + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + without_offloading = _measure_memory_between_forward_and_backward( + models_list, fp8_recipe, False + ) + without_offloading_one_layer = _measure_memory_between_forward_and_backward( + models_list[:1], fp8_recipe, False + ) + with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True) assert with_offloading < without_offloading + + # The only difference between the memory consumption of with_offloading + # and without_offloading_one_layer should be the size of the FP8 weights cache, + # which is not offloaded to the CPU. + memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) + assert ( + memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON + ) From 86928e07ec55dad247bc1745d3a3ce9202a969da Mon Sep 17 00:00:00 2001 From: Li Tao Date: Wed, 16 Apr 2025 01:28:22 +0800 Subject: [PATCH 34/53] Add adam bf16 state with original fp32 kernel (#1640) * support adam bf16 state Signed-off-by: XiaobingSuper * use fp32 kernel but keep bf16 optimizer states to save memory Signed-off-by: lit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: XiaobingSuper Signed-off-by: lit Co-authored-by: XiaobingSuper Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fused_optimizer.py | 28 +++++++++++++++++++ .../pytorch/optimizers/fused_adam.py | 18 ++++++++---- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 507fd3f350..cec25803f2 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -360,6 +360,20 @@ def test_fp16_exp_avg(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_exp_avg(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.bfloat16, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): @@ -389,6 +403,20 @@ def test_fp16_exp_avg_sq(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.bfloat16, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg_sq(self): diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 070f46e937..18f7e2031a 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -133,10 +133,10 @@ def __init__( # Add constraints to dtypes of states. if master_weights and master_weight_dtype not in [torch.float32, torch.float16]: raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.") - if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]: - raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.") - if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]: - raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.") + if exp_avg_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg.") + if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.") # Currently, capturable mode only supports fp32 master weights and optimizer states. # The reason is, if the master weights or optimizer states are not in fp32 dtype, @@ -259,6 +259,10 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): scale (torch.Tensor): A FP32 tensor representing the scaling factor. """ assert unscaled_state.dtype == torch.float32 + if scaled_state.dtype == torch.bfloat16: + scaled_state.copy_(unscaled_state.bfloat16()) + return + dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: assert isinstance(scaled_state, Float8Tensor) @@ -313,8 +317,11 @@ def get_unscaled_state(self, param, state_name): else: assert state[state_name].dtype == torch.float32 unscaled = state[state_name] + elif dtype == torch.bfloat16: + assert state[state_name].dtype == torch.bfloat16 + unscaled = state[state_name].float() else: - raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") + raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.") return unscaled def set_scaled_state(self, param, state_name, unscaled_state): @@ -329,6 +336,7 @@ def set_scaled_state(self, param, state_name, unscaled_state): and 'master_param`. unscaled_state (torch.Tensor): The original high-precision(FP32) state. """ + store_param_remainders = ( self.store_param_remainders and state_name == "master_param" From 0994fb48ba959c3d742a39db68461539f2bb2a37 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:30:26 -0700 Subject: [PATCH 35/53] Fix #1524 and other softmax mask functionality (#1681) * Add test cases for full coverage in jax/test_layer.py - causal and window size None - causal and window size default (-1,1) - no_mask and window size default (-1,1) - no_mask and window size default (2,2) - padding and window size None - padding_causal and window_size (2,2) Signed-off-by: Kshitij Janardan Lakhani * Correct the condition where padding_causal_mask was being mapped to scaled upper triangle Signed-off-by: Kshitij Janardan Lakhani * Fix Issue #1524 Signed-off-by: Kshitij Janardan Lakhani * Add a runner and test cases for jax.flax.module.Softmax class for fwd pass only Segregate runner classes for Softmax module and softmax primitives Signed-off-by: Kshitij Janardan Lakhani * Simplify logic when picking softmax primitives and softmax jax framework calls Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify the logic for performing jax based softmax Signed-off-by: Kshitij Janardan Lakhani * Code clean up Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support table for mask, SWA and Softmax type. Code linting Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Explicit SWA conditons in comments. Fix Typo Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve typo to remove None in SWA comments section Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Janardan Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_layer.py | 47 ++++++++- tests/jax/test_softmax.py | 96 ++++++++++++++++++- .../jax/cpp_extensions/softmax.py | 30 ++++-- transformer_engine/jax/flax/module.py | 45 ++++----- transformer_engine/jax/flax/transformer.py | 16 +++- 5 files changed, 188 insertions(+), 46 deletions(-) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index a21583a98c..d59e130530 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -215,12 +215,53 @@ def enable_fused_attn(): _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, }, # attrs22 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "causal", + _KEY_OF_WINDOW_SIZE: None, + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, + }, + # attrs23 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "causal", + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, + }, + # attrs24 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask", + }, + # attrs25 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask", + _KEY_OF_WINDOW_SIZE: (2, 2), + }, + # attrs26 { _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", _KEY_OF_WINDOW_SIZE: (2, 2), }, + # attrs27 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", + _KEY_OF_WINDOW_SIZE: None, + }, + # attrs28 + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_WINDOW_SIZE: (2, 2), + }, ] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] @@ -370,13 +411,13 @@ def generate_inputs(self, data_shape, dtype): data_rng = jax.random.PRNGKey(2024) inputs = (jax.random.normal(data_rng, data_shape, dtype),) - padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) - causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) + mask_shape = (batch, 1, seqlen, seqlen) + padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8) + causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1) if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]: mask = causal_mask else: mask = padded_mask - ref_masks = (1 - mask,) test_masks = (None, mask) # The second arg of Transformer is encoded tokens. diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 8cc8448979..09386c92ed 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -18,6 +18,7 @@ from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available from transformer_engine.jax.softmax import SoftmaxType, softmax +from transformer_engine.jax.flax.module import Softmax def catch_unsupported(method): @@ -94,7 +95,6 @@ def _setup_inputs(self): case _: raise ValueError(f"Unknown {self.softmax_type=}") - @catch_unsupported def test_forward(self): """ Test transformer_engine.jax.softmax.softmax fwd rule @@ -104,7 +104,6 @@ def test_forward(self): reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor) assert_allclose(primitive_out, reference_out, dtype=self.dtype) - @catch_unsupported def test_backward(self): """ Test transformer_engine.jax.softmax.softmax bwd rule @@ -141,6 +140,50 @@ def grad_func(func, *args, **kwargs): assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) +class SoftmaxPrimitivesRunner(SoftmaxRunner): + """ + Jax Softmax Primitives runner + """ + + @catch_unsupported + def test_forward(self): + return super().test_forward() + + @catch_unsupported + def test_backward(self): + return super().test_backward() + + +class SoftmaxModuleRunner: + """ + Jax Softmax Module runner + """ + + module_runner: SoftmaxRunner + bias: None + + def __init__(self, module_runner, bias): + self.module_runner = module_runner + self.bias = bias + + def test_forward(self): + """ + Test transformer_engine.jax.flax.module.Softmax fwd rule + """ + runner = self.module_runner + runner._setup_inputs() + rng = jax.random.PRNGKey(0) + softmax_module = Softmax( + scale_factor=runner.scale_factor, + softmax_type=runner.softmax_type, + ) + softmax_vars = softmax_module.init(rng, runner.logits, runner.mask) + module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask) + reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor) + assert_allclose(module_out, reference_out, dtype=runner.dtype) + + +# Run softmax primitives test @pytest.mark.parametrize( "b, s_q, s_kv, h", [ @@ -165,7 +208,7 @@ def grad_func(func, *args, **kwargs): pytest.param(jnp.float16, id="FP16"), ], ) -class TestSoftmax: +class TestSoftmaxPrimitives: """ Test transformer_engine.jax.softmax.softmax """ @@ -175,7 +218,7 @@ def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): """ Test forward with parameterized configs """ - runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner.test_forward() @staticmethod @@ -183,5 +226,48 @@ def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): """ Test forward with parameterized configs """ - runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner.test_backward() + + +# Run Softmax module test +@pytest.mark.parametrize( + "b, s_q, s_kv, h", + [ + pytest.param(8, 16, 16, 16, id="8-16-16-16"), + pytest.param(8, 512, 512, 16, id="8-512-512-16"), + pytest.param(2, 8, 16384, 8, id="2-8-16384-8"), + # triggers backup framework implementation due to (s_q % 4) != 0 + pytest.param(8, 511, 512, 16, id="8-511-512-16"), + ], +) +@pytest.mark.parametrize("scale_factor", [0.125]) +@pytest.mark.parametrize( + "softmax_type", + [ + pytest.param(SoftmaxType.SCALED, id="SCALED"), + pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"), + pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(jnp.bfloat16, id="BF16"), + pytest.param(jnp.float16, id="FP16"), + ], +) +class TestSoftmaxModule: + """ + Test transformer_engine.jax.flax.module.Softmax + """ + + @staticmethod + def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): + """ + Test forward with parameterized configs + """ + module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + bias = None + runner = SoftmaxModuleRunner(module_runner, bias) + runner.test_forward() diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index e16bf76c94..1556fa3344 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -31,6 +31,9 @@ "scaled_upper_triang_masked_softmax_fwd", "scaled_upper_triang_masked_softmax_bwd", "is_softmax_kernel_available", + "jax_scaled_softmax", + "jax_scaled_masked_softmax", + "jax_scaled_upper_triang_masked_softmax", ] @@ -422,7 +425,7 @@ def scaled_softmax_bwd( Return FP16/BF16 tensor """ if not ScaledSoftmaxBwdPrimitive.enabled(): - _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits) + _, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits) return vjp_func(dz)[0] return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( @@ -795,11 +798,17 @@ def shardy_sharding_rule(*args): register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) -def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): +def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): + """ + JAX based implementation of scaled softmax + """ return jax.nn.softmax(scale_factor * logits) -def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): +def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): + """ + JAX based implementation of scaled and masked softmax + """ if mask is not None: logits += jax.lax.select( mask > 0, @@ -809,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac return jax.nn.softmax(logits * scale_factor) -def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): +def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): + """ + JAX based implementation of scaled and upper triangle masked softmax + """ mask = 1 - jnp.tril(jnp.ones_like(logits)) logits += jax.lax.select( mask > 0, @@ -825,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: Return FP16/BF16 tensor """ if not ScaledSoftmaxFwdPrimitive.enabled(): - return _jax_scaled_softmax(logits, scale_factor) + return jax_scaled_softmax(logits, scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) @@ -837,7 +849,7 @@ def scaled_masked_softmax_fwd( Return FP16/BF16 tensor """ if not ScaledMaskedSoftmaxFwdPrimitive.enabled(): - return _jax_scaled_masked_softmax(logits, mask, scale_factor) + return jax_scaled_masked_softmax(logits, mask, scale_factor) return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, mask, scale_factor=scale_factor ) @@ -856,7 +868,7 @@ def scaled_masked_softmax_bwd( """ if not ScaledMaskedSoftmaxBwdPrimitive.enabled(): _, vjp_func = jax.vjp( - partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask + partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask ) return vjp_func(dz)[0] return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( @@ -870,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl Return FP16/BF16 tensor """ if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled(): - return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor) + return jax_scaled_upper_triang_masked_softmax(logits, scale_factor) return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, scale_factor=scale_factor ) @@ -885,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd( """ if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled(): _, vjp_func = jax.vjp( - partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits + partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits ) return vjp_func(dz)[0] return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 45ff8d7ed9..ef60052768 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -13,7 +13,6 @@ from flax import linen as nn from flax.linen import partitioning as nn_partitioning from jax import lax -from jax import nn as jax_nn from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name @@ -26,7 +25,12 @@ from ..activation import activation from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes -from ..cpp_extensions import is_softmax_kernel_available +from ..cpp_extensions import ( + is_softmax_kernel_available, + jax_scaled_softmax, + jax_scaled_masked_softmax, + jax_scaled_upper_triang_masked_softmax, +) from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..sharding import get_non_contracting_logical_axes @@ -168,10 +172,10 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp input_dtype = inputs.dtype logits = inputs - if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( + # use primitives + if is_softmax_kernel_available( self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype ): - if bias is not None: logits = logits + bias.astype(input_dtype) @@ -180,31 +184,22 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp mask_ = None outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type) + # use default jax based implementation else: - attention_bias = None - if mask is not None: - attention_bias = lax.select( - mask > 0, - jnp.full(mask.shape, -1e10), - jnp.full(mask.shape, 0.0), - ) - attention_bias = attention_bias.astype(input_dtype) - if bias is not None: - attention_bias = _combine_biases(attention_bias, bias) - - if attention_bias is not None: - logits = logits + attention_bias.astype(input_dtype) + logits = logits + bias.astype(input_dtype) - # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED - # and kernel is unavailable, then try on pure scaled softmax custom calls. - if is_softmax_kernel_available( - SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype - ): - outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) + if self.softmax_type is SoftmaxType.SCALED: + outputs = jax_scaled_softmax(logits, self.scale_factor) + elif self.softmax_type is SoftmaxType.SCALED_MASKED: + outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor) + elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: + outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor) else: - outputs = jax_nn.softmax(logits * self.scale_factor) - + raise ValueError( + f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED," + " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]" + ) assert input_dtype == outputs.dtype return outputs diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 70a4da9186..10a0a06824 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -220,11 +220,11 @@ def convert_to_softmax_type(attn_mask_type, mask): if mask is not None: mask = apply_swa_mask(mask) # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + if mask is not None: + return SoftmaxType.SCALED_MASKED, mask + if attn_mask_type is AttnMaskType.CAUSAL_MASK: return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask - if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: - if mask is not None: - return SoftmaxType.SCALED_MASKED, mask + if attn_mask_type is AttnMaskType.NO_MASK: return SoftmaxType.SCALED, mask raise ValueError( f"Unsupported {attn_mask_type=}, supported attn_mask_type=" @@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods .. note:: THD format only supports 'padding' or 'causal_padding' mask type. + attn_mask_type mask/sequence_descriptor SWA softmax type + -------------------------------------------------------------------------------------------- + no_mask None None SCALED + causal None None SCALED_UPPER_TRIANG_MASKED + causal None Yes SCALED_MASKED + padding Required Yes/No SCALED_MASKED + padding_causal Required Yes/No SCALED_MASKED + attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. From beaecf84ea93c817b0fe7e0611e49c6c0b2e6b30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 16 Apr 2025 13:39:08 +0200 Subject: [PATCH 36/53] =?UTF-8?q?[Pytorch]=20NVIDIA-DL-Framework-Inspect?= =?UTF-8?q?=20support=20=E2=80=93=20part=201=20=E2=80=93=20core=20(#1614)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add Signed-off-by: Pawel Gadzinski * weight workspace fix Signed-off-by: Pawel Gadzinski * docs fix Signed-off-by: Pawel Gadzinski * file i forgot Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * lint fix Signed-off-by: Pawel Gadzinski * Update transformer_engine/debug/pytorch/utils.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * setup fix Signed-off-by: Pawel Gadzinski * setup fix Signed-off-by: Pawel Gadzinski * Update transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * all tensor types Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed check Signed-off-by: Pawel Gadzinski * move error Signed-off-by: Pawel Gadzinski * _reset Signed-off-by: Pawel Gadzinski * Update transformer_engine/pytorch/module/linear.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * name documentation Signed-off-by: Pawel Gadzinski * added blockwise quantizer Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make debug option optional Signed-off-by: Pawel Gadzinski * Update transformer_engine/pytorch/tensor/quantized_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * names fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- setup.py | 4 + transformer_engine/debug/__init__.py | 11 + transformer_engine/debug/pytorch/__init__.py | 3 + .../debug/pytorch/debug_quantization.py | 528 ++++++++++++++++++ .../debug/pytorch/debug_state.py | 68 +++ transformer_engine/debug/pytorch/utils.py | 10 + transformer_engine/pytorch/attention.py | 15 + .../pytorch/cpp_extensions/gemm.py | 11 + transformer_engine/pytorch/distributed.py | 25 +- transformer_engine/pytorch/module/base.py | 85 ++- .../pytorch/module/layernorm_linear.py | 85 ++- .../pytorch/module/layernorm_mlp.py | 266 ++++++--- transformer_engine/pytorch/module/linear.py | 86 ++- transformer_engine/pytorch/tensor/__init__.py | 24 + .../tensor/_internal/float8_tensor_base.py | 6 +- .../pytorch/tensor/quantized_tensor.py | 12 +- transformer_engine/pytorch/transformer.py | 13 + transformer_engine/pytorch/utils.py | 14 + 18 files changed, 1148 insertions(+), 118 deletions(-) create mode 100644 transformer_engine/debug/__init__.py create mode 100644 transformer_engine/debug/pytorch/__init__.py create mode 100644 transformer_engine/debug/pytorch/debug_quantization.py create mode 100644 transformer_engine/debug/pytorch/debug_state.py create mode 100644 transformer_engine/debug/pytorch/utils.py diff --git a/setup.py b/setup.py index e1977601f5..6969ad76e7 100644 --- a/setup.py +++ b/setup.py @@ -110,6 +110,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: install_reqs.extend(["torch>=2.1"]) + install_reqs.append( + "nvdlfw-inspect @" + " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" + ) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) diff --git a/transformer_engine/debug/__init__.py b/transformer_engine/debug/__init__.py new file mode 100644 index 0000000000..62f7f41728 --- /dev/null +++ b/transformer_engine/debug/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Top level package for numerical debugging.""" + +try: + from . import pytorch + from .pytorch.debug_state import set_weight_tensor_tp_group_reduce +except ImportError as e: + pass diff --git a/transformer_engine/debug/pytorch/__init__.py b/transformer_engine/debug/pytorch/__init__.py new file mode 100644 index 0000000000..8bdbe287de --- /dev/null +++ b/transformer_engine/debug/pytorch/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py new file mode 100644 index 0000000000..4a7a156a0a --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -0,0 +1,528 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +This file contains DebugQuantizer and DebugQuantizedTensor objects, +which are wrappers over Quantizer and QuantizedTensor. +These wrappers add logic related to debugging, using the nvdlfw_inspect package. +""" + +from __future__ import annotations +from typing import Optional, Tuple, Iterable, Union +import torch + +import transformer_engine_torch as tex + + +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + +aten = torch.ops.aten + +_tensor_to_gemm_names_map = { + "weight": ["fprop", "dgrad"], + "activation": ["fprop", "wgrad"], + "output": ["fprop", None], + "gradient": ["dgrad", "wgrad"], + "wgrad": ["wgrad", None], + "dgrad": ["dgrad", None], +} + +API_CALL_MODIFY = "modify_tensor()" +STANDARD_FP8_QUANTIZE = "FP8 Quantize" +HIGH_PRECISION = "High Precision" + + +class DebugQuantizer(Quantizer): + """ + DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect. + It allows adding custom calls inside the quantization process - which enables modifying tensors + or gathering tensor stats. + """ + + def __init__( + self, + layer_name: str, + tensor_name: str, + parent_quantizer: Optional[Quantizer], + tp_group: torch.distributed.ProcessGroup, + ): + import nvdlfw_inspect.api as debug_api + + super().__init__(rowwise=True, columnwise=True) + self.layer_name = layer_name + self.tensor_name = tensor_name + self.parent_quantizer = parent_quantizer + self.tp_group = tp_group # used in inspect_tensor calls + self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + + self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] + + # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, + # rowwise_tensor_plan, and columnwise_tensor_plan are computed. + # These fields indicate the path where API calls will be inserted. + # + # inspect_tensor*_enabled are bool fields, + # indicating whether some feature will need to run inspect_tensor_* calls. + # + # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] + # determining what will happen when the quantizer is used for that tensor. + self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] + if self.output_tensor: + self.inspect_tensor_enabled, self.rowwise_tensor_plan = ( + self.get_plans_for_output_tensors() + ) + else: + ( + self.inspect_tensor_enabled, + self.inspect_tensor_postquantize_enabled_rowwise, + self.inspect_tensor_postquantize_enabled_columnwise, + ) = self.get_enabled_look_at_tensors() + self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan() + + self.log_messages_about_plans() + + def get_plans_for_output_tensors(self) -> Tuple[bool, str]: + """ + Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the + API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support + gemm output in FP8. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + modify_enabled = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION + + return inspect_tensor_enabled, plan + + def get_enabled_look_at_tensors(self): + """ + Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + inspect_tensor_postquantize_enabled_rowwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.rowwise_gemm_name, + ) + ) + inspect_tensor_postquantize_enabled_columnwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.columnwise_gemm_name, + ) + ) + + return ( + inspect_tensor_enabled, + inspect_tensor_postquantize_enabled_rowwise, + inspect_tensor_postquantize_enabled_columnwise, + ) + + def get_tensors_plan(self): + """ + Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of + API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior + of this quantizer with respect to these tensors. + """ + import nvdlfw_inspect.api as debug_api + + rowwise_plan = None + columnwise_plan = None + + modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_rowwise: + rowwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + rowwise_plan = STANDARD_FP8_QUANTIZE + if rowwise_plan is None: + rowwise_plan = HIGH_PRECISION + + if self.columnwise_gemm_name is not None: + modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_columnwise: + columnwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + columnwise_plan = STANDARD_FP8_QUANTIZE + if columnwise_plan is None: + columnwise_plan = HIGH_PRECISION + + return rowwise_plan, columnwise_plan + + def log_messages_about_plans(self): + """ + Logs the messages about the plans for each of the tensors. + """ + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -" + f" {self.rowwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name), + ) + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -" + f" {self.columnwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name), + ) + + def _call_inspect_tensor_api( + self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None + ): + import nvdlfw_inspect.api as debug_api + + args = { + "layer_name": self.layer_name, + "tensor": tensor, + "tensor_name": self.tensor_name, + "iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count, + "tp_group": self.tp_group, + } + if tensor is not None and self.inspect_tensor_enabled: + debug_api.transformer_engine.inspect_tensor(**args) + + if self.output_tensor: + return + + if ( + self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_rowwise + ): + args["tensor"] = rowwise_gemm_tensor + args["rowwise"] = True + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + if ( + self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_columnwise + ): + args["tensor"] = columnwise_gemm_tensor + args["rowwise"] = False + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + + def quantize( + self, + tensor: torch.Tensor, + *, + out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None, + dtype: torch.dtype = None, + ): + """Returns DebugQuantizedTensor object.""" + import nvdlfw_inspect.api as debug_api + + assert not self.output_tensor + if out is not None: + return self.update_quantized(tensor, self) + + # 1. If there is fp8 quantization in at least one of the gemms, + # the quantization using the self.parent_quantizer is performed. + + # rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise + rowwise_gemm_quantize = ( + self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + columnwise_gemm_quantize = ( + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + if columnwise_gemm_quantize and not rowwise_gemm_quantize: + rowwise_gemm_quantize = True # only columnwise quantization not implemented + + rowwise_gemm_tensor, columnwise_gemm_tensor = None, None + if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + self.parent_quantizer.set_usage( + rowwise=True, + columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported + ) + quantized_tensor = self.parent_quantizer(tensor) + # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, + # one tensor with columnwise=True and rowwise=True is computed + # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. + if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: + rowwise_gemm_tensor = quantized_tensor + if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: + columnwise_gemm_tensor = quantized_tensor + + # 2. modify_tensor() is called, if it is used. + if self.columnwise_tensor_plan == API_CALL_MODIFY: + columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + if columnwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") + if self.rowwise_tensor_plan == API_CALL_MODIFY: + rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + if rowwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") + + # 3. If some tensors still are not defined we use high precision tensor. + if self.rowwise_tensor_plan == HIGH_PRECISION: + rowwise_gemm_tensor = tensor.to(dtype) + if self.columnwise_tensor_plan == HIGH_PRECISION: + columnwise_gemm_tensor = tensor.to(dtype) + + self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor) + + # sometimes we may want to return simple tensor with only rowwise_gemm + if self.tensor_name in ["wgrad", "dgrad", "output"]: + return rowwise_gemm_tensor + + return DebugQuantizedTensor( + rowwise_gemm_tensor=rowwise_gemm_tensor, + columnwise_gemm_tensor=columnwise_gemm_tensor, + quantizer=self, + layer_name=self.layer_name, + tensor_name=self.tensor_name, + ) + + def process_gemm_output(self, tensor: torch.Tensor): + """This call is invoked after the gemm to inspect and modify the output tensor.""" + import nvdlfw_inspect.api as debug_api + + assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." + assert self.output_tensor + tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} + if self.rowwise_tensor_plan == API_CALL_MODIFY: + tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + gemm=tensor_to_gemm[self.tensor_name], + tensor_name=self.tensor_name, + tensor=tensor, + iteration=self.iteration, + default_quantizer=self.parent_quantizer, + ) + self._call_inspect_tensor_api(tensor) + return tensor + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensor: + """Override make_empty() from Quantizer class.""" + if self.parent_quantizer is not None: + return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device) + return torch.empty(shape, dtype=dtype, device=device) + + def calibrate(self, tensor: torch.Tensor): + """Calibration override, should not be invoked.""" + raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported") + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Update quantized tensor - used in weight caching.""" + import nvdlfw_inspect.api as debug_api + + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + + updated_rowwise_gemm = False + if self.parent_quantizer is not None: + if ( + dst.rowwise_gemm_tensor is not None + and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ): + if hasattr(dst.rowwise_gemm_tensor, "quantize_"): + dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None) + updated_rowwise_gemm = True + if ( + dst.columnwise_gemm_tensor is not None + and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + and not updated_rowwise_gemm + ): + if hasattr(dst.columnwise_gemm_tensor, "quantize_"): + dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None) + + if self.columnwise_tensor_plan == API_CALL_MODIFY: + out = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.columnwise_gemm_tensor, + iteration=self.iteration, + ) + assert out is None, ( + "API call debug_api.transformer_engine.modify_tensor with out != None should" + " return None" + ) + if self.rowwise_tensor_plan == API_CALL_MODIFY: + debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.rowwise_gemm_tensor, + iteration=self.iteration, + ) + + if self.rowwise_tensor_plan == HIGH_PRECISION: + dst.rowwise_gemm_tensor.copy_(src) + if self.columnwise_tensor_plan == HIGH_PRECISION: + # if they are the same tensor object, it is sufficient to update one + if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor: + dst.columnwise_gemm_tensor.copy_(src) + + self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor) + + def any_feature_enabled(self) -> bool: + """Returns bool if there is at least one API call enabled.""" + if self.output_tensor: + return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + if ( + self.inspect_tensor_enabled + or self.inspect_tensor_postquantize_enabled_rowwise + or self.inspect_tensor_postquantize_enabled_columnwise + or self.rowwise_tensor_plan == API_CALL_MODIFY + or self.columnwise_tensor_plan == API_CALL_MODIFY + ): + return True + if self.parent_quantizer is not None: + if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + return False + + +class DebugQuantizedTensor: + """ + Class containing quantized tensors after debug. Depending on configuration + it can contain one or two different objects. These objects can be accessed by the method + get_tensor(). + """ + + def __init__( + self, + rowwise_gemm_tensor, + columnwise_gemm_tensor, + quantizer, + layer_name=None, + tensor_name=None, + ): + + self.rowwise_gemm_tensor = rowwise_gemm_tensor + self.columnwise_gemm_tensor = columnwise_gemm_tensor + self.quantizer = quantizer + self._layer_name = layer_name + self._tensor_name = tensor_name + + def prepare_for_saving(self): + """ " Prepare for saving method override""" + self.tensors_to_save = ( + [self.rowwise_gemm_tensor, self.columnwise_gemm_tensor] + if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor + else [self.rowwise_gemm_tensor] + ) + tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save) + self.tensors_to_save = tensor_objects_list + # pylint: disable=unbalanced-tuple-unpacking + return tensor_list, self + + def restore_from_saved(self, tensors): + """Restore from saved method override""" + tensor_objects_list, saved_tensors = restore_from_saved( + self.tensors_to_save, + tensors, + return_saved_tensors=True, + ) + if len(tensor_objects_list) == 2: + # pylint: disable=unbalanced-tuple-unpacking + self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list + else: + self.rowwise_gemm_tensor = tensor_objects_list[0] + self.columnwise_gemm_tensor = self.rowwise_gemm_tensor + return saved_tensors + + def quantize_(self, tensor, *, noop_flag=None): + """ " quantize_ method override""" + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + self.quantizer.update_quantized(tensor, self) + + def dequantize(self, *, dtype=None): + """ " dequantize method override""" + if dtype is None: + dtype = self.rowwise_gemm_tensor.dtype + return self.rowwise_gemm_tensor.dequantize().to(dtype) + + def get_tensor(self, transpose: bool): + """Is used in the python gemm() to get tensor or transpose of the tensor.""" + return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor + + def size(self): + """Size of the tensor.""" + return self.rowwise_gemm_tensor.size() + + def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): + """Update usage of the tensor.""" diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py new file mode 100644 index 0000000000..11edb3641f --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Managing the state of all the debugged layers. +""" + +import sys + + +class TEDebugState: + """ + A class to manage the state of debug layers. + """ + + layer_count = 1 + layers_initialized = {} + weight_tensor_tp_group_reduce = True + debug_enabled = None + + @classmethod + def initialize(cls): + """ + If debug_api module is initialized, then sets cls.debug_enabled to True. + """ + + if "nvdlfw_inspect" in sys.modules: + import nvdlfw_inspect.api as debug_api + + if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None: + # This method is invoked when initializing TE modules. + # If this error is thrown, it means that some TE module had been initialized before + # debug_api was initialized, and now a new TE module is being initialized. + # This is likely to be a bug. + raise RuntimeError( + "[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before" + " initialization of the first TE module" + ) + cls.debug_enabled = debug_api.DEBUG_MANAGER is not None + + @classmethod + def _reset(cls): + """Resets layer count and stats buffers.""" + from ..features.utils.stats_buffer import STATS_BUFFERS + + STATS_BUFFERS.reset() + cls.debug_enabled = None + cls.layers_initialized.clear() + + @classmethod + def get_layer_count(cls): + """ + Layer counter is used when layer names are not provided to modules by the user. + """ + lc = cls.layer_count + cls.layer_count += 1 + return lc + + @classmethod + def set_weight_tensor_tp_group_reduce(cls, enabled): + """Sets weight tensor reduction mode.""" + cls.weight_tensor_tp_group_reduce = enabled + + +def set_weight_tensor_tp_group_reduce(enabled): + """Sets weight tensor reduction mode.""" + TEDebugState.set_weight_tensor_tp_group_reduce(enabled) diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py new file mode 100644 index 0000000000..4aea05333c --- /dev/null +++ b/transformer_engine/debug/pytorch/utils.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utils functions for the debug module.""" + + +def any_feature_enabled(quantizers): + """Returns True if at least one API call is made from DebugQuantizer.""" + return any(q.any_feature_enabled() for q in quantizers) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8a3f259575..194fed3adf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -19,6 +19,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.utils import ( get_cudnn_version, nvtx_range_pop, @@ -6483,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module): equal length. Please note that these formats do not reflect how tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. For that, please use `get_qkv_layout` to gain the layout information. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -6561,6 +6564,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -6612,6 +6616,8 @@ def __init__( self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups + self.name = name + common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, "tp_group": tp_group, @@ -6652,6 +6658,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_qkv" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6663,6 +6670,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=parameters_split, + name=name + ".linear_qkv" if name is not None else None, **common_gemm_kwargs, ) elif self.attention_type == "cross": @@ -6684,6 +6692,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_q" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6694,6 +6703,7 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, + name=name + ".linear_q" if name is not None else None, **common_gemm_kwargs, ) self.key_value = Linear( @@ -6704,6 +6714,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=("key", "value") if not fuse_qkv_params else None, + name=name + ".linear_kv" if name is not None else None, **common_gemm_kwargs, ) @@ -6733,6 +6744,7 @@ def __init__( ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, ub_name="proj", + name=name + ".proj" if name is not None else None, **common_gemm_kwargs, ) @@ -6923,6 +6935,9 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 737d92eb75..62f029bed7 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,6 +14,7 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ "general_gemm", @@ -109,6 +110,13 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + debug_quantizer = None + if isinstance(quantization_params, DebugQuantizer): + debug_quantizer = quantization_params + quantization_params = quantization_params.parent_quantizer + A = A.get_tensor(not transa) + B = B.get_tensor(transb) + # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] @@ -145,6 +153,9 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) reset_swizzled_inputs(A, B, original_scale_inverses) + if debug_quantizer is not None: + out = debug_quantizer.process_gemm_output(out) + return out, bias_grad, gelu_input, extra_output diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 890a1835a8..0e11b2c102 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -19,7 +19,7 @@ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules -from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data +from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm from .constants import dist_group_type from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer @@ -29,6 +29,7 @@ from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -1195,6 +1196,28 @@ def gather_along_first_dim( out_shape=out_shape, ) + # Debug case - call gather_along_first_dim on each tensor + if isinstance(inp, DebugQuantizedTensor): + out_obj = inp + rowwise = inp.get_tensor(False) + columnwise = inp.get_tensor(True) + final_quantizer = ( + None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer + ) + rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] + out_obj.rowwise_gemm_tensor = rowwise_total + if rowwise is not columnwise: + final_quantizer_columnwise = ( + None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer + ) + columnwise_total, _ = gather_along_first_dim( + columnwise, process_group, False, final_quantizer_columnwise + ) + out_obj.columnwise_gemm_tensor = columnwise_total + else: + out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor + return out_obj, None + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 65f47a0817..739572b925 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager +import logging from types import MethodType import torch @@ -39,6 +40,9 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ...common.recipe import Recipe +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor __all__ = ["initialize_ub", "destroy_ub"] @@ -413,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + self.name = None self.fp8_initialized = False self.fp8 = False self.fp8_calibration = False @@ -432,6 +437,9 @@ def __init__(self) -> None: self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None + if not TEDebugState.debug_enabled: + TEDebugState.initialize() + # Names of attributes that can be set quickly (see __setattr__ # method) _fast_setattr_names: Set[str] = { @@ -848,7 +856,7 @@ def grad_output_preprocess( gather_grad_output = row_parallel_mode and ctx.sequence_parallel # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8: + if not ctx.fp8 and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) @@ -858,6 +866,7 @@ def grad_output_preprocess( return grad_output, None # FP8 with all-gather: unfused bgrad, fused cast + transpose + # Also supports debug quantization, which is handled inside gather_along_first_dim. if gather_grad_output: grad_bias = None if ctx.use_bias: @@ -886,6 +895,23 @@ def grad_output_preprocess( ) return grad_output, grad_bias + # Debug without all-gather: unfused cast and bgrad + # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None + if ctx.debug: + grad_output_ = quantizer(grad_output) + if ( + isinstance( + grad_output_.get_tensor(True), + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase), + ) + and ctx.use_bias + ): + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias = None + grad_output = grad_output_ + return grad_output, grad_bias + # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: @@ -1002,6 +1028,7 @@ def get_weight_workspace( update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, fsdp_group: Optional[dist_group_type] = None, + workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: """Get FP8 workspace buffer and maybe update its values @@ -1024,6 +1051,9 @@ def get_weight_workspace( over `update_workspace` if provided. fsdp_group: bool, default = None FSDP process group that the weights are distributed over. + workspace_dtype: torch.dtype, default = None + If weight workspace contains high-precision tensor - for example + for debug quantization, this is dtype of the tensor. """ # FP8 primary weights @@ -1037,6 +1067,7 @@ def get_weight_workspace( # Try getting workspace from cache out = None + if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) if quantizer is not None and isinstance(out, MXFP8TensorBase): @@ -1047,6 +1078,11 @@ def get_weight_workspace( out = None del self._fp8_workspaces[cache_name] + is_debug = isinstance(quantizer, DebugQuantizer) + is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) + if is_debug != is_out_debug_tensor: + out = None + # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # for models initialized with Fp8 primary weights. @@ -1064,7 +1100,7 @@ def get_weight_workspace( raise ValueError( "tensor and quantizer kwargs must be provided to construct FP8 workspace" ) - out = quantizer(tensor) + out = quantizer.quantize(tensor, dtype=workspace_dtype) # Update cache if cache_name is not None: @@ -1081,7 +1117,6 @@ def get_weight_workspace( out.quantize_(tensor, noop_flag=skip_update_flag) else: tex.quantize(tensor, quantizer, out, skip_update_flag) - return out def _load_from_state_dict( @@ -1104,3 +1139,47 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + + def _validate_name(self): + """ + Validate name passed to the module. + This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. + If no name is assigned, it creates a default name with layer count as the variable. + """ + assert TEDebugState.debug_enabled + import nvdlfw_inspect.api as debug_api + + if self.name is None: + debug_api.log_message( + "Names are not provided to debug modules. ", + "Creating and using generic names. Pass names to debug modules for better" + " insight. ", + level=logging.WARNING, + ) + self.name = f"Layer_{TEDebugState.get_layer_count()}" + + def _turn_off_unsupported_features_in_debug(self): + if ( + getattr(self, "ub_bulk_wgrad", False) + or getattr(self, "ub_bulk_dgrad", False) + or getattr(self, "ub_overlap_ag", False) + or getattr(self, "ub_overlap_rs_dgrad", False) + or getattr(self, "ub_overlap_rs", False) + ): + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + "UserBuffers are not supported in debug module. " + "Using UB optimization will not affect the debug module. ", + level=logging.WARNING, + ) + if hasattr(self, "ub_bulk_wgrad"): + self.ub_bulk_wgrad = None + if hasattr(self, "ub_bulk_dgrad"): + self.ub_bulk_dgrad = None + if hasattr(self, "ub_overlap_ag"): + self.ub_overlap_ag = None + if hasattr(self, "ub_overlap_rs_dgrad"): + self.ub_overlap_rs_dgrad = None + if hasattr(self, "ub_overlap_rs"): + self.ub_overlap_rs = None diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index c82a0e2153..2cc6e770da 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -35,6 +35,7 @@ nvtx_range_pop, nvtx_range_push, requires_grad, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -56,6 +57,8 @@ prepare_for_saving, restore_from_saved, ) +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer @@ -90,8 +93,9 @@ def forward( input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -116,6 +120,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -214,12 +219,12 @@ def forward( # norm output will be returned ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total - if fp8: + if fp8 or debug: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = input_quantizer(ln_out_total) else: - if fp8: + if fp8 or debug: if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -233,18 +238,19 @@ def forward( ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, - quantizer=(input_quantizer if fp8 else None), + quantizer=(input_quantizer if fp8 or debug else None), ) else: - if fp8 and not with_quantized_norm: + if (fp8 or debug) and not with_quantized_norm: ln_out = input_quantizer(ln_out) ln_out_total = ln_out nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype - if not fp8: - quantized_weight = False - weightmat = cast_if_needed(weight, activation_dtype) + weightmat = weight + quantized_weight = False + if not fp8 and not debug: + weightmat = cast_if_needed(weightmat, activation_dtype) else: quantized_weight = not isinstance(weight, QuantizedTensor) @@ -254,6 +260,7 @@ def forward( # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( tensor=weight, quantizer=weight_quantizer, @@ -261,11 +268,12 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -400,6 +408,7 @@ def forward( if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.input_quantizer = input_quantizer ctx.owns_input = inputmat is not inp @@ -434,6 +443,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.debug = debug # Row Parallel Linear if ub_overlap_rs_fprop: @@ -611,7 +621,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.input_quantizer is not None: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -757,6 +767,7 @@ def backward( out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, @@ -865,8 +876,9 @@ def backward( None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -889,6 +901,7 @@ def backward( None, # ub_bulk_wgrad None, # ub_name None, # fsdp_group + None, # debug None, # module None, # skip_fp8_weight_update ) @@ -943,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -1007,6 +1022,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, + name: str = None, ) -> None: super().__init__() @@ -1023,6 +1039,10 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma + self.name = name + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1312,6 +1332,9 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1348,13 +1371,28 @@ def forward( else: bias_tensor = getattr(self, self.bias_names[0]) # Unused + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output, fp8_grad) + debug = False + + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1376,8 +1414,9 @@ def forward( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1402,6 +1441,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = fwd_fn(*args) @@ -1421,8 +1461,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1441,8 +1482,20 @@ def _get_quantizers(self, fp8_output, fp8_grad): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1bf791c12b..0fd051d781 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -41,6 +41,7 @@ clear_tensor_data, requires_grad, non_tn_fp8_gemm_supported, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -73,6 +74,8 @@ from ..cpp_extensions import ( general_gemm, ) +from ...debug.pytorch.utils import any_feature_enabled +from ...debug.pytorch.debug_state import TEDebugState __all__ = ["LayerNormMLP"] @@ -153,12 +156,16 @@ def forward( fuse_wgrad_accumulation: bool, fc1_input_quantizer: Optional[Quantizer], fc1_weight_quantizer: Optional[Quantizer], + fc1_output_quantizer: Optional[Quantizer], + fc1_grad_input_quantizer: Optional[Quantizer], + fc1_grad_weight_quantizer: Optional[Quantizer], + fc1_grad_output_quantizer: Optional[Quantizer], fc2_input_quantizer: Optional[Quantizer], fc2_weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], - grad_fc2_output_quantizer: Optional[Quantizer], - grad_fc1_output_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], + fc2_output_quantizer: Optional[Quantizer], + fc2_grad_input_quantizer: Optional[Quantizer], + fc2_grad_weight_quantizer: Optional[Quantizer], + fc2_grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -184,6 +191,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -212,9 +220,16 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - # Avoid quantized norm kernel if norm output will be returned + # for fp8 DelayedScaling: layernorm output = FP8 + # only output of the linear is returned + # for return_layernorm_output: layernorm output = High precision, then cast to FP8 + # high precision layernorm output and output of the linear are returned + # for debug: : layernorm output = High precision to enable processing of this norm with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not debug ) if isinstance(fc1_input_quantizer, Float8BlockQuantizer): # Kernels not available for norm fusion. @@ -270,13 +285,13 @@ def forward( # norm output will be returned ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total - if fp8: + if fp8 or debug: if not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = fc1_input_quantizer(ln_out_total) else: - if fp8: + if fp8 or debug: if not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -290,21 +305,21 @@ def forward( ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, - quantizer=(fc1_input_quantizer if fp8 else None), + quantizer=(fc1_input_quantizer if fp8 or debug else None), ) else: # NOTE: force_hp_fc1_input_gather is redundant with else, but # here for clarity. We should not quantize ln_out if bwd needs # to gather in hp. - if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather: + if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) ln_out_total = ln_out # Cast weights to expected dtype - if not fp8: - fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype) - fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype) - else: + fc1_weight_final = fc1_weight + fc2_weight_final = fc2_weight + + if fp8 or debug: # If weights are not quantized, we call get_weight_workspace, # which handles weight caching etc. # FP8 cast to workspace buffer @@ -316,6 +331,7 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc2_weight_final = module.get_weight_workspace( @@ -325,11 +341,15 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) + else: + fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) # Cast biases to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 if fc1_bias is not None: fc1_bias = cast_if_needed(fc1_bias, bias_dtype) @@ -359,13 +379,16 @@ def forward( gemm_gelu_fusion = True if gemm_gelu_fusion and bias_gelu_fusion: gemm_gelu_fusion = False - + if debug: + gemm_gelu_fusion = False fc1_outputs = general_gemm( fc1_weight_final, ln_out_total, get_workspace(), quantization_params=( - fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 + fc2_input_quantizer + if gemm_gelu_fusion + else fc1_output_quantizer # fused gelu output is in fp8 ), out_dtype=activation_dtype, bias=( @@ -376,6 +399,7 @@ def forward( ub=ub_obj_lnout, ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ) + if not is_grad_enabled and (ln_out_total is not ln_out_return): clear_tensor_data(ln_out_total) @@ -389,6 +413,10 @@ def forward( act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) elif gemm_gelu_fusion: act_out, _, fc1_out, _ = fc1_outputs + elif debug: + fc1_out, *_ = fc1_outputs + act_out = activation_func(fc1_out, None) + act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): @@ -426,7 +454,7 @@ def forward( get_workspace(), out_dtype=activation_dtype, bias=fc2_bias, - quantization_params=output_quantizer, + quantization_params=fc2_output_quantizer, out=fc2_out, use_split_accumulator=_2X_ACC_FPROP, ub=ub_obj_fc2out, @@ -515,11 +543,14 @@ def forward( ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather - ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer - ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer + ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer + ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer + ctx.fc2_grad_input_quantizer = fc2_grad_input_quantizer + ctx.fc2_grad_weight_quantizer = fc2_grad_weight_quantizer + ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer ctx.fc1_input_quantizer = fc1_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer ctx.fc1_weight_requires_grad = fc1_weight.requires_grad ctx.fc2_weight_requires_grad = fc2_weight.requires_grad @@ -552,6 +583,7 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag + ctx.debug = debug ctx.requires_dgrad = ( inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad @@ -675,18 +707,18 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_fc2_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None: rowwise_usage = True columnwise_usage = True if ctx.ub_overlap_ag and isinstance( - ctx.grad_fc2_output_quantizer, + ctx.fc2_grad_output_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer), ): # If data is in FP8 and communication is handled # with Userbuffers, we compute FP8 transposes # manually columnwise_usage = False - ctx.grad_fc2_output_quantizer.set_usage( + ctx.fc2_grad_output_quantizer.set_usage( rowwise=rowwise_usage, columnwise=columnwise_usage, ) @@ -701,7 +733,7 @@ def backward( grad_output, fc2_bias_grad, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer + ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer ) # Launch tensor-parallel communication for FC1 GEMM input @@ -714,7 +746,7 @@ def backward( and not ctx.ub_bulk_dgrad ): quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -747,7 +779,10 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + not ctx.fp8 + and (ctx.activation == "gelu") + and (not ctx.bias_gelu_fusion) + and (not ctx.debug) ) # FC2 DGRAD; Unconditional @@ -763,7 +798,9 @@ def backward( layout="NN", grad=True, quantization_params=( - ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None + ctx.fc1_grad_input_quantizer + if fc2_dgrad_gemm_gelu_fusion or ctx.debug + else None ), # high precision to activation out_dtype=ctx.activation_dtype, gelu=fc2_dgrad_gemm_gelu_fusion, @@ -798,7 +835,7 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - quantization_params=None, # wgrad in high precision + quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision layout="NT", grad=grad_arg, bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, @@ -817,15 +854,20 @@ def backward( # bias computation fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.grad_fc1_output_quantizer is not None: - ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.fc1_grad_output_quantizer is not None: + ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) - if ctx.grad_fc1_output_quantizer is not None: - dact = ctx.grad_fc1_output_quantizer(dact) + if ctx.fc1_grad_output_quantizer is not None: + dact = ctx.fc1_grad_output_quantizer(dact) + elif ctx.debug: + dact_func = _act_func(ctx.activation)[1] + dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None) + fc1_bias_grad = dact.sum(dim=0) + dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and ctx.fp8 @@ -835,7 +877,7 @@ def backward( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer + fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, @@ -849,12 +891,12 @@ def backward( if ctx.fp8: # TODO float8 blockwise current scaling has no bgrad fusion for now - if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer): + if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) - dact = ctx.grad_fc1_output_quantizer(dact) + dact = ctx.fc1_grad_output_quantizer(dact) else: fc1_bias_grad, dact = tex.bgrad_quantize( - dact, ctx.grad_fc1_output_quantizer + dact, ctx.fc1_grad_output_quantizer ) else: fuse_gemm_and_bias_fc1_wgrad = ( @@ -915,6 +957,7 @@ def backward( get_workspace(), out=fc1_dgrad_bulk, out_dtype=ctx.activation_dtype, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, ub=ub_obj_fc1_dgrad, @@ -990,6 +1033,7 @@ def backward( else ctx.activation_dtype ), layout="NT", + quantization_params=ctx.fc1_grad_weight_quantizer, grad=fuse_gemm_and_bias_fc1_wgrad, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, @@ -1123,14 +1167,18 @@ def backward( None, # fp8 None, # fp8_calibration None, # fuse_wgrad_accumulation - None, # fc1_input_quantizer - None, # fc1_weight_quantizer - None, # fc2_input_quantizer - None, # fc2_weight_quantizer - None, # output_quantizer - None, # grad_fc2_output_quantizer - None, # grad_fc1_output_quantizer - None, # grad_input_quantizer + None, # fc1_input_quantizer, + None, # fc1_weight_quantizer, + None, # fc1_output_quantizer, + None, # fc1_grad_input_quantizer, + None, # fc1_grad_weight_quantizer, + None, # fc1_grad_output_quantizer, + None, # fc2_input_quantizer, + None, # fc2_weight_quantizer, + None, # fc2_output_quantizer, + None, # fc2_grad_input_quantizer, + None, # fc2_grad_weight_quantizer, + None, # fc2_grad_output_quantizer, None, # cpu_offloading None, # tp_group None, # tp_size @@ -1156,6 +1204,7 @@ def backward( None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # debug ) @@ -1208,6 +1257,8 @@ class LayerNormMLP(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -1277,6 +1328,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, + name: str = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1306,6 +1358,10 @@ def __init__( and self.activation == "gelu" and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers if tp_group is None: self.tp_size = tp_size @@ -1466,7 +1522,9 @@ def reset_parameters(self, defer_init=False): @no_torch_dynamo() def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1489,6 +1547,9 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1503,17 +1564,35 @@ def forward( fp8_output = True with self.prepare_forward(inp, num_gemms=2) as inp: + + quantizers = ( + self._get_quantizers(fp8_output) + if not debug + else self._get_debug_quantizers(fp8_output) + ) + if debug: + if not any_feature_enabled(quantizers): + quantizers = self._get_quantizers(fp8_output) + debug = False + + if isinstance(self.fc1_weight, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + # Get quantizers ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = self._get_quantizers(fp8_output) + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers # Get weight tensors fc1_weight = self.fc1_weight @@ -1551,12 +1630,16 @@ def forward( self.fuse_wgrad_accumulation, fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1565,7 +1648,7 @@ def forward( self.activation_dtype, self.return_layernorm_output, self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8, + self.bias_gelu_nvfusion and not self.fp8 and not debug, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, @@ -1578,10 +1661,11 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_bulk_dgrad, self.ub_bulk_wgrad, - self.gemm_gelu_fusion, + self.gemm_gelu_fusion and not debug, self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = fwd_fn(*args) @@ -1603,13 +1687,17 @@ def _get_quantizers(self, fp8_output): ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = [None] * 8 + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = [None] * 12 if self.fp8: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = False # temporary @@ -1623,30 +1711,54 @@ def _get_quantizers(self, fp8_output): fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] + fc2_output_quantizer = self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_OUTPUT + ] if torch.is_grad_enabled(): - grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ] - grad_fc2_output_quantizer.internal = True - grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer.internal = True + fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ] - grad_fc1_output_quantizer.internal = True - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2] - grad_input_quantizer.internal = True + fc1_grad_output_quantizer.internal = True return ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, ) + def _get_debug_quantizers(self, fp8_output): + from ...debug.pytorch.debug_quantization import DebugQuantizer + + base_quantizers = list(self._get_quantizers(fp8_output)) + assert TEDebugState.debug_enabled + + def make_debug(prefix, offset): + labels = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return [ + DebugQuantizer( + f"{self.name}.{prefix}", + label, + None if label in ("dgrad", "wgrad") else base_quantizers[i + offset], + self.tp_group, + ) + for i, label in enumerate(labels) + ] + + return tuple(make_debug("fc1", 0) + make_debug("fc2", 6)) + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + layernorm_mlp.""" assert ( @@ -1691,14 +1803,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8FwdTensors.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: - # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer + # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer + # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale @@ -1706,7 +1818,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_INPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon if self.sequence_parallel and self.set_parallel_mode: - # grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].with_amax_reduction = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2556987fed..e0954ebbb2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -28,11 +28,12 @@ clear_tensor_data, divide, init_method_constant, + requires_grad, + needs_quantized_gemm, non_tn_fp8_gemm_supported, assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, - requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -62,6 +63,8 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled __all__ = ["Linear"] @@ -84,8 +87,9 @@ def forward( input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, tp_group: Union[dist_group_type, None], @@ -106,6 +110,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: Optional[bool] = False, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -144,7 +149,7 @@ def forward( "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" " current scaling" ) - + if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: @@ -196,9 +201,9 @@ def forward( nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype - if not fp8: - weightmat = cast_if_needed(weight, activation_dtype) - else: + weightmat = weight + + if fp8 or debug: # Configure quantizer if weight_quantizer is not None: columnwise_usage = is_grad_enabled and inp.requires_grad @@ -208,7 +213,6 @@ def forward( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch weightmat = module.get_weight_workspace( @@ -218,11 +222,14 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) + else: + weightmat = cast_if_needed(weightmat, activation_dtype) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -343,12 +350,14 @@ def forward( ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer - ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad + ctx.debug = debug ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -528,7 +537,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -564,7 +573,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Update quantizer if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - # dgrad GEMM nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD @@ -678,6 +686,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, @@ -753,8 +762,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # fuse_wgrad_accumulation None, # cpu_offloading None, # tp_group @@ -775,6 +785,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # debug ) @@ -810,6 +821,8 @@ class Linear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -871,6 +884,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, + name: Optional[str] = None, ) -> None: super().__init__() @@ -883,6 +897,10 @@ def __init__( self.apply_bias = bias and not return_bias self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -1126,6 +1144,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() + if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() else: @@ -1161,13 +1183,28 @@ def forward( else: bias_tensor = None + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output, fp8_grad) + debug = False + + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization @@ -1191,8 +1228,9 @@ def forward( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1213,6 +1251,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = linear_fn(*args) if self.gemm_bias_unfused_add: @@ -1224,8 +1263,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1243,8 +1283,20 @@ def _get_quantizers(self, fp8_output, fp8_grad): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 22b86fbcc6..7fa12cc087 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -42,3 +42,27 @@ def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: torch.nn.Module.float = _make_module_cast_func(torch.float32) torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) + + +def get_all_tensor_types(): + """ + Get all tensor-like types that can be used in TE. + """ + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ) + + all_tensor_types = [ + torch.Tensor, + torch.nn.Parameter, + Float8Tensor, + Float8TensorBase, + MXFP8Tensor, + MXFP8TensorBase, + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ] + return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 2fea2c4f28..2b54e9ed79 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -27,12 +27,14 @@ def forward( dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - dtype = torch_to_transformer_engine_dtype[dtype] + te_dtype = torch_to_transformer_engine_dtype[dtype] # Make sure FP8 data is in expected format if tensor._data is not None: + if tensor._data.numel() == 0: + return torch.empty_like(tensor._data, dtype=dtype) # Cast from FP8 - return tex.dequantize(tensor, dtype) + return tex.dequantize(tensor, te_dtype) raise NotImplementedError("Casting back from the transpose not implemented yet!") diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 019aca9f60..aa433e58bc 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -37,7 +37,8 @@ def prepare_for_saving( def restore_from_saved( tensors: list[Optional[Any]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], -) -> list[Optional[Any]]: + return_saved_tensors: bool = False, +) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]: """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] for tensor in tensors: @@ -47,6 +48,9 @@ def restore_from_saved( else: saved_tensors = tensor.restore_from_saved(saved_tensors) tensor_objects.append(tensor) + + if return_saved_tensors: + return tensor_objects, saved_tensors return tensor_objects @@ -113,7 +117,11 @@ def update_quantized( """Quantize tensor in-place""" def quantize( - self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None + self, + tensor: torch.Tensor, + *, + out: Optional[QuantizedTensor] = None, + dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override ) -> QuantizedTensor: """Quantize tensor""" if out is not None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index d829275777..ef7c4c8ab2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -11,6 +11,7 @@ import torch from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention import ( MultiheadAttention, ) @@ -33,6 +34,7 @@ dist_group_type, ) from transformer_engine.pytorch.distributed import get_distributed_world_size +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module): head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -277,6 +281,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -336,6 +341,8 @@ def __init__( self.attn_input_format = attn_input_format + self.name = name + attention_args = ( hidden_size, num_attention_heads, @@ -376,6 +383,7 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, + name=name + ".self_attention" if name is not None else None, ) if layer_type == "decoder": @@ -389,6 +397,7 @@ def __init__( return_bias=True, normalization=normalization, device=device, + name=name + ".inter_attention" if name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -423,6 +432,7 @@ def __init__( activation=activation, normalization=normalization, device=device, + name=name + ".layernorm_mlp" if name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -679,6 +689,9 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 603c1d5de4..8450460c46 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,6 +11,7 @@ import torch import transformer_engine.pytorch.cpp_extensions as ext +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from .tensor.quantized_tensor import QuantizedTensor @@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple): return ((value + multiple - 1) // multiple) * multiple +def needs_quantized_gemm(obj, rowwise=True): + """Used to check if obj will need quantized gemm or normal gemm.""" + if isinstance(obj, DebugQuantizedTensor): + return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck + torch.Tensor, + torch.nn.Parameter, + ] + return type(obj) not in [ + torch.Tensor, + torch.nn.Parameter, + ] # pylint: disable=unidiomatic-typecheck + + @functools.lru_cache(maxsize=None) def _nvtx_enabled() -> bool: """Check if NVTX range profiling is enabled""" From cd035093572f63433f8abe1c1cabc7b3a9a14139 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 16 Apr 2025 06:21:44 -0700 Subject: [PATCH 37/53] add unittest for distributed interface apply Dener's suggestion Signed-off-by: Hongbin Liu --- tests/pytorch/distributed/run_numerics.py | 17 ++++++++++++ transformer_engine/pytorch/module/_common.py | 4 +-- transformer_engine/pytorch/module/base.py | 27 ++++++++++--------- .../pytorch/module/layernorm_linear.py | 12 +-------- .../pytorch/module/layernorm_mlp.py | 5 ++-- transformer_engine/pytorch/module/linear.py | 12 +-------- 6 files changed, 39 insertions(+), 38 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index b423bce53d..9252c3ce9b 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -299,6 +299,9 @@ def _loss_backward(output_single_node, output_distributed): LOSS_FN(output_single_node, target).backward() LOSS_FN(output_distributed, target).backward() +def _loss_backward_dw(model_single_node, model_distributed): + model_single_node.backward_dw() + model_distributed.backward_dw() def _alloc_main_grad(model_single_node, model_distributed): for model in [model_single_node, model_distributed]: @@ -473,6 +476,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): # Compute loss and backpropagate _loss_backward(output_single_node, output_distributed) + # Compute delayed weight gradient + if "split_bw" in kwargs: + _loss_backward_dw(model_single_node, model_distributed) + # Validate outputs and gradients _check_outputs(output_single_node, output_distributed) @@ -494,6 +501,7 @@ def test_linear(): {"fuse_wgrad_accumulation": True}, {"return_bias": True}, {"params_dtype": torch.float16}, + {"split_bw": True}, ] for kwargs in kwargs_list: for parallel_mode in ["column", "row"]: @@ -645,6 +653,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs # Compute loss and backpropagate _loss_backward(output_single_node, output_distributed) + # Compute delayed weight gradient + if "split_bw" in kwargs: + _loss_backward_dw(model_single_node, model_distributed) + # Validate outputs and gradients _check_outputs(output_single_node, output_distributed) @@ -667,6 +679,7 @@ def test_layernorm_linear(): {"params_dtype": torch.float16}, {"zero_centered_gamma": False}, {"return_layernorm_output": True}, + {"split_bw": True}, ] for kwargs in kwargs_list: for parallel_mode in ["column"]: @@ -746,6 +759,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg # Compute loss and backpropagate _loss_backward(output_single_node, output_distributed) + if "split_bw" in kwargs: + _loss_backward_dw(model_single_node, model_distributed) + # Validate outputs and gradients _check_outputs(output_single_node, output_distributed) @@ -771,6 +787,7 @@ def test_layernorm_mlp(): {"fuse_wgrad_accumulation": True}, {"return_bias": True}, {"return_layernorm_output": True}, + {"split_bw": True}, ] for kwargs in kwargs_list: diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index bbe22370b2..85e44b8de8 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -279,8 +279,8 @@ def pop(self): return func(*tensor_list), tensor_list if torch.distributed.is_initialized(): rank = torch.distributed.get_rank() - raise Exception(f"Pop empty queue. rank {rank}") - raise Exception("Pop empty queue. No distributed environment detected.") + raise RuntimeError(f"Pop empty queue. rank {rank}") + raise RuntimeError("Pop empty queue. No distributed environment detected.") def assert_empty(self): """ diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 50494bb428..580f7f6a06 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1110,15 +1110,18 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() - if not self.fuse_wgrad_accumulation: - unfused_weights = [getattr(self, name) for name in self.weight_names] - weight_tensor = noop_cat(unfused_weights) - if weight_tensor.grad is None: - weight_tensor.grad = wgrad.to(weight_tensor.dtype) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - if bias_tensor.grad is None: - bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) - del grad_bias_ - del wgrad + if self.wgrad_store is None or not self.wgrad_store.split_bw(): + return + with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): + (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() + if not self.fuse_wgrad_accumulation: + unfused_weights = [getattr(self, name) for name in self.weight_names] + weight_tensor = noop_cat(unfused_weights) + if weight_tensor.grad is None: + weight_tensor.grad = wgrad.to(weight_tensor.dtype) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + if bias_tensor.grad is None: + bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) + del grad_bias_ + del wgrad diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index da7007f0dc..4bec9d91aa 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -781,7 +781,7 @@ def backward( dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) # Don't return grad bias if not needed - if not ctx.use_bias or (ctx.wgrad_store is not None and ctx.wgrad_store.split_bw()): + if not ctx.use_bias: grad_bias = None # Synchronize tensor parallel communication @@ -1492,13 +1492,3 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - - def backward_dw(self): - """ - Execute the delayed weight gradient computation. - This method is called after the main backward pass to compute weight gradients. - """ - if self.wgrad_store is None or not self.wgrad_store.split_bw(): - return - with torch.cuda.nvtx.range("_LayerNormLinear_wgrad"): - super().backward_dw() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6e9e6181c5..8f99fa4a22 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -804,7 +804,8 @@ def backward( if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) fc2_wgrad = None - fc2_bias_grad = None + # if fc2_bias is not None and fc2_bias_grad is None: + # fc2_bias_grad = None else: fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad( act_out, @@ -1751,7 +1752,7 @@ def backward_dw(self): if self.fc2_bias.grad is None: if ( self.fp8 - and self.fp8_recipe.float8_block_scaling() + and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling() and self.apply_bias and not self.gemm_bias_unfused_add ): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5fc9217487..5d84dd09a3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -700,7 +700,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) # Don't return grad bias if not needed - if not ctx.use_bias or (ctx.wgrad_store is not None and ctx.wgrad_store.split_bw()): + if not ctx.use_bias: grad_bias = None # Make sure all tensor-parallel communication is finished @@ -1298,13 +1298,3 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group - - def backward_dw(self): - """ - Execute the delayed weight gradient computation. - This method is called after the main backward pass to compute weight gradients. - """ - if self.wgrad_store is None or not self.wgrad_store.split_bw(): - return - with torch.cuda.nvtx.range("_Linear_wgrad"): - super().backward_dw() From 92b80ac1306aa42695d9cf5d8a7a04a5b232d193 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 13:41:17 +0000 Subject: [PATCH 38/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/run_numerics.py | 2 ++ transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 9252c3ce9b..48e831acee 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -299,10 +299,12 @@ def _loss_backward(output_single_node, output_distributed): LOSS_FN(output_single_node, target).backward() LOSS_FN(output_distributed, target).backward() + def _loss_backward_dw(model_single_node, model_distributed): model_single_node.backward_dw() model_distributed.backward_dw() + def _alloc_main_grad(model_single_node, model_distributed): for model in [model_single_node, model_distributed]: for param in model.parameters(): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1d1239a9eb..410b2e3776 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -850,7 +850,7 @@ def backward( ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) fc2_wgrad = None # if fc2_bias is not None and fc2_bias_grad is None: - # fc2_bias_grad = None + # fc2_bias_grad = None else: fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad( act_out, From 8ffbbabd11d7802afb84dcebc1e79ac62e53136d Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Wed, 16 Apr 2025 17:37:26 -0500 Subject: [PATCH 39/53] README.md - Installation section (#1689) * Update README.rst - Installation Update installation section with comprehensive guidelines - Add detailed system requirements - Include Conda installation method (experimental) - Document environment variables for customizing build process - Update FlashAttention support to cover both version 2 and 3 - Add troubleshooting section with solutions for common installation issues Signed-off-by: Santosh Bhavani * Update README.rst - Installation removed conda section Signed-off-by: Santosh Bhavani * Update README.rst - Installation added all gpu archs that support FP8 Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani * Update installation.rst Signed-off-by: Kirthi Shankar Sivamani * Fix docs and adding troubleshooting Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Santosh Bhavani Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Przemyslaw Tredak Co-authored-by: Kirthi Shankar Sivamani --- README.rst | 145 +++++++++++++++++++++++++++++++++--------- docs/installation.rst | 20 +++--- 2 files changed, 128 insertions(+), 37 deletions(-) diff --git a/README.rst b/README.rst index c4fde5bd11..3313a2625b 100644 --- a/README.rst +++ b/README.rst @@ -145,18 +145,30 @@ Flax Installation ============ -.. installation -Pre-requisites +System Requirements ^^^^^^^^^^^^^^^^^^^^ -* Linux x86_64 -* CUDA 12.1+ (CUDA 12.8+ for Blackwell) -* NVIDIA Driver supporting CUDA 12.1 or later -* cuDNN 9.3 or later -Docker -^^^^^^^^^^^^^^^^^^^^ +* **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere + +* **OS:** Linux (official), WSL2 (limited support) + +* **Software:** + + * CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers + * cuDNN: 9.3+ + * Compiler: GCC 9+ or Clang 10+ with C++17 support + * Python: 3.12 recommended + +* **Source Build Requirements:** CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+ +* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell) + +Installation Methods +^^^^^^^^^^^^^^^^^^^ + +Docker (Recommended) +^^^^^^^^^^^^^^^^^^^ The quickest way to get started with Transformer Engine is by using Docker images on `NVIDIA GPU Cloud (NGC) Catalog `_. For example to use the NGC PyTorch container interactively, @@ -167,41 +179,116 @@ For example to use the NGC PyTorch container interactively, Where 25.01 (corresponding to January 2025 release) is the container version. -pip -^^^^^^^^^^^^^^^^^^^^ -To install the latest stable version of Transformer Engine, +**Benefits of using NGC containers:** + +* All dependencies pre-installed with compatible versions and optimized configurations +* NGC PyTorch 23.08+ containers include FlashAttention-2 + +pip Installation +^^^^^^^^^^^^^^^^^^^ + +**Prerequisites for pip installation:** + +* A compatible C++ compiler +* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed + +To install the latest stable version with pip: .. code-block:: bash - pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable + # For PyTorch integration + pip install --no-build-isolation transformer_engine[pytorch] + + # For JAX integration + pip install --no-build-isolation transformer_engine[jax] + + # For both frameworks + pip install --no-build-isolation transformer_engine[pytorch,jax] + +Alternatively, install directly from the GitHub repository: + +.. code-block:: bash -This will automatically detect if any supported deep learning frameworks are installed and build -Transformer Engine support for them. To explicitly specify frameworks, set the environment variable -NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). + pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable -Alternatively, the package can be directly installed from -`Transformer Engine's PyPI `_, e.g. +When installing from GitHub, you can explicitly specify frameworks using the environment variable: .. code-block:: bash - pip3 install transformer_engine[pytorch] + NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + +Source Installation +^^^^^^^^^^^^^^^^^^^ + +`See the installation guide `_ -To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be -explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). -Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX -and PyTorch extensions. +Environment Variables +^^^^^^^^^^^^^^^^^^^ +These environment variables can be set before installation to customize the build process: -From source -^^^^^^^^^^^ -`See the installation guide `_. +* **CUDA_PATH**: Path to CUDA installation +* **CUDNN_PATH**: Path to cuDNN installation +* **CXX**: Path to C++ compiler +* **NVTE_FRAMEWORK**: Comma-separated list of frameworks to build for (e.g., ``pytorch,jax``) +* **MAX_JOBS**: Limit number of parallel build jobs (default varies by system) +* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job -Compiling with FlashAttention-2 -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance. +Compiling with FlashAttention +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment. + +You can verify which FlashAttention version is being used by setting these environment variables: + +.. code-block:: bash + + NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python your_script.py It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. -Note that NGC PyTorch 23.08+ containers include FlashAttention-2. +.. troubleshooting-begin-marker-do-not-remove +Troubleshooting +^^^^^^^^^^^^^^^^^^^ + +**Common Issues and Solutions:** + +1. **ABI Compatibility Issues:** + + * **Symptoms:** ``ImportError`` with undefined symbols when importing transformer_engine + * **Solution:** Ensure PyTorch and Transformer Engine are built with the same C++ ABI setting. Rebuild PyTorch from source with matching ABI. + * **Context:** If you're using PyTorch built with a different C++ ABI than your system's default, you may encounter these undefined symbol errors. This is particularly common with pip-installed PyTorch outside of containers. + +2. **Missing Headers or Libraries:** + + * **Symptoms:** CMake errors about missing headers (``cudnn.h``, ``cublas_v2.h``, ``filesystem``, etc.) + * **Solution:** Install missing development packages or set environment variables to point to correct locations: + + .. code-block:: bash + + export CUDA_PATH=/path/to/cuda + export CUDNN_PATH=/path/to/cudnn + + * If CMake can't find a C++ compiler, set the ``CXX`` environment variable. + * Ensure all paths are correctly set before installation. + +3. **Build Resource Issues:** + + * **Symptoms:** Compilation hangs, system freezes, or out-of-memory errors + * **Solution:** Limit parallel builds: + + .. code-block:: bash + + MAX_JOBS=1 NVTE_BUILD_THREADS_PER_JOB=1 pip install ... + +4. **Verbose Build Logging:** + + * For detailed build logs to help diagnose issues: + + .. code-block:: bash + + cd transformer_engine + pip install -v -v -v --no-build-isolation . + +.. troubleshooting-end-marker-do-not-remove Breaking Changes ================ diff --git a/docs/installation.rst b/docs/installation.rst index 10046d6306..d0d6cf96d2 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -34,7 +34,7 @@ Transformer Engine can be directly installed from `our PyPI Date: Wed, 16 Apr 2025 20:26:54 -0700 Subject: [PATCH 40/53] minor fix Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2183152c4b..b992a72e71 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -420,6 +420,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, *wgrad_list, *grad_biases, ) From 61312d6a84134219749559d51d60e223b8de45d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 17 Apr 2025 05:36:11 +0200 Subject: [PATCH 41/53] [PyTorch] Deprecate the weight offloading (#1678) * drop Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 18 ++--- transformer_engine/pytorch/cpu_offload.py | 73 +++++++++++-------- .../pytorch/module/layernorm_linear.py | 12 +-- .../pytorch/module/layernorm_mlp.py | 22 +----- transformer_engine/pytorch/module/linear.py | 9 +-- 5 files changed, 58 insertions(+), 76 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 194fed3adf..3db13593f5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -81,7 +81,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb -from .cpu_offload import set_offloading_param +from .cpu_offload import mark_activation_offload # Setup Attention Logging @@ -4323,10 +4323,9 @@ def forward( from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] - for tensor in tensor_list: - if tensor is not None: - set_offloading_param(tensor, "activation_offloading", True) + mark_activation_offload( + query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv + ) with self.attention_dropout_ctx(): # | API | use cases @@ -4729,13 +4728,8 @@ def forward( tensor_list = [q, k, v, out_save] qkv_layout = "sbhd_sbhd_sbhd" - for tensor in tensor_list: - if tensor is not None: - set_offloading_param(tensor, "activation_offloading", True) - - for tensor in aux_ctx_tensors: - if tensor is not None: - set_offloading_param(tensor, "activation_offloading", True) + mark_activation_offload(*tensor_list) + mark_activation_offload(*aux_ctx_tensors) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 93df512ac6..814e699557 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -16,18 +16,22 @@ CPUOffloadEnabled = False -def set_offloading_param(tensor, param_name, value): +def mark_activation_offload(*tensors): """Set the type of the offloading needed for a tensor.""" - assert param_name in ["weight_offloading", "activation_offloading"] - if tensor is None: - return - if type(tensor) in [torch.Tensor, torch.nn.Parameter]: - setattr(tensor, param_name, value) - else: - data_tensors = tensor.get_data_tensors() - for tensor in data_tensors: - if tensor is not None: - setattr(tensor, param_name, value) + for tensor in tensors: + if tensor is None: + continue + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + tensor.activation_offloading = True + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + tensor.activation_offloading = True + # This is a hack to force clear the tensor after it is offloaded. + # It is needed, because .*TensorBase classes are saved in the ctx, + # and they contain the reference to their data tensors. + tensor.needs_force_clear = True def is_cpu_offload_enabled() -> bool: @@ -459,8 +463,15 @@ def synchronize_on_group_commit_forward(self, current_group): torch.cuda.current_stream().wait_stream(self.d2h_stream) # Time to free the activation memory after usage - for tensor_tag, _ in self.tensor_tag_to_buf.items(): + for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): if tensor_tag[0] == self.offloaded_group_count: + if hasattr(tensor_buf, "needs_force_clear"): + # Need to clear activation tensor - sometimes references persist in the code. + # This is the case for example with the Float8TensorBase class, + # which is saved directly inside the ctx while its internal tensors are + # saved inside save_for_backward. + tensor_buf.data = torch.Tensor() + # Release the pointer to the tensor self.tensor_tag_to_buf[tensor_tag] = None # Time to offload the next group @@ -538,7 +549,7 @@ def get_cpu_offload_context( num_layers: int = 1, model_layers: int = 1, offload_activations: bool = True, - offload_weights: bool = True, + offload_weights: bool = False, ): """ This function returns the CPU Offload context and the synchronizer function that needs to be @@ -570,28 +581,30 @@ def get_cpu_offload_context( """ - def tensor_need_offloading_checker_activations(tensor): - return hasattr(tensor, "activation_offloading") - - # This includes the Gradient Accumulation Buffer - def tensor_need_offloading_checker_weights(tensor): - return hasattr(tensor, "weight_offloading") - - def tensor_need_offloading_checker_all(tensor): - return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading") - - if offload_activations and offload_weights: - tensor_need_offloading_checker = tensor_need_offloading_checker_all - elif offload_activations: - tensor_need_offloading_checker = tensor_need_offloading_checker_activations - elif offload_weights: - tensor_need_offloading_checker = tensor_need_offloading_checker_weights - else: + if not offload_weights and not offload_activations: raise ValueError( "CPU Offloading is enabled while it is not " "mentioned what to offload (weights/activations)" ) + if offload_weights: + import warnings + + warnings.warn( + "Offloading weights is deprecated. Using offload_weights=True does not have any" + " effect.", + DeprecationWarning, + ) + + # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. + if not offload_activations: + return nullcontext(), lambda x: x + + def tensor_need_offloading_checker_activations(tensor): + return hasattr(tensor, "activation_offloading") + + tensor_need_offloading_checker = tensor_need_offloading_checker_activations + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( num_offload_group=num_layers, num_model_group=model_layers, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2cc6e770da..d9b135d571 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -63,7 +63,7 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpp_extensions import ( general_gemm, @@ -355,15 +355,7 @@ def forward( weightmat.update_usage(columnwise_usage=True) if cpu_offloading: - if fp8 and weightmat is not None: - set_offloading_param(weightmat, "weight_offloading", True) - set_offloading_param(ln_weight, "weight_offloading", True) - set_offloading_param(weight, "weight_offloading", True) - - set_offloading_param(inputmat, "activation_offloading", True) - set_offloading_param(mu, "activation_offloading", True) - set_offloading_param(rsigma, "activation_offloading", True) - set_offloading_param(ln_out, "activation_offloading", True) + mark_activation_offload(inputmat, mu, rsigma, ln_out) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 0fd051d781..dbdc9eca5e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -64,7 +64,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose -from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -473,23 +473,9 @@ def forward( clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) else: if cpu_offloading: - if fp8 and fc1_weight_final is not None: - set_offloading_param(fc1_weight_final, "weight_offloading", True) - if fp8 and fc2_weight_final is not None: - set_offloading_param(fc2_weight_final, "weight_offloading", True) - set_offloading_param(ln_weight, "weight_offloading", True) - set_offloading_param(fc1_weight, "weight_offloading", True) - set_offloading_param(fc2_weight, "weight_offloading", True) - set_offloading_param(fc1_bias, "weight_offloading", True) - - set_offloading_param(inputmat, "activation_offloading", True) - set_offloading_param(mu, "activation_offloading", True) - set_offloading_param(rsigma, "activation_offloading", True) - set_offloading_param(mu, "activation_offloading", True) - set_offloading_param(ln_out, "activation_offloading", True) - set_offloading_param(fc1_out, "activation_offloading", True) - set_offloading_param(fc1_out_without_bias, "activation_offloading", True) - set_offloading_param(act_out, "activation_offloading", True) + mark_activation_offload( + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + ) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e0954ebbb2..ae2eafc44b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -62,7 +62,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.utils import any_feature_enabled @@ -307,11 +307,8 @@ def forward( if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) - if cpu_offloading: - set_offloading_param(weight, "weight_offloading", True) - set_offloading_param(weightmat, "weight_offloading", True) - if saved_inputmat is not None: - set_offloading_param(saved_inputmat, "activation_offloading", True) + if cpu_offloading and saved_inputmat is not None: + mark_activation_offload(saved_inputmat) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights From a0cabb717c7f5d51ff8f0986ea38b49346d7c1de Mon Sep 17 00:00:00 2001 From: linxiddd Date: Thu, 17 Apr 2025 23:23:33 +0800 Subject: [PATCH 42/53] [QA] Add XML log generation for pytest results (#1661) * [QA] Add error handling - Standardize test failure handling using the unified 'test_fail' function and 'error_exit' function Signed-off-by: Linxi Ding * Add XML log generation for pytest results - Add `--junitxml` option to pytest command to generate JUnit XML format logs Signed-off-by: Linxi Ding * Add $XML_LOG_DIR Signed-off-by: Linxi Ding * mkdir Signed-off-by: Linxi Ding * Update qa/L0_pytorch_unittest/test.sh Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Linxi Ding Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- qa/L0_jax_distributed_unittest/test.sh | 8 +++-- qa/L0_jax_unittest/test.sh | 11 +++--- qa/L0_pytorch_unittest/test.sh | 42 +++++++++++----------- qa/L1_jax_distributed_unittest/test.sh | 4 ++- qa/L1_pytorch_distributed_unittest/test.sh | 16 +++++---- qa/L1_pytorch_thunder_integration/test.sh | 4 ++- qa/L2_jax_unittest/test.sh | 10 +++--- qa/L3_pytorch_FA_versions_test/test.sh | 4 ++- 8 files changed, 58 insertions(+), 41 deletions(-) diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 377a57e909..92434c28ea 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -17,16 +17,18 @@ RET=0 FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" wait -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" wait -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py" +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 4cde60b5a9..6ffc5945a2 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -20,20 +20,23 @@ FAILED_CASES="" pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" + : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" # Test without custom calls -NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" +NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "mnist" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index ffdc088a49..8e37a83dea 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -19,29 +19,31 @@ FAILED_CASES="" set -x : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -NVTE_FLASH_ATTN=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" -NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 96c5949a99..5deb77af91 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -5,5 +5,7 @@ set -xe : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 36d491ecd3..03997489e8 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -17,16 +17,18 @@ RET=0 FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" -# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential -python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh index 1737ca9ba1..edf3f2eb84 100644 --- a/qa/L1_pytorch_thunder_integration/test.sh +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -5,9 +5,11 @@ set -x : ${THUNDER_PATH:=/opt/pytorch/lightning-thunder} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 -python3 -m pytest -v -s ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py # Check return code # Note: Return code 5 is fine. Lightning tests are skipped on systems diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index 59212e38e1..7611575412 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -21,20 +21,22 @@ FAILED_CASES="" pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "tests/jax/*not_distributed_*" # Test without custom calls -NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" +NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "mnist" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 3e83ef7f52..a42ec035e8 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -5,6 +5,8 @@ set -e : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 @@ -37,6 +39,6 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py done From 61f1bf6ffbc2939863295a7ba67f01d6619cc44d Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 17 Apr 2025 14:47:07 -0700 Subject: [PATCH 43/53] Support computing zero-centered gamma in compute dtype for CuDNN (#1690) * Add a flag to support computing zero-centered gamma in weight dtype or compute dtype for CuDNN Signed-off-by: Jeremy Berchtold * Address comments Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- tests/cpp/operator/test_normalization.cu | 179 +++--------------- tests/cpp/operator/test_normalization.h | 178 +++++++++++++++++ .../cpp/operator/test_normalization_mxfp8.cu | 97 +++------- .../transformer_engine/normalization.h | 10 + .../common/normalization/common.cpp | 25 ++- .../common/normalization/layernorm/ln_api.cpp | 23 ++- .../normalization/rmsnorm/rmsnorm_api.cpp | 13 +- .../jax/cpp_extensions/normalization.py | 29 +++ 8 files changed, 313 insertions(+), 241 deletions(-) create mode 100644 tests/cpp/operator/test_normalization.h diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 0004c2ce74..a0ca938fbf 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -18,159 +18,16 @@ #include #include #include "../test_common.h" +#include "test_normalization.h" using namespace transformer_engine; using namespace test; namespace { -enum NormType { - LayerNorm, - RMSNorm -}; - -std::map normToString = { - {NormType::LayerNorm, "LayerNorm"}, - {NormType::RMSNorm, "RmsNorm"} -}; - -template -void compute_ref_stats(NormType norm_type, - const InputType *data, float *mu, float *rsigma, - const size_t N, const size_t H, const double epsilon){ - using compute_t = float; - compute_t current, m; - for (size_t i = 0; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - sum += static_cast(data[i * H + j]); - } - if (norm_type == LayerNorm){ - mu[i] = sum / H; - m = mu[i]; - } else { m = 0;} - - compute_t sum_sq = 0; - for (size_t j = 0; j < H; ++j) { - current = static_cast(data[i * H + j]); - sum_sq += (current - m) * (current - m); - } - rsigma[i] = rsqrtf((sum_sq / H) + epsilon); - } -} - -// For now, cudnn does static_cast(gamma + static_cast(1.0)) -// This will be changed in the future release -template -inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){ - - using compute_t = float; - if constexpr (std::is_same_v || std::is_same_v){ - compute_t g = static_cast(gamma); - if (zero_centered_gamma) { - g += static_cast(1.f); - } - return g; - } else { - if (use_cudnn){ - compute_t g = static_cast(0.f); - InputType gi = gamma; - if (zero_centered_gamma) { - gi = gi + static_cast(1.f); - } - g = static_cast(gi); - return g; - } else { - compute_t g = static_cast(gamma); - if (zero_centered_gamma) { - g += static_cast(1.f); - } - return g; - } - } -} - -template -void compute_ref_output(NormType norm_type, - const InputType *data, const InputType *gamma, const InputType *beta, - OutputType* output, - const float *mu, const float *rsigma, - const size_t N, const size_t H, - float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); - - compute_t tmp; - if (norm_type == LayerNorm) { - tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); - } else { // RMSNorm - tmp = current * rsigma[i] * g; - } - - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - - -template -void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, - const float *mu, const float *rsigma, - const InputType *gamma, - InputType *data_grad, - InputType *gamma_grad, InputType *beta_grad, - const size_t N, const size_t H, - const bool zero_centered_gamma, const bool use_cudnn) { - using compute_t = float; - std::vector dgamma(H, 0.f); - std::vector dbeta(H, 0.f); - - for (size_t i = 0 ; i < N; ++i) { - // Reductions - auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; - compute_t mdy = 0, mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - local_mu) * rsigma[i]; - compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - if (norm_type == LayerNorm) { - dbeta[j] += dz; - mdy += dy; - } - mdyy += dy * y; - } - mdy /= H; - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - local_mu) * rsigma[i]; - compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); - if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); -} - template void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, - NormType norm_type, bool use_cudnn) { + NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) { if (sizeof(InputType) < sizeof(OutputType)) { GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; return; @@ -219,9 +76,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); + if ((!use_cudnn || !zero_centered_gamma) && zero_centered_gamma_in_weight_dtype) { + // Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation + GTEST_SKIP() << "Zero-centered gamma in weight dtype is only supported with cuDNN backend"; + } + if (use_cudnn){ nvte_enable_cudnn_norm_fwd(true); nvte_enable_cudnn_norm_bwd(true); + + + // Zero-centered gamma in weight dtype only supported by CuDNN backend currently + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(true); + } else { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } } // Forward kernel @@ -269,6 +139,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, if (use_cudnn){ nvte_enable_cudnn_norm_fwd(false); nvte_enable_cudnn_norm_bwd(false); + + // Zero-centered gamma in weight dtype only supported by CuDNN backend currently + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } } // Reference implementations @@ -289,14 +164,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, &ref_amax, ref_scale, zero_centered_gamma, - use_cudnn); + use_cudnn, + zero_centered_gamma_in_weight_dtype); compute_ref_backward(norm_type, dz.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), mu.rowwise_cpu_dptr(), rsigma.rowwise_cpu_dptr(), gamma.rowwise_cpu_dptr(), ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), N, H, zero_centered_gamma, - use_cudnn); + use_cudnn, + zero_centered_gamma_in_weight_dtype); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -341,6 +218,7 @@ NormType, transformer_engine::DType, transformer_engine::DType, std::pair, + bool, bool>> {}; TEST_P(NormTestSuite, TestNorm) { @@ -353,10 +231,11 @@ TEST_P(NormTestSuite, TestNorm) { const DType output_type = std::get<3>(GetParam()); const auto size = std::get<4>(GetParam()); const bool zero_centered_gamma = std::get<5>(GetParam()); + const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn); + performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype); ); ); } @@ -370,6 +249,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), ::testing::ValuesIn(test_cases), + ::testing::Values(false, true), ::testing::Values(false, true)), [](const testing::TestParamInfo& info) { auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; @@ -380,6 +260,7 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<3>(info.param)) + "X" + std::to_string(std::get<4>(info.param).first) + "X" + std::to_string(std::get<4>(info.param).second) + "X" + - std::to_string(std::get<5>(info.param)); + std::to_string(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)); return name; }); diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h new file mode 100644 index 0000000000..368ffa66c9 --- /dev/null +++ b/tests/cpp/operator/test_normalization.h @@ -0,0 +1,178 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + + #pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include "../test_common.h" + +namespace test { +namespace { + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RmsNorm"} +}; + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + compute_t current, m; + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +template +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { + + using compute_t = float; + + // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently + // Remove the use_cudnn check here when it is supported by both backends. + const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; + + if constexpr (std::is_same_v || std::is_same_v){ + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } else { + if (zero_centered_gamma_in_weight_dtype){ + compute_t g = static_cast(0.f); + InputType gi = gamma; + if (zero_centered_gamma) { + gi = gi + static_cast(1.f); + } + g = static_cast(gi); + return g; + } else { + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + OutputType* output, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = static_cast(tmp * scale); + current_max = fmaxf(current_max, fabsf(tmp)); + } + } + + if (amax) { + *amax = current_max; + } +} + + +template +void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, + const float *mu, const float *rsigma, + const InputType *gamma, + InputType *data_grad, + InputType *gamma_grad, InputType *beta_grad, + const size_t N, const size_t H, + const bool zero_centered_gamma, const bool use_cudnn, + const bool cudnn_zero_centered_gamma_in_weight_dtype) { + using compute_t = float; + std::vector dgamma(H, 0.f); + std::vector dbeta(H, 0.f); + + for (size_t i = 0 ; i < N; ++i) { + // Reductions + auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; + compute_t mdy = 0, mdyy = 0; + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + dgamma[j] += y * dz; + if (norm_type == LayerNorm) { + dbeta[j] += dz; + mdy += dy; + } + mdyy += dy * y; + } + mdy /= H; + mdyy /= H; + + // Input grads + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); + data_grad[i * H + j] = static_cast(dx); + } + } + + // Weight grads + for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); + if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); +} + +} // namespace +} // namespace test diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index 191c62835b..4d0cf86034 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -19,6 +19,7 @@ #include #include #include "../test_common.h" +#include "test_normalization.h" using namespace transformer_engine; using namespace test; @@ -27,16 +28,6 @@ namespace { using fp8e8m0 = byte; -enum NormType { - LayerNorm, - RMSNorm -}; - -std::map normToString = { - {NormType::LayerNorm, "LayerNorm"}, - {NormType::RMSNorm, "RMSNorm"} -}; - template void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr, size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){ @@ -110,65 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training) 32, 1); } -template -void compute_ref_stats(NormType norm_type, - const InputType *data, float *mu, float *rsigma, - const size_t N, const size_t H, const double epsilon){ - using compute_t = float; - - #pragma omp parallel for proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - sum += static_cast(data[i * H + j]); - } - compute_t m; - if (norm_type == LayerNorm){ - mu[i] = sum / H; - m = mu[i]; - } else { m = 0;} - - compute_t sum_sq = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum_sq += (current - m) * (current - m); - } - rsigma[i] = rsqrtf((sum_sq / H) + epsilon); - } -} - -template -void compute_ref_output(NormType norm_type, - const InputType *data, const InputType *gamma, const InputType *beta, - const float *mu, const float *rsigma, - const size_t N, const size_t H, - OutputType* output, - const bool zero_centered_gamma){ - using compute_t = float; - - #pragma omp parallel for proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1.0; - } - - compute_t tmp; - if (norm_type == LayerNorm) { - tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); - } else { // RMSNorm - tmp = current * rsigma[i] * g; - } - - output[i * H + j] = tmp; - } - } -} - template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) { +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training, const bool zero_centered_gamma_in_weight_dtype) { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); @@ -195,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, fillUniform(&gamma); fillUniform(&beta); + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(true); + } else { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } + // Forward kernel float epsilon = 1e-5; if (norm_type == NormType::LayerNorm){ @@ -220,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, 0); } + if (zero_centered_gamma_in_weight_dtype) { + nvte_enable_zero_centered_gamma_in_weight_dtype(false); + } + Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); dequantize_2x(z, dequantized_output, is_training); @@ -246,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, compute_ref_output(norm_type, input.rowwise_cpu_dptr(), gamma.rowwise_cpu_dptr(), beta.rowwise_cpu_dptr(), + ref_output.get(), ref_mu_ptr, ref_rsigma_ptr, N, H, - ref_output.get(), - zero_centered_gamma); + nullptr, // amax + 1.f, // scale + zero_centered_gamma, + true, // CuDNN is the only MXFP8 backend currently + zero_centered_gamma_in_weight_dtype); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -298,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple, - bool, bool>> {}; + bool, bool, bool>> {}; TEST_P(MxNormTestSuite, TestMxNorm) { using namespace transformer_engine; @@ -310,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) { const auto size = std::get<3>(GetParam()); const bool zero_centered_gamma = std::get<4>(GetParam()); const bool is_training = std::get<5>(GetParam()); + const bool zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma, norm_type, is_training); + performTest(size.first, size.second, zero_centered_gamma, norm_type, is_training, zero_centered_gamma_in_weight_dtype); ); ); } @@ -327,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), ::testing::ValuesIn(test_cases), ::testing::Values(true, false), + ::testing::Values(true, false), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { std::string name = normToString.at(std::get<0>(info.param)) + "_" + @@ -335,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<3>(info.param).first) + "X" + std::to_string(std::get<3>(info.param).second) + "X" + std::to_string(std::get<4>(info.param)) + "out" + - std::to_string(int(std::get<5>(info.param)) + 1) + "x"; + std::to_string(int(std::get<5>(info.param)) + 1) + "x" + + std::to_string(std::get<6>(info.param)); return name; }); diff --git a/transformer_engine/common/include/transformer_engine/normalization.h b/transformer_engine/common/include/transformer_engine/normalization.h index 9b0b80acc2..9c194e9da2 100644 --- a/transformer_engine/common/include/transformer_engine/normalization.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable); +/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma + * in weight dtype. If set to false, it will compute in compute dtype. + * + * Currently this only applies to the CuDNN backend. If CuDNN is not used, + * this setting has no effect. + * + * \param[in] bool Enable if True + */ +void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable); + enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; #ifdef __cplusplus diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index ddda78d951..89affc081c 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -39,6 +39,8 @@ Compute always in FP32 namespace transformer_engine { namespace normalization { +bool& use_zero_centered_gamma_in_weight_dtype(); + cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { return training ? cudnn_frontend::NormFwdPhase_t::TRAINING : cudnn_frontend::NormFwdPhase_t::INFERENCE; @@ -207,9 +209,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor _ndim_scale_block = 1; } - _scalar_dptr = std::make_unique(typeToSize(wtype)); + const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype; + + _scalar_dptr = std::make_unique(typeToSize(gamma_dtype)); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); + gamma_dtype, cpp_dtype, + *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); _handle = cudnnExecutionPlanManager::Instance().GetHandle(); @@ -239,13 +244,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_name("one") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(wtype)) + .set_data_type(get_cudnn_fe_dtype(gamma_dtype)) .set_is_pass_by_value(true)); auto centered_options = fe::graph::Pointwise_attributes() .set_mode(fe::PointwiseMode_t::ADD) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options); - _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype)); + _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(gamma_dtype)); } else { _gamma = _gamma_zero; } @@ -503,6 +508,13 @@ bool& _cudnn_norm_bwd_flag() { bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); } bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } +bool& _zero_centered_gamma_in_weight_dtype() { + static bool flag = transformer_engine::getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE"); + return flag; +} + +bool& use_zero_centered_gamma_in_weight_dtype() { return _zero_centered_gamma_in_weight_dtype(); } + } // namespace normalization } // namespace transformer_engine @@ -515,3 +527,8 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { NVTE_API_CALL(nvte_enable_cudnn_norm_bwd); transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; } + +void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { + NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype); + transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; +} diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index f6b6ae22c2..47b37b3482 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -31,19 +31,24 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); + NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); + NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape."); + NVTE_CHECK(gamma.data.dtype == beta.data.dtype, + "Gamma and Beta must have the same dtype. Gamma dtype: " + + to_string(gamma.data.dtype) + ", Beta dtype: " + to_string(beta.data.dtype)); + NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0], "Gamma must have the same hidden size."); - NVTE_CHECK(epsilon >= 0.f); + NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative."); - NVTE_CHECK(z->data.shape == x.data.shape); + NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x."); - NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}); - NVTE_CHECK(mu->data.dtype == DType::kFloat32); + NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}, + "Mu must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(mu->data.dtype == DType::kFloat32, "Mu must be a float32 tensor."); - NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); - NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}, + "RSigma must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); if (!workspace->data.shape.empty()) { CheckInputTensor(x, "x"); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index c56f9ef407..48cf1d819b 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -27,15 +27,16 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } - NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); - NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); - NVTE_CHECK(epsilon >= 0.f); + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1], "Gamma must have the same hidden size."); + NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative."); - NVTE_CHECK(z->data.shape == x.data.shape); + NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x."); - NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); - NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}, + "RSigma must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); if (!workspace->data.shape.empty()) { CheckInputTensor(x, "x"); diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3aec30f420..54360c2dcc 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -64,6 +64,27 @@ def get_backward_sm_margin(): return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) +@cache +def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool: + """Retrieves whether CuDNN norm fwd is enabled.""" + # MXFP8_1D_SCALING always uses CuDNN currently + return ( + int(os.getenv("NVTE_NORM_FWD_USE_CUDNN", "0")) == 1 + or scaling_mode == ScalingMode.MXFP8_1D_SCALING + ) + + +@cache +def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bool: + """Retrieves whether norm should compute `gamma += 1.0` for zero-centered gamma + in weight dtype as opposed to compute dtype.""" + if not is_norm_fwd_cudnn_enabled(scaling_mode): + # If CuDNN is not enabled, we use the TE backend which uses the compute dtype not weight dtype + # Remove this when TE supports gamma += 1.0 in weight dtype + return False + return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1 + + class NormFwdPrimitive(BasePrimitive): """ Layer Normalization Forward FP8 Primitive @@ -788,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) JAX native layernorm implementation """ x_ = jnp.asarray(x, jnp.float32) + if not is_norm_zero_centered_gamma_in_weight_dtype( + quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING + ): + gamma = gamma.astype(jnp.float32) mean = jnp.mean(x_, axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) @@ -809,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): JAX native rmsnorm implementation """ x_ = jnp.asarray(x, jnp.float32) + if not is_norm_zero_centered_gamma_in_weight_dtype( + quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING + ): + gamma = gamma.astype(jnp.float32) var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) normed_input = x_ * rsigma From e61ce77c95ab1db6753af926e734cc155e5c772b Mon Sep 17 00:00:00 2001 From: kwyss-nvidia Date: Thu, 17 Apr 2025 16:02:12 -0700 Subject: [PATCH 44/53] Allow NVTEShape to own data. (#1674) * Allow NVTEShape to own data. Signed-off-by: Keith Wyss * Convert repeated copy paths to nvte_make_shape calls. Signed-off-by: Keith Wyss * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Build fixes. Signed-off-by: Keith Wyss * MR feedback. Signed-off-by: Keith Wyss --------- Signed-off-by: Keith Wyss Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/cpp/test_common.cu | 9 ++-- tests/cpp/test_common.h | 2 +- transformer_engine/common/common.h | 25 +---------- .../transformer_engine/transformer_engine.h | 42 +++++++++++++++---- .../common/transformer_engine.cpp | 37 +++++++++------- .../pytorch/csrc/extensions/attention.cu | 10 +++-- 6 files changed, 70 insertions(+), 55 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 61d3075265..0977c512cb 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -112,8 +112,8 @@ struct scale_inv_meta { size_t type_size; }; -NVTEShape convertShape(const std::vector& shape) { - return {shape.data(), shape.size()}; +NVTEShape convertShape(const std::vector& s) { + return nvte_make_shape(s.data(), s.size()); } std::pair get_scales(const NVTEShape& shape, @@ -240,7 +240,7 @@ Tensor::Tensor(const std::string& name, std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), shape.data[shape.ndim - 1]}; NVTEShape normalized_shape = convertShape(normalized_shape_v); - NVTEShape columnwise_shape{nullptr, 0}; + NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { @@ -257,8 +257,7 @@ Tensor::Tensor(const std::string& name, } if (columnwise) { - columnwise_shape.data = columnwise_shape_vec.data(); - columnwise_shape.ndim = columnwise_shape_vec.size(); + columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); } tensor_ = TensorWrapper(scaling_mode); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index d5ecc6d0f5..5e01dacc0a 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -109,7 +109,7 @@ class Tensor { const bool rowwise = true, const bool columnwise = false, const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} + Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {} Tensor() {} diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index a852bda410..daed7718ff 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -78,8 +78,8 @@ struct SimpleTensor { SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} operator NVTEBasicTensor() const { - const NVTEShape shape = {this->shape.data(), this->shape.size()}; - return {dptr, static_cast(dtype), shape}; + return {dptr, static_cast(dtype), + nvte_make_shape(this->shape.data(), this->shape.size())}; } int numel() const { @@ -99,11 +99,6 @@ struct Tensor { SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; - private: - // Used as an allocation for nvte_tensor_shape - // if the shape has to be inferred from columnwise data. - mutable std::vector rowwise_shape_cache; - public: NVTEScalingMode scaling_mode; @@ -194,22 +189,6 @@ struct Tensor { } } - const std::vector &rowwise_shape_ref() const { - auto shape_queried = shape(); - // This method is primarily designed for nvte_shape. - // An unfortunate consequence of unconditionally assigning - // values to rowwise_shape_cache without a check is that - // repeated calls to rowwise_shape_ref are likely to - // invalidate the data pointers from previous calls. - // If the shape has changed, then invalidating is necessary - // in at least some cases, but we want to keep the data - // valid otherwise. - if (rowwise_shape_cache != shape_queried) { - rowwise_shape_cache = std::move(shape_queried); - } - return rowwise_shape_cache; - } - /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index d3ee446f83..2c3192f773 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -42,6 +42,8 @@ struct NVTEShape { const size_t *data; /*! \brief Number of dimensions. */ size_t ndim; + /*! \brief Copy of data. Num dims limited to permit fixed struct size.*/ + size_t owned_data[14]; }; /*! \struct NVTEBasicTensor @@ -134,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor); */ void *nvte_tensor_columnwise_data(const NVTETensor tensor); +/*! \brief Construct a shape from an array of dimension sizes. + * + * \param[data] Pointer to start of shape array. + * \param[data] Number of dimensions (must be <= 14) + * + * \return A shape. The shape will own its own copy of the data. + */ +NVTEShape nvte_make_shape(const size_t *data, size_t ndim); + /*! \brief Get a tensor's data shape. * * \param[in] tensor Tensor. @@ -417,8 +428,9 @@ class TensorWrapper { float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, const std::vector &scale_inv_shape = {1}, const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, - scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, + : TensorWrapper(dptr, nvte_make_shape(shape.data(), shape.size()), dtype, amax_dptr, + scale_dptr, scale_inv_dptr, + nvte_make_shape(scale_inv_shape.data(), scale_inv_shape.size()), scaling_mode) {} /*! \brief Constructs new empty TensorWrapper. @@ -534,7 +546,9 @@ class TensorWrapper { * \return Shape of this TensorWrapper. */ const NVTEShape shape() const noexcept { - if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + if (tensor_ == nullptr) { + return nvte_make_shape(nullptr, 0); + } return nvte_tensor_shape(tensor_); } @@ -543,7 +557,9 @@ class TensorWrapper { * \return Shape of this TensorWrapper. */ const NVTEShape columnwise_shape() const noexcept { - if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + if (tensor_ == nullptr) { + return nvte_make_shape(nullptr, 0); + } return nvte_tensor_columnwise_shape(tensor_); } @@ -656,7 +672,9 @@ class TensorWrapper { * \return scale_inv_shape of this TensorWrapper. */ const NVTEShape scale_inv_shape() const noexcept { - if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + if (tensor_ == nullptr) { + return nvte_make_shape(nullptr, 0); + } return nvte_tensor_scale_inv_shape(tensor_); } @@ -672,12 +690,20 @@ class TensorWrapper { void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } static constexpr size_t defaultData = 1; - static constexpr NVTEShape defaultShape = {&defaultData, 1}; + static constexpr NVTEShape defaultShape = { + &defaultData, 1, {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; private: - NVTEShape convertShape(const NVTEShape &s) { return s; } + NVTEShape convertShape(const NVTEShape &s) { + NVTEShape ret = s; + // Move the ownership rather than pointing to the parent shape. + ret.data = ret.owned_data; + return ret; + } - NVTEShape convertShape(const std::vector &s) { return {s.data(), s.size()}; } + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } /*! \brief Wrapped NVTETensor. */ NVTETensor tensor_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 706e0bd0b5..68f6bcc322 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -211,6 +211,22 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { reinterpret_cast(tensor)->dtype()); } +NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { + NVTEShape ret; + if (ndim == 0) { + ret.data = nullptr; + ret.ndim = 0; + return ret; + } + NVTE_CHECK(ndim <= sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), + "Too many dims for NVTEShape (requested: ", ndim, + ", max: ", sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), ")"); + std::copy(data, data + ndim, ret.owned_data); + ret.data = ret.owned_data; + ret.ndim = ndim; + return ret; +} + NVTEShape nvte_tensor_shape(const NVTETensor tensor) { if (tensor == nullptr) { NVTE_ERROR("Invalid tensor"); @@ -218,12 +234,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { // Determine tensor shape depending on tensor format const auto &t = *reinterpret_cast(tensor); - const std::vector &rowwise_shape = t.rowwise_shape_ref(); + std::vector shape = t.shape(); - NVTEShape ret; - ret.data = rowwise_shape.data(); - ret.ndim = rowwise_shape.size(); - return ret; + return nvte_make_shape(shape.data(), shape.size()); } NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { @@ -231,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTE_ERROR("Invalid tensor"); } const auto &t = *reinterpret_cast(tensor); - NVTEShape ret; - ret.data = t.columnwise_data.shape.data(); - ret.ndim = t.columnwise_data.shape.size(); - return ret; + return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size()); } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } @@ -302,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { } NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; + if (tensor == nullptr) { + return nvte_make_shape(nullptr, 0); + } const auto &t = *reinterpret_cast(tensor); - NVTEShape ret; - ret.data = t.scale_inv.shape.data(); - ret.ndim = t.scale_inv.shape.size(); - return ret; + return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size()); } void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index da82120f4a..593c05f98e 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -3,9 +3,11 @@ * * See LICENSE for license information. ************************************************************************/ + #include "extensions.h" #include "kv_cache.cuh" #include "thd_utils.cuh" +#include "transformer_engine/transformer_engine.h" constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -449,13 +451,13 @@ std::vector fused_attn_bwd( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - auto temp_vec = std::vector(tmp.begin(), tmp.end()); - const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()}; + const std::vector &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); + const std::vector tmp(signed_shape.begin(), signed_shape.end()); + NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - temp_shape}; + nvte_make_shape(tmp.data(), tmp.size())}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } From 4e036c8c8d442ab0c6a2719babd32fcfd9d4413e Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 18 Apr 2025 07:02:47 +0800 Subject: [PATCH 45/53] [PyTorch] Move swizzle scaling factor to cpp (#1683) * move swizzle scaling factor to cpp Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/transformer_engine.cpp | 2 +- .../pytorch/cpp_extensions/gemm.py | 47 ------------------- transformer_engine/pytorch/csrc/extensions.h | 20 ++++---- .../pytorch/csrc/extensions/apply_rope.cpp | 4 +- .../pytorch/csrc/extensions/attention.cu | 14 +++--- .../pytorch/csrc/extensions/gemm.cpp | 14 ++++++ .../pytorch/csrc/extensions/swizzle.cpp | 12 +++-- transformer_engine/pytorch/csrc/util.h | 13 +++++ 8 files changed, 54 insertions(+), 72 deletions(-) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 68f6bcc322..9072e1d060 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -268,7 +268,7 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { size_t nvte_tensor_element_size(const NVTETensor tensor) { if (tensor == nullptr) return sizeof(float); const auto &t = *reinterpret_cast(tensor); - return transformer_engine::typeToSize(t.data.dtype); + return transformer_engine::typeToSize(t.dtype()); } void *nvte_tensor_data(const NVTETensor tensor) { diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 62f029bed7..b970d0549d 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -12,7 +12,6 @@ from ..utils import get_sm_count from ..tensor.quantized_tensor import Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -28,46 +27,6 @@ def _empty_tensor() -> torch.Tensor: return torch.Tensor().cuda() -def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str): - """Swizzle gemm inputs and return original scaling factor inverses.""" - if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase): - return None - - original_scale_inverses = ( - A._rowwise_scale_inv, - A._columnwise_scale_inv, - B._rowwise_scale_inv, - B._columnwise_scale_inv, - ) - - if layout[0] == "T": - A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv) - else: - A._columnwise_scale_inv = tex.columnwise_swizzle( - A._columnwise_data, A._columnwise_scale_inv - ) - - if layout[1] == "N": - B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv) - else: - B._columnwise_scale_inv = tex.columnwise_swizzle( - B._columnwise_data, B._columnwise_scale_inv - ) - - return original_scale_inverses - - -def reset_swizzled_inputs(A, B, scale_inverses): - """Reset the swizzled scale inverses after GEMM.""" - if scale_inverses is not None: - ( - A._rowwise_scale_inv, - A._columnwise_scale_inv, - B._rowwise_scale_inv, - B._columnwise_scale_inv, - ) = scale_inverses - - def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -149,9 +108,7 @@ def general_gemm( "bulk_overlap": bulk_overlap, } - original_scale_inverses = swizzle_inputs(A, B, layout) out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) - reset_swizzled_inputs(A, B, original_scale_inverses) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) @@ -210,8 +167,6 @@ def general_grouped_gemm( for o in out ] # this should differ with respect to single output - # TODO: Move the swizzle to the C++ side. # pylint: disable=fixme - original_scale_inverses_list = [swizzle_inputs(A[i], B[i], layout) for i in range(num_gemms)] bias = tex.te_general_grouped_gemm( A, transa, @@ -231,7 +186,5 @@ def general_grouped_gemm( use_split_accumulator, sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))), ) - for i in range(num_gemms): - reset_swizzled_inputs(A[i], B[i], original_scale_inverses_list[i]) return out, bias, gelu_input diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 14609762fc..770517a051 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -50,11 +50,11 @@ std::vector fused_attn_fwd( NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional page_table_k, const c10::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, + const std::optional page_table_k, const std::optional page_table_v, + py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, + const std::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -63,8 +63,8 @@ std::vector fused_attn_bwd( const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer); at::Tensor fa_prepare_fwd(at::Tensor qkvi); @@ -270,12 +270,12 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank); /*************************************************************************************************** @@ -396,8 +396,6 @@ void nvshmem_finalize(); * swizzle **************************************************************************************************/ -void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); - at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv); at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv); diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 424a988301..3414975b0e 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -8,7 +8,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank) { using namespace transformer_engine::pytorch; @@ -96,7 +96,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, - const c10::optional cu_seqlens, const int cp_size, + const std::optional cu_seqlens, const int cp_size, const int cp_rank) { using namespace transformer_engine::pytorch; TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 593c05f98e..37b6840f1a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -92,11 +92,11 @@ std::vector fused_attn_fwd( NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional page_table_k, const c10::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, + const std::optional page_table_k, const std::optional page_table_v, + py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, + const std::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; using namespace transformer_engine::pytorch; TensorWrapper te_Q, te_K, te_V, te_O, te_S; @@ -282,8 +282,8 @@ std::vector fused_attn_bwd( const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + const std::optional cu_seqlens_q_padded, + const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer) { using namespace transformer_engine; using namespace transformer_engine::pytorch; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index ff61cd940c..5860d9ff2c 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -17,6 +17,7 @@ #include "extensions.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" +#include "util.h" namespace { @@ -175,8 +176,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + // Keep the swizzled scaling factor tensors alive during the GEMM. + std::vector> swizzled_scale_inverses_list; auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { + // Optionally swizzle the scaling factors + swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); + swizzled_scale_inverses_list.emplace_back( + std::move(swizzle_scaling_factors(B_tensor, !transb))); + if (comm_overlap) { // Prepare extra output tensor TensorWrapper extra_output_tensor; @@ -313,6 +321,8 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_vector, te_workspace_vector; std::vector wrappers; std::vector D_vectors; + // Keep the swizzled scaling factor tensors alive during the GEMMs. + std::vector> swizzled_scale_inverses_list; auto none = py::none(); @@ -379,6 +389,10 @@ std::optional> te_general_grouped_gemm( continue; } + // Optionally swizzle the scaling factors + swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa))); + swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb))); + auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index b127b5d75b..a16c43d2d9 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -6,14 +6,16 @@ #include "extensions.h" #include "transformer_engine/transformer_engine.h" +#include "util.h" -void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { +std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, + bool rowwise) { using namespace transformer_engine::pytorch; if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { - return; + return std::nullopt; } NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); @@ -48,9 +50,9 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); } else { - input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); - output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); } @@ -63,6 +65,8 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww } else { input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); } + + return swizzled_scale_inv; } at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) { diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index cbdf0833ed..a69e2cc24f 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -7,6 +7,19 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ +#include + +#include + +#include "transformer_engine/transformer_engine.h" + bool non_tn_fp8_gemm_supported(); +/* Swizzle the scaling factor of the input tensor. + * + * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + */ +std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, + bool trans); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ From 39c0e709051536a52fc93525fce2b2557078b100 Mon Sep 17 00:00:00 2001 From: wdykas <73254672+wdykas@users.noreply.github.com> Date: Thu, 17 Apr 2025 19:44:38 -0400 Subject: [PATCH 46/53] Re Do symmetric memory merge request (#1682) * re merge request Signed-off-by: Peter Dykas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add docstring Signed-off-by: Peter Dykas --------- Signed-off-by: Peter Dykas Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/distributed.py | 153 ++++++++++++++++++ .../pytorch/module/layernorm_linear.py | 24 ++- .../pytorch/module/layernorm_mlp.py | 26 ++- transformer_engine/pytorch/module/linear.py | 26 ++- 4 files changed, 224 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 0e11b2c102..fe77b69cad 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -31,6 +31,13 @@ from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..debug.pytorch.debug_quantization import DebugQuantizedTensor +try: + import torch.distributed._symmetric_memory as symm_mem + + HAS_TORCH_SYMMETRIC = True +except ImportError: + HAS_TORCH_SYMMETRIC = False + __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -1260,6 +1267,152 @@ def gather_along_first_dim( return out, handle +# Global cache to store symmetric memory tensors +symmetric_mem_cache = {} + + +def get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group, tag=None): + """ + Gets or creates a symmetric memory tensor with specified properties. + + Reuses cached tensors when available to avoid redundant creation and rendezvous operations. + + Note: This function always returns a 1D tensor. + + Parameters + ---------- + tensor_numel : int + Number of elements in the tensor. + tensor_dtype : torch.dtype + Data type of the tensor. + tensor_device : torch.device + Device on which to allocate the tensor. + tp_group : dist_group_type + Process group for rendezvous operation. + tag : Any, optional + Optional identifier to further distinguish tensors. + + Returns + ------- + torch.Tensor + A symmetric memory tensor with the specified properties. + """ + # Create a cache key based on tensor properties and group + cache_key = (tensor_numel, tensor_dtype, tensor_device, tp_group.group_name, tag) + + # Check if we already have a symmetric memory tensor for this configuration + if cache_key not in symmetric_mem_cache: + # Create a new symmetric memory tensor if not in cache + msg = symm_mem.empty( + tensor_numel, + dtype=tensor_dtype, + device=tensor_device, + ) + # Perform the rendezvous once for this tensor + symm_mem.rendezvous(msg, group=tp_group) + # Store in cache + symmetric_mem_cache[cache_key] = msg + else: + # Reuse the existing symmetric memory tensor + msg = symmetric_mem_cache[cache_key] + + return msg + + +def symmetric_all_reduce( + inp: torch.Tensor, + tp_group: Optional[dist_group_type] = None, + async_op: bool = False, + all_reduce_type: str = "multimem_all_reduce", +): + """ + Performs an all-reduce operation across multiple processes using symmetric memory. + If the input tensor is already in the symmetric memory cache we can avoid copy + overheads by just directly using the input tensor for all reduce. Externally + created symmetric memory tensors not in the cache currently will not be able to + avoid the extra copies. + + Parameters + ---------- + inp : torch.Tensor + The input tensor to be reduced. The operation is performed in-place. + + tp_group : Optional[dist_group_type], default=None + The process group over which to perform the all-reduce operation. + If None, the default process group is used. + + async_op : bool, default=False + Whether to perform the operation asynchronously. + Note: Currently only synchronous operations are supported for symmetric memory variants. + + all_reduce_type : str, default="multimem_all_reduce" + The type of all-reduce implementation to use. Options include: + - "nccl": Standard PyTorch distributed all-reduce + - "multimem_all_reduce": multimem symmetric all-reduce + - "two_shot": Two-shot symmetric all-reduce + - "one_shot": One-shot symmetric all-reduce + + Returns + ------- + Tuple[torch.Tensor, Optional[torch.distributed.Work]] + - The first element is the input tensor with the all-reduce result. + - The second element is the async work handle if async_op=True, + otherwise None. + """ + assert async_op is False, "Async symmetric ops no supported yet" + assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch" + + if get_distributed_world_size(tp_group) == 1: + return inp, None + + if all_reduce_type == "nccl": + # Standard all-reduce implementation + handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op) + return inp, handle + + all_reduce_impl = None + if all_reduce_type == "multimem_all_reduce": + all_reduce_impl = torch.ops.symm_mem.multimem_all_reduce_ + elif all_reduce_type == "two_shot": + all_reduce_impl = torch.ops.symm_mem.two_shot_all_reduce_ + elif all_reduce_type == "one_shot": + all_reduce_impl = torch.ops.symm_mem.one_shot_all_reduce + else: + raise TypeError(f"All reduce type {all_reduce_type} is not supported.") + + group_name = tp_group.group_name + tensor_shape = inp.shape + tensor_numel = inp.numel() + tensor_dtype = inp.dtype + tensor_device = inp.device + + input_id = id(inp) + is_cached = any(id(cached_tensor) == input_id for cached_tensor in symmetric_mem_cache.values()) + # Check if the input tensor is already in the symmetric memory cache. If it is we can avoid copy overheads. + if is_cached: + all_reduce_impl( + inp, + "sum", + group_name, + ) + else: + # Get symmetric memory tensor. Build or retrieve from cache. + msg = get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group) + + msg.copy_(inp.reshape(-1)) + + all_reduce_impl( + msg, + "sum", + group_name, + ) + + # Copy the result back to the input tensor + inp.copy_(msg.reshape(tensor_shape)) + + return inp, None + + def allreduce( inp: torch.Tensor, tp_group: Optional[dist_group_type] = None, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d9b135d571..39158a7566 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -15,6 +15,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version from .base import ( get_workspace, get_ub, @@ -41,6 +42,7 @@ set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, + symmetric_all_reduce, reduce_scatter_along_first_dim, gather_along_first_dim, in_fp8_activation_recompute_phase, @@ -120,6 +122,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + symmetric_ar_type: str, debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -445,7 +448,10 @@ def forward( if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: - out, _ = allreduce(out, tp_group) + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") # [*, in_features] -> [*, out_features] except first dimension changes for SP @@ -896,6 +902,7 @@ def backward( None, # debug None, # module None, # skip_fp8_weight_update + None, # symmetric_ar_type ) @@ -985,6 +992,11 @@ class LayerNormLinear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. """ def __init__( @@ -1014,6 +1026,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, + symmetric_ar_type: Optional[str] = None, name: str = None, ) -> None: super().__init__() @@ -1030,6 +1043,7 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma + self.symmetric_ar_type = symmetric_ar_type self.name = name if TEDebugState.debug_enabled: @@ -1099,6 +1113,13 @@ def __init__( assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + self.eps = eps layer_norm_weight = torch.nn.Parameter( torch.empty(self.in_features, device=device, dtype=params_dtype) @@ -1433,6 +1454,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, debug, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index dbdc9eca5e..ca863913fc 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -16,6 +16,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version from .base import ( get_workspace, _ub_communicators, @@ -47,6 +48,7 @@ set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, + symmetric_all_reduce, reduce_scatter_along_first_dim, gather_along_first_dim, use_reentrant_activation_recompute, @@ -191,6 +193,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + symmetric_ar_type: str, debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -590,7 +593,12 @@ def forward( elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) elif set_parallel_mode and tensor_parallel: - fc2_out, _ = allreduce(fc2_out, tp_group) + if symmetric_ar_type is not None: + fc2_out, _ = symmetric_all_reduce( + fc2_out, tp_group, all_reduce_type=symmetric_ar_type + ) + else: + fc2_out, _ = allreduce(fc2_out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1]) @@ -1190,6 +1198,7 @@ def backward( None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # symmetric_ar_type None, # debug ) @@ -1287,6 +1296,11 @@ class LayerNormMLP(TransformerEngineBaseModule): batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase. + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. """ def __init__( @@ -1319,6 +1333,7 @@ def __init__( ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, + symmetric_ar_type: Optional[str] = None, ) -> None: super().__init__() @@ -1337,6 +1352,7 @@ def __init__( ) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma + self.symmetric_ar_type = symmetric_ar_type # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( @@ -1376,6 +1392,13 @@ def __init__( ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad ) + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1651,6 +1674,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, debug, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ae2eafc44b..5a993e14f9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version from .base import ( get_workspace, get_ub, @@ -39,6 +40,7 @@ set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, + symmetric_all_reduce, reduce_scatter_along_first_dim, gather_along_first_dim, is_fp8_activation_recompute_enabled, @@ -66,7 +68,6 @@ from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.utils import any_feature_enabled - __all__ = ["Linear"] @@ -110,6 +111,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + symmetric_ar_type: str, debug: Optional[bool] = False, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -387,7 +389,10 @@ def forward( if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: - out, _ = allreduce(out, tp_group) + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") out = out.view(-1, *inp_shape[1:-1], out_features) @@ -782,6 +787,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # symmetric_ar_type None, # debug ) @@ -855,7 +861,11 @@ class Linear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. - + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. """ def __init__( @@ -881,6 +891,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, + symmetric_ar_type: Optional[str] = None, name: Optional[str] = None, ) -> None: super().__init__() @@ -894,6 +905,7 @@ def __init__( self.apply_bias = bias and not return_bias self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.symmetric_ar_type = symmetric_ar_type self.name = name if TEDebugState.debug_enabled: @@ -963,6 +975,13 @@ def __init__( assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." self.ub_name = ub_name + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1248,6 +1267,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, debug, ) out = linear_fn(*args) From 1a6a6d7b3a3604a5ef79fb8ad100a381794cef15 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 17 Apr 2025 21:47:03 -0400 Subject: [PATCH 47/53] [JAX] Deprecate Praxis layers (#1694) rm pax/praxis Signed-off-by: Phuong Nguyen Co-authored-by: Kirthi Shankar Sivamani --- qa/L2_jax_unittest/test.sh | 2 +- setup.py | 1 - transformer_engine/jax/praxis/__init__.py | 9 - transformer_engine/jax/praxis/module.py | 311 -------------- transformer_engine/jax/praxis/transformer.py | 408 ------------------- transformer_engine/jax/setup.py | 2 +- 6 files changed, 2 insertions(+), 731 deletions(-) delete mode 100644 transformer_engine/jax/praxis/__init__.py delete mode 100644 transformer_engine/jax/praxis/module.py delete mode 100644 transformer_engine/jax/praxis/transformer.py diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index 7611575412..07eb0fc8f1 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -24,7 +24,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" # Test without custom calls NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" diff --git a/setup.py b/setup.py index 6969ad76e7..97fb292c51 100644 --- a/setup.py +++ b/setup.py @@ -119,7 +119,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) - # test_reqs.extend(["numpy", "praxis"]) test_reqs.extend(["numpy"]) return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] diff --git a/transformer_engine/jax/praxis/__init__.py b/transformer_engine/jax/praxis/__init__.py deleted file mode 100644 index 5352f1f53b..0000000000 --- a/transformer_engine/jax/praxis/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Praxis related Modules""" -from .module import FusedSoftmax, LayerNorm -from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer -from .transformer import DotProductAttention, MultiHeadAttention -from .transformer import RelativePositionBiases, TransformerLayer -from ..flax.transformer import TransformerLayerType diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py deleted file mode 100644 index ce407f94fc..0000000000 --- a/transformer_engine/jax/praxis/module.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -""" -Praxis Modules -""" -from dataclasses import field -from functools import partial -from typing import Callable, Iterable, Sequence, Tuple, Union - -from praxis import pax_fiddle -from praxis.base_layer import init_var -from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection -from praxis.layers import flax_adapter -from praxis.pytypes import JTensor - -from ..fp8 import FP8Helper -from ..flax.module import DenseGeneral, LayerNormDenseGeneral -from ..flax.module import LayerNorm as flax_LayerNorm -from ..flax.module import LayerNormMLP as flax_LayerNormMLP -from ..flax.module import Softmax -from ..softmax import SoftmaxType - - -def _generate_ln_scale_init(scale_init): - if scale_init is not None: - return TransformerEngineBaseLayer.generate_params_init("scale", scale_init) - return scale_init - - -class TransformerEngineBaseLayer(BaseLayer): - """TransformerEngineBaseLayer""" - - logical_axes_rules: Tuple[Tuple, ...] = None - - @staticmethod - def generate_params_init(name: str, initializer: WeightInit): - """generate_params_init""" - - def kernel_init(key, shape, dtype): - wp = WeightHParams(shape=shape, init=initializer, dtype=dtype) - return init_var(wp, key, name) - - return kernel_init - - def create_layer(self, name, flax_module_cls): - """create_layer""" - - fp8_collection_map = { - FP8Helper.FP8_COLLECTION_NAME: [ - WeightHParamsCollection.SKIP_LP_REGULARIZATION, - WeightHParamsCollection.OVERWRITE_WITH_GRADIENT, - WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION, - ] - } - - flax_module_p = pax_fiddle.Config( - flax_adapter.FlaxModuleAdapter, - module_factory_method=flax_module_cls, - logical_axes_rules=self.logical_axes_rules, - var_collection_map=fp8_collection_map, - ici_mesh_shape=self.ici_mesh_shape, - dcn_mesh_shape=self.dcn_mesh_shape, - mesh_axis_names=self.mesh_axis_names, - ) - - self.create_child(name, flax_module_p.clone()) - - -class LayerNorm(TransformerEngineBaseLayer): - """LayerNorm""" - - epsilon: float = 1e-6 - layernorm_type: str = "layernorm" - zero_centered_gamma: bool = False - scale_init: WeightInit = None - scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes: Tuple[str, ...] = () - transpose_batch_sequence: bool = False - - def setup(self) -> None: - """setup""" - super().setup() - - ln_cls = partial( - flax_LayerNorm, - epsilon=self.epsilon, - layernorm_type=self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - scale_init=_generate_ln_scale_init(self.scale_init), - scale_axes=self.scale_axes, - bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init), - bias_axes=self.bias_axes, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - ) - - self.create_layer("layer_norm", ln_cls) - - def __call__(self, x: JTensor) -> JTensor: - """__call__""" - return self.layer_norm(x) - - -class FusedSoftmax(TransformerEngineBaseLayer): - """FusedSoftmax""" - - scale_factor: float = 1.0 - softmax_type: SoftmaxType = SoftmaxType.SCALED - - def setup(self) -> None: - """setup""" - super().setup() - - fused_softmax_cls = partial( - Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type - ) - - self.create_layer("fused_softmax", fused_softmax_cls) - - def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor: - """__call__""" - return self.fused_softmax(x, mask, bias) - - -class Linear(TransformerEngineBaseLayer): - """Linear""" - - out_features: int = 512 - kernel_axes: Tuple[str, ...] = () - use_bias: bool = True - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes: Tuple[str, ...] = () - enable_low_rank_adaptation: bool = False - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - axis: Union[Iterable[int], int] = -1 - transpose_batch_sequence: bool = False - - def setup(self) -> None: - """setup""" - super().setup() - - dense_general_cls = partial( - DenseGeneral, - features=self.out_features, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - kernel_axes=self.kernel_axes, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - bias_axes=self.bias_axes, - enable_low_rank_adaptation=self.enable_low_rank_adaptation, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - axis=self.axis, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - ) - - self.create_layer("linear", dense_general_cls) - - def __call__(self, x: JTensor) -> JTensor: - """__call__""" - return self.linear(x) - - -class LayerNormLinear(TransformerEngineBaseLayer): - """LayerNormLinear""" - - out_features: int = 512 - enable_layernorm: bool = True - layernorm_type: str = "layernorm" - epsilon: float = 1e-6 - zero_centered_gamma: bool = False - scale_init: WeightInit = None - scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=1.0) - ) - ln_bias_axes: Tuple[str, ...] = () - kernel_axes: Tuple[str, ...] = () - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes: Tuple[str, ...] = () - enable_low_rank_adaptation: bool = False - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - return_layernorm_output: bool = True - axis: Union[Iterable[int], int] = -1 - transpose_batch_sequence: bool = False - depth_scaling: float = None - - def setup(self) -> None: - """setup""" - super().setup() - - ln_dense_general_cls = partial( - LayerNormDenseGeneral, - features=self.out_features, - enable_layernorm=self.enable_layernorm, - layernorm_type=self.layernorm_type, - epsilon=self.epsilon, - zero_centered_gamma=self.zero_centered_gamma, - scale_init=_generate_ln_scale_init(self.scale_init), - scale_axes=self.scale_axes, - ln_bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", self.ln_bias_init - ), - ln_bias_axes=self.ln_bias_axes, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - kernel_axes=self.kernel_axes, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - bias_axes=self.bias_axes, - enable_low_rank_adaptation=self.enable_low_rank_adaptation, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - return_layernorm_output=self.return_layernorm_output, - axis=self.axis, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - depth_scaling=self.depth_scaling, - ) - - self.create_layer("ln_linear", ln_dense_general_cls) - - def __call__(self, x: JTensor) -> JTensor: - """__call__""" - return self.ln_linear(x) - - -class LayerNormMLP(TransformerEngineBaseLayer): - """LayerNormMLP""" - - intermediate_dim: int = 2048 - enable_layernorm: bool = True - layernorm_type: str = "layernorm" - epsilon: float = 1e-6 - zero_centered_gamma: bool = False - scale_init: WeightInit = None - scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=1.0) - ) - ln_bias_axes: Tuple[str, ...] = () - kernel_axes_1: Tuple[str, ...] = () - kernel_axes_2: Tuple[str, ...] = () - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - bias_axes_1: Tuple[str, ...] = () - bias_axes_2: Tuple[str, ...] = () - enable_low_rank_adaptation: bool = False - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ("relu",) - intermediate_dropout_rate: float = 0.1 - intermediate_hidden_dropout_dims: Sequence[int] = () - axis: Union[Iterable[int], int] = -1 - transpose_batch_sequence: bool = False - - def setup(self) -> None: - """setup""" - super().setup() - - ln_mlp_cls = partial( - flax_LayerNormMLP, - intermediate_dim=self.intermediate_dim, - enable_layernorm=self.enable_layernorm, - layernorm_type=self.layernorm_type, - epsilon=self.epsilon, - zero_centered_gamma=self.zero_centered_gamma, - scale_init=_generate_ln_scale_init(self.scale_init), - scale_axes=self.scale_axes, - ln_bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", self.ln_bias_init - ), - ln_bias_axes=self.ln_bias_axes, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - kernel_axes_1=self.kernel_axes_1, - kernel_axes_2=self.kernel_axes_2, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - bias_axes_1=self.bias_axes_1, - bias_axes_2=self.bias_axes_2, - enable_low_rank_adaptation=self.enable_low_rank_adaptation, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - return_layernorm_output=self.return_layernorm_output, - activations=self.activations, - intermediate_dropout_rate=self.intermediate_dropout_rate, - intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims, - axis=self.axis, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - ) - - self.create_layer("ln_mlp", ln_mlp_cls) - - def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor: - """__call__""" - return self.ln_mlp(x, deterministic) diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py deleted file mode 100644 index f441834355..0000000000 --- a/transformer_engine/jax/praxis/transformer.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -""" -Praxis Modules related Transformer -""" -from dataclasses import field -from functools import partial -from typing import Optional, Sequence, Tuple -import warnings - -from praxis import pax_fiddle -from praxis.base_layer import WeightInit -from praxis.pytypes import JTensor - -from .module import TransformerEngineBaseLayer -from ..flax.transformer import TransformerLayerType -from ..flax.transformer import DotProductAttention as flax_DotProductAttention -from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention -from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases -from ..flax.transformer import TransformerLayer as flax_TransformerLayer -from ..attention import AttnBiasType, AttnMaskType - - -class RelativePositionBiases(TransformerEngineBaseLayer): - """RelativePositionBiases""" - - num_buckets: int = 32 - max_distance: int = 128 - num_attention_heads: int = 64 - embedding_init: WeightInit = None - embedding_axes: Tuple[str, ...] = () - - @staticmethod - def generate_embedding_init(init, num_attention_heads, num_buckets): - """generate_embedding_init""" - embedding_init = init - if embedding_init is None: - rb_stddev = (num_attention_heads * num_buckets) ** -0.5 - embedding_init = WeightInit.Gaussian(rb_stddev) - return embedding_init - - def setup(self) -> None: - """setup""" - super().setup() - - embedding_init = RelativePositionBiases.generate_embedding_init( - self.embedding_init, self.num_attention_heads, self.num_buckets - ) - - rpb_cls = partial( - flax_RelativePositionBiases, - num_buckets=self.num_buckets, - max_distance=self.max_distance, - num_attention_heads=self.num_attention_heads, - embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", embedding_init - ), - embedding_axes=self.embedding_axes, - dtype=self.dtype, - ) - - self.create_layer("relative_position_bias", rpb_cls) - - def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor: - """__call__""" - return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional) - - -class DotProductAttention(TransformerEngineBaseLayer): - """DotProductAttention""" - - head_dim: int = 0 - num_attention_heads: int = 0 - num_gqa_groups: Optional[int] = None - attention_dropout: float = 0.0 - attn_mask_type: AttnMaskType = "causal" - attn_bias_type: AttnBiasType = None - dropout_rng_name: str = "dropout" - float32_logits: bool = False - qkv_layout: str = "bshd_bshd_bshd" - scale_factor: Optional[float] = None - transpose_batch_sequence: bool = True - window_size: Optional[Tuple[int, int]] = None - - def setup(self) -> None: - """setup""" - super().setup() - - assert self.head_dim > 0, f"{self.head_dim=}" - assert self.num_attention_heads > 0, f"{self.num_attention_heads=}" - - dpa_cls = partial( - flax_DotProductAttention, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - attention_dropout=self.attention_dropout, - dtype=self.dtype, - dropout_rng_name=self.dropout_rng_name, - float32_logits=self.float32_logits, - qkv_layout=self.qkv_layout, - scale_factor=self.scale_factor, - transpose_batch_sequence=self.transpose_batch_sequence, - window_size=self.window_size, - ) - - self.create_layer("dot_product_attention", dpa_cls) - - def __call__( - self, - query: JTensor, - key: JTensor, - value: JTensor, - mask: Optional[JTensor] = None, - bias: Optional[JTensor] = None, - *, - deterministic: bool = False, - ) -> JTensor: - """__call__""" - return self.dot_product_attention( - query, key, value, mask, bias, deterministic=deterministic - ) - - -class MultiHeadAttention(TransformerEngineBaseLayer): - """MultiHeadAttention""" - - head_dim: int = 0 - num_attention_heads: int = 0 - num_gqa_groups: Optional[int] = None - attention_dropout: float = 0.0 - dropout_rng_name: str = "dropout" - input_layernorm: bool = True - layernorm_type: str = "layernorm" - layernorm_epsilon: float = 1e-6 - zero_centered_gamma: bool = False - return_layernorm_output: bool = False - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - attn_mask_type: str = "causal" - attn_bias_type: Optional[str] = None - enable_rotary_pos_emb: bool = False - rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = "consecutive" - low_rank_adaptation_scope: str = "none" - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - fuse_qkv_params: bool = True - transpose_batch_sequence: bool = True - enable_sequence_parallel: bool = False - scale_attn_logits: bool = False - scaled_query_init: bool = True - float32_logits: bool = False - window_size: Optional[Tuple[int, int]] = None - - # Deprecated parameters - num_heads: Optional[int] = None - dropout_rate: Optional[float] = None - output_layernorm: Optional[bool] = None - apply_residual_connection_post_layernorm: Optional[bool] = None - fuse_qkv: Optional[bool] = None - - def __post_init__(self): - # Deal with the deprecated parameters - if self.num_heads is not None: - self.num_attention_heads = self.num_heads - warnings.warn( - f"{__class__}.num_heads is deprecated. It will be removed recently. " - f"Please uses {__class__}.num_attention_heads as the new API.", - DeprecationWarning, - ) - if self.dropout_rate is not None: - self.attention_dropout = self.dropout_rate - warnings.warn( - f"{__class__}.dropout_rate is deprecated. It will be removed recently. " - f"Please use {__class__}.attention_dropout as the new API.", - DeprecationWarning, - ) - if self.apply_residual_connection_post_layernorm is not None: - warnings.warn( - f"{__class__}.apply_residual_connection_post_layernorm is deprecated. " - f"It will be removed recently, please use {__class__}.return_layernorm_output.", - DeprecationWarning, - ) - if self.fuse_qkv is not None: - warnings.warn( - f"{__class__}.fuse_qkv is deprecated. It will be removed recently. " - f"Please use {__class__}.fuse_qkv_params as the new API.", - DeprecationWarning, - ) - assert self.output_layernorm is None, ( - f"{__class__}.output_layernorm is deprecated. It will be removed recently. " - f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm." - ) - - if self.num_gqa_groups is None: - self.num_gqa_groups = self.num_heads - super().__post_init__() - - def setup(self) -> None: - """setup""" - super().setup() - - assert self.head_dim > 0, f"{self.head_dim=}" - assert self.num_attention_heads > 0, f"{self.num_attention_heads=}" - - mha_cls = partial( - flax_MultiHeadAttention, - dtype=self.dtype, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - attention_dropout=self.attention_dropout, - dropout_rng_name=self.dropout_rng_name, - input_layernorm=self.input_layernorm, - layernorm_type=self.layernorm_type, - layernorm_epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - return_layernorm_output=self.return_layernorm_output, - kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - enable_rotary_pos_emb=self.enable_rotary_pos_emb, - rotary_pos_emb_windows=self.rotary_pos_emb_windows, - rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, - low_rank_adaptation_scope=self.low_rank_adaptation_scope, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - fuse_qkv_params=self.fuse_qkv_params, - transpose_batch_sequence=self.transpose_batch_sequence, - enable_sequence_parallel=self.enable_sequence_parallel, - scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init, - float32_logits=self.float32_logits, - window_size=self.window_size, - ) - - self.create_layer("multi_head_attn", mha_cls) - - def __call__( - self, - inputs_q: JTensor, - inputs_kv: JTensor, - mask: Optional[JTensor] = None, - bias: Optional[JTensor] = None, - *, - decode: bool = False, - deterministic: bool = False, - ) -> JTensor: - """__call__""" - return self.multi_head_attn( - inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic - ) - - -class TransformerLayer(TransformerEngineBaseLayer): - """TransformerLayer""" - - hidden_size: int = 512 - mlp_hidden_size: int = 2048 - num_attention_heads: int = 8 - num_gqa_groups: Optional[int] = None - layernorm_type: str = "layernorm" - layernorm_epsilon: float = 1e-6 - zero_centered_gamma: bool = False - hidden_dropout: float = 0.1 - hidden_dropout_dims: Sequence[int] = () - attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 - intermediate_dropout_dims: Sequence[int] = () - dropout_rng_name: str = "dropout" - mlp_activations: Sequence[str] = ("relu",) - use_bias: bool = False - bias_init: WeightInit = field( # pylint: disable=invalid-field-call - default_factory=partial(WeightInit.Constant, scale=0.0) - ) - apply_residual_connection_post_layernorm: bool = False - output_layernorm: bool = False - float32_attention_logits: bool = False - layer_type: TransformerLayerType = TransformerLayerType.ENCODER - self_attn_mask_type: str = "causal" - self_attn_bias_type: Optional[str] = None - enable_rotary_pos_emb: bool = False - rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = "consecutive" - low_rank_adaptation_scope: str = "none" - low_rank_adaptation_dim: int = 32 - low_rank_adaptation_alpha: float = None - enable_relative_embedding: bool = True - relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) - drop_path: float = 0.0 - fuse_qkv_params: bool = True - transpose_batch_sequence: bool = False - enable_sequence_parallel: bool = False - scale_attn_logits: bool = False - scaled_query_init: bool = True - window_size: Optional[Tuple[int, int]] = None - - def __post_init__(self): - if self.num_gqa_groups is None: - self.num_gqa_groups = self.num_attention_heads - super().__post_init__() - - def setup(self) -> None: - """setup""" - super().setup() - - relative_embedding_flax_module = None - if self.enable_relative_embedding and self.relative_embedding is not None: - assert self.relative_embedding.num_attention_heads == self.num_attention_heads, ( - "TransformerLayer.relative_embedding.num_attention_heads shoule be" - "the same as TransformerLayer.num_attention_heads." - ) - - embedding_init = RelativePositionBiases.generate_embedding_init( - self.relative_embedding.embedding_init, - self.relative_embedding.num_attention_heads, - self.relative_embedding.num_buckets, - ) - - relative_embedding_flax_module = flax_RelativePositionBiases( - num_buckets=self.relative_embedding.num_buckets, - max_distance=self.relative_embedding.max_distance, - num_attention_heads=self.relative_embedding.num_attention_heads, - embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", embedding_init - ), - embedding_axes=self.relative_embedding.embedding_axes, - dtype=self.relative_embedding.dtype, - ) - - transformerlayer_cls = partial( - flax_TransformerLayer, - dtype=self.dtype, - hidden_size=self.hidden_size, - mlp_hidden_size=self.mlp_hidden_size, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - layernorm_type=self.layernorm_type, - layernorm_epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - hidden_dropout=self.hidden_dropout, - hidden_dropout_dims=self.hidden_dropout_dims, - attention_dropout=self.attention_dropout, - intermediate_dropout=self.intermediate_dropout, - intermediate_dropout_dims=self.intermediate_dropout_dims, - dropout_rng_name=self.dropout_rng_name, - mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mha_kernel", self.params_init - ), - mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mlp_kernel", self.params_init - ), - mlp_activations=self.mlp_activations, - use_bias=self.use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, - output_layernorm=self.output_layernorm, - float32_attention_logits=self.float32_attention_logits, - layer_type=self.layer_type, - self_attn_mask_type=self.self_attn_mask_type, - self_attn_bias_type=self.self_attn_bias_type, - enable_rotary_pos_emb=self.enable_rotary_pos_emb, - rotary_pos_emb_windows=self.rotary_pos_emb_windows, - rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, - low_rank_adaptation_scope=self.low_rank_adaptation_scope, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - enable_relative_embedding=self.enable_relative_embedding, - relative_embedding=relative_embedding_flax_module, - drop_path=self.drop_path, - fuse_qkv_params=self.fuse_qkv_params, - transpose_batch_sequence=self.transpose_batch_sequence, - enable_sequence_parallel=self.enable_sequence_parallel, - scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init, - window_size=self.window_size, - ) - - self.create_layer("transformerlayer", transformerlayer_cls) - - def __call__( - self, - inputs: JTensor, - encoded: JTensor = None, - attention_mask: JTensor = None, - encoder_decoder_mask: JTensor = None, - deterministic: bool = False, - decode: bool = False, - max_decode_length: bool = None, - ) -> JTensor: - """__call__""" - return self.transformerlayer( - inputs, - encoded, - attention_mask, - encoder_decoder_mask, - deterministic, - decode, - max_decode_length, - ) diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index a9fc6b6b6f..ef3c05a882 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -101,7 +101,7 @@ ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, install_requires=["jax", "flax>=0.7.1"], - tests_require=["numpy", "praxis"], + tests_require=["numpy"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) From 0fbb286c66beef2be0b73377c0a3ee34098e216b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 17 Apr 2025 18:58:22 -0700 Subject: [PATCH 48/53] replace split_bw with delay_wgrad_compute Signed-off-by: Hongbin Liu --- tests/pytorch/distributed/run_numerics.py | 12 +-- tests/pytorch/test_numerics.py | 76 +++++++++---------- transformer_engine/pytorch/module/_common.py | 23 +++--- transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/grouped_linear.py | 14 ++-- .../pytorch/module/layernorm_linear.py | 8 +- .../pytorch/module/layernorm_mlp.py | 14 ++-- transformer_engine/pytorch/module/linear.py | 8 +- 8 files changed, 83 insertions(+), 74 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 48e831acee..621d036212 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -479,7 +479,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): _loss_backward(output_single_node, output_distributed) # Compute delayed weight gradient - if "split_bw" in kwargs: + if "delay_wgrad_compute" in kwargs: _loss_backward_dw(model_single_node, model_distributed) # Validate outputs and gradients @@ -503,7 +503,7 @@ def test_linear(): {"fuse_wgrad_accumulation": True}, {"return_bias": True}, {"params_dtype": torch.float16}, - {"split_bw": True}, + {"delay_wgrad_compute": True}, ] for kwargs in kwargs_list: for parallel_mode in ["column", "row"]: @@ -656,7 +656,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs _loss_backward(output_single_node, output_distributed) # Compute delayed weight gradient - if "split_bw" in kwargs: + if "delay_wgrad_compute" in kwargs: _loss_backward_dw(model_single_node, model_distributed) # Validate outputs and gradients @@ -681,7 +681,7 @@ def test_layernorm_linear(): {"params_dtype": torch.float16}, {"zero_centered_gamma": False}, {"return_layernorm_output": True}, - {"split_bw": True}, + {"delay_wgrad_compute": True}, ] for kwargs in kwargs_list: for parallel_mode in ["column"]: @@ -761,7 +761,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg # Compute loss and backpropagate _loss_backward(output_single_node, output_distributed) - if "split_bw" in kwargs: + if "delay_wgrad_compute" in kwargs: _loss_backward_dw(model_single_node, model_distributed) # Validate outputs and gradients @@ -789,7 +789,7 @@ def test_layernorm_mlp(): {"fuse_wgrad_accumulation": True}, {"return_bias": True}, {"return_layernorm_output": True}, - {"split_bw": True}, + {"delay_wgrad_compute": True}, ] for kwargs in kwargs_list: diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index cb8cced93b..b2f56efb80 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1036,7 +1036,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_granular_accuracy(block, bs, dtype, config, split_bw=False): +def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False): reset_rng_states() inp_hidden_states = torch.randn( @@ -1052,7 +1052,7 @@ def _test_granular_accuracy(block, bs, dtype, config, split_bw=False): out = out[0] loss = out.sum() loss.backward() - if split_bw: + if delay_wgrad_compute: block.backward_dw() torch.cuda.synchronize() @@ -1202,7 +1202,7 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulation): +def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation): config = model_configs[model] te_linear_ref = Linear( @@ -1211,7 +1211,7 @@ def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulatio bias=bias, params_dtype=dtype, device="cuda", - split_bw=False, + delay_wgrad_compute=False, fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() @@ -1221,7 +1221,7 @@ def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulatio bias=bias, params_dtype=dtype, device="cuda", - split_bw=True, + delay_wgrad_compute=True, fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() @@ -1235,8 +1235,8 @@ def test_linear_accuracy_split_bw(dtype, bs, model, bias, fuse_wgrad_accumulatio weight.main_grad = torch.rand_like(weight, dtype=torch.float32) te_linear_ref.weight.main_grad = weight.main_grad.clone() - te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, split_bw=True) - te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, split_bw=False) + te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True) + te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, delay_wgrad_compute=False) # Shoule be bit-wise match for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): @@ -1435,7 +1435,7 @@ def test_layernorm_linear_accuracy( @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -def test_layernorm_linear_accuracy_split_bw( +def test_layernorm_linear_accuracy_delay_wgrad_compute( dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation ): config = model_configs[model] @@ -1449,7 +1449,7 @@ def test_layernorm_linear_accuracy_split_bw( params_dtype=dtype, zero_centered_gamma=zero_centered_gamma, device="cuda", - split_bw=False, + delay_wgrad_compute=False, fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() @@ -1462,7 +1462,7 @@ def test_layernorm_linear_accuracy_split_bw( params_dtype=dtype, zero_centered_gamma=zero_centered_gamma, device="cuda", - split_bw=True, + delay_wgrad_compute=True, fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() @@ -1479,8 +1479,8 @@ def test_layernorm_linear_accuracy_split_bw( weight.main_grad = torch.rand_like(weight, dtype=torch.float32) ln_linear_ref.weight.main_grad = weight.main_grad.clone() - te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, split_bw=True) - te_outputs_ref = _test_granular_accuracy(ln_linear_ref, bs, dtype, config, split_bw=False) + te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True) + te_outputs_ref = _test_granular_accuracy(ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False) # Shoule be bit-wise match for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): @@ -1570,12 +1570,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -def test_layernorm_mlp_accuracy_split_bw( +def test_layernorm_mlp_accuracy_delay_wgrad_compute( dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation ): config = model_configs[model] - ln_mlp_split_bw = LayerNormMLP( + ln_mlp = LayerNormMLP( hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size, eps=config.eps, @@ -1583,7 +1583,7 @@ def test_layernorm_mlp_accuracy_split_bw( normalization=normalization, params_dtype=dtype, device="cuda", - split_bw=True, + delay_wgrad_compute=True, fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() @@ -1595,32 +1595,32 @@ def test_layernorm_mlp_accuracy_split_bw( normalization=normalization, params_dtype=dtype, device="cuda", - split_bw=False, + delay_wgrad_compute=False, fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() # Share params with torch.no_grad(): - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp_split_bw.layer_norm_weight.clone()) + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) if normalization != "RMSNorm": - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp_split_bw.layer_norm_bias.clone()) - ln_mlp_ref.fc1_weight = Parameter(ln_mlp_split_bw.fc1_weight.clone()) - ln_mlp_ref.fc2_weight = Parameter(ln_mlp_split_bw.fc2_weight.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) if bias: - ln_mlp_ref.fc1_bias = Parameter(ln_mlp_split_bw.fc1_bias.clone()) - ln_mlp_ref.fc2_bias = Parameter(ln_mlp_split_bw.fc2_bias.clone()) + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) if fuse_wgrad_accumulation: - ln_mlp_split_bw.fc1_weight.main_grad = torch.rand_like( - ln_mlp_split_bw.fc1_weight, dtype=torch.float32 + ln_mlp.fc1_weight.main_grad = torch.rand_like( + ln_mlp.fc1_weight, dtype=torch.float32 ) - ln_mlp_ref.fc1_weight.main_grad = ln_mlp_split_bw.fc1_weight.main_grad.clone() - ln_mlp_split_bw.fc2_weight.main_grad = torch.rand_like( - ln_mlp_split_bw.fc2_weight, dtype=torch.float32 + ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() + ln_mlp.fc2_weight.main_grad = torch.rand_like( + ln_mlp.fc2_weight, dtype=torch.float32 ) - ln_mlp_ref.fc2_weight.main_grad = ln_mlp_split_bw.fc2_weight.main_grad.clone() + ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() - te_outputs = _test_granular_accuracy(ln_mlp_split_bw, bs, dtype, config, split_bw=True) - te_outputs_ref = _test_granular_accuracy(ln_mlp_ref, bs, dtype, config, split_bw=False) + te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True) + te_outputs_ref = _test_granular_accuracy(ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False) # Shoule be bit-wise match for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): @@ -1628,7 +1628,7 @@ def test_layernorm_mlp_accuracy_split_bw( def _test_grouped_linear_accuracy( - block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw=False + block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute=False ): reset_rng_states() if fp8: @@ -1670,7 +1670,7 @@ def _test_grouped_linear_accuracy( ) loss = out.sum() loss.backward() - if split_bw: + if delay_wgrad_compute: if isinstance(block, GroupedLinear): block.backward_dw() else: @@ -1697,7 +1697,7 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("bias", all_boolean) -@pytest.mark.parametrize("split_bw", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_accuracy( dtype, num_gemms, @@ -1707,7 +1707,7 @@ def test_grouped_linear_accuracy( fp8_model_params, fuse_wgrad_accumulation, bias, - split_bw, + delay_wgrad_compute, parallel_mode=None, ): fp8 = recipe is not None @@ -1732,7 +1732,7 @@ def test_grouped_linear_accuracy( parallel_mode=parallel_mode, device="cuda", fuse_wgrad_accumulation=fuse_wgrad_accumulation, - split_bw=split_bw, + delay_wgrad_compute=delay_wgrad_compute, ).eval() sequential_linear = torch.nn.ModuleList( [ @@ -1769,10 +1769,10 @@ def test_grouped_linear_accuracy( recipe, fp8, fuse_wgrad_accumulation, - split_bw, + delay_wgrad_compute, ) outputs = _test_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, split_bw + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute ) # Shoule be bit-wise match @@ -1792,7 +1792,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe): fp8_model_params=True, fuse_wgrad_accumulation=True, bias=True, - split_bw=False, + delay_wgrad_compute=False, ) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 85e44b8de8..1efe2dce68 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -225,22 +225,23 @@ class WeightGradStore: This class enables split backward propagation for better memory efficiency. """ - def __init__(self, split_bw=False, ub_bulk_wgrad=False): + def __init__(self, delay_wgrad_compute=False, ub_bulk_wgrad=False): """ Initialize the WeightGradStore. Args: - split_bw (bool): Whether to enable split backward propagation + delay_wgrad_compute (bool): Whether to delay weight gradient computation + ub_bulk_wgrad (bool): Whether to enable bulk weight gradient computation """ - if split_bw: + if delay_wgrad_compute: self.context = queue.Queue() - assert ub_bulk_wgrad is False, "ub_bulk_wgrad is not supported when enabling split_bw" - self.enabled = split_bw + assert ub_bulk_wgrad is False, "ub_bulk_wgrad is not supported when enabling delay_wgrad_compute" + self.enabled = delay_wgrad_compute else: self.context = None self.enabled = False - def split_bw(self): + def delay_wgrad_compute(self): """ Get the current split backward propagation status. @@ -249,11 +250,11 @@ def split_bw(self): """ return self.enabled - def enable_split_bw(self): + def enable_delay_wgrad_compute(self): """Enable split backward propagation.""" self.enabled = True - def disable_split_bw(self): + def disable_delay_wgrad_compute(self): """Disable split backward propagation.""" self.enabled = False @@ -265,7 +266,7 @@ def put(self, tensor_list, func): tensor_list (list): List of tensors needed for computation func (callable): Function to be executed with the tensors """ - assert self.enabled is True, "split_bw is not enabled" + assert self.enabled is True, "delay_wgrad_compute is not enabled" self.context.put([tensor_list, func]) def pop(self): @@ -273,7 +274,7 @@ def pop(self): Execute the stored computation with the stored tensors. Raises an exception if the queue is empty. """ - assert self.enabled is True, "split_bw is not enabled" + assert self.enabled is True, "delay_wgrad_compute is not enabled" if self.context.qsize() > 0: tensor_list, func = self.context.get() return func(*tensor_list), tensor_list @@ -287,6 +288,6 @@ def assert_empty(self): Assert that the queue is empty. Used for debugging and ensuring proper cleanup. """ - assert self.enabled is True, "split_bw is not enabled" + assert self.enabled is True, "delay_wgrad_compute is not enabled" rank = torch.distributed.get_rank() assert self.context.empty(), f"Queue is not empty. rank {rank}" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e1a59f38ad..17848a36bf 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1145,7 +1145,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.split_bw(): + if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): return with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): (wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b992a72e71..54899c7b36 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -345,7 +345,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate=accumulate_wgrad_into_param_main_grad, ) # WGRAD - if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) else: _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) @@ -392,11 +392,11 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): wgrad_list = [None] * ctx.num_gemms if not ctx.use_bias or ( - ctx.wgrad_store is not None and ctx.wgrad_store.split_bw() and not ctx.fp8 + ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() and not ctx.fp8 ): grad_biases = [None] * ctx.num_gemms @@ -469,6 +469,8 @@ class GroupedLinear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + delay_wgrad_compute : bool, default = `False` + Whether to delay weight gradient computation Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and `parallel_mode` are used to determine the shapes of weights and biases. @@ -495,7 +497,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, - split_bw: bool = False, + delay_wgrad_compute: bool = False, ) -> None: super().__init__() @@ -516,7 +518,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.wgrad_store = WeightGradStore(split_bw) + self.wgrad_store = WeightGradStore(delay_wgrad_compute) self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} self._num_fp8_tensors_per_gemm = { @@ -749,7 +751,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.split_bw(): + if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a343ddbd82..d18e68d936 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -774,7 +774,7 @@ def backward( bulk_overlap=ctx.ub_bulk_wgrad, ) - if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad) else: wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output) @@ -1004,6 +1004,8 @@ class LayerNormLinear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + delay_wgrad_compute : bool, default = `False` + Whether to delay weight gradient computation symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. @@ -1038,7 +1040,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, - split_bw: bool = False, + delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: str = None, ) -> None: @@ -1058,7 +1060,7 @@ def __init__( self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - self.wgrad_store = WeightGradStore(split_bw, ub_bulk_wgrad) + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name if TEDebugState.debug_enabled: self._turn_off_unsupported_features_in_debug() # turn off userbuffers diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ddacf580c4..cd861f90cb 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -840,7 +840,7 @@ def backward( use_split_accumulator=_2X_ACC_WGRAD, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) - if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) fc2_wgrad = None # if fc2_bias is not None and fc2_bias_grad is None: @@ -862,7 +862,7 @@ def backward( fc2_bias_grad = fc2_bias_grad_ del fc2_bias_grad_ - if ctx.wgrad_store is not None and not ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute(): clear_tensor_data(act_out) # bias computation @@ -1056,7 +1056,7 @@ def backward( extra_output=fc1_dgrad_rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) - if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) fc1_wgrad = None # (fc1_wgrad_outputs), _ = ctx.wgrad_store.pop() @@ -1327,6 +1327,8 @@ class LayerNormMLP(TransformerEngineBaseModule): batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase. + delay_wgrad_compute : bool, default = `False` + Whether to delay weight gradient computation symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. @@ -1364,7 +1366,7 @@ def __init__( ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, - split_bw: bool = False, + delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, ) -> None: super().__init__() @@ -1397,7 +1399,7 @@ def __init__( if TEDebugState.debug_enabled: self._turn_off_unsupported_features_in_debug() # turn off userbuffers - self.wgrad_store = WeightGradStore(split_bw, ub_bulk_wgrad) + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if tp_group is None: self.tp_size = tp_size @@ -1876,7 +1878,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.split_bw(): + if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): return with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c7407e1362..3596f6aebd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -697,7 +697,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], bulk_overlap=ctx.ub_bulk_wgrad, ) - if ctx.wgrad_store is not None and ctx.wgrad_store.split_bw(): + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad) else: wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output) @@ -869,6 +869,8 @@ class Linear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + delay_wgrad_compute : bool, default = `False` + Whether to delay weight gradient computation symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. @@ -899,7 +901,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, - split_bw: bool = False, + delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: Optional[str] = None, ) -> None: @@ -920,7 +922,7 @@ def __init__( if TEDebugState.debug_enabled: self._turn_off_unsupported_features_in_debug() # turn off userbuffers - self.wgrad_store = WeightGradStore(split_bw, ub_bulk_wgrad) + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." From 559e9bd12f63474c575d4ccbf5bf2d6aa93f50c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Apr 2025 02:03:48 +0000 Subject: [PATCH 49/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_numerics.py | 40 ++++++++++++++----- transformer_engine/pytorch/module/_common.py | 4 +- .../pytorch/module/grouped_linear.py | 4 +- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b2f56efb80..905339f4d3 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1236,7 +1236,9 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True) - te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, delay_wgrad_compute=False) + te_outputs_ref = _test_granular_accuracy( + te_linear_ref, bs, dtype, config, delay_wgrad_compute=False + ) # Shoule be bit-wise match for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): @@ -1480,7 +1482,9 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ln_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True) - te_outputs_ref = _test_granular_accuracy(ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False) + te_outputs_ref = _test_granular_accuracy( + ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False + ) # Shoule be bit-wise match for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): @@ -1610,17 +1614,15 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) if fuse_wgrad_accumulation: - ln_mlp.fc1_weight.main_grad = torch.rand_like( - ln_mlp.fc1_weight, dtype=torch.float32 - ) + ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32) ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() - ln_mlp.fc2_weight.main_grad = torch.rand_like( - ln_mlp.fc2_weight, dtype=torch.float32 - ) + ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32) ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True) - te_outputs_ref = _test_granular_accuracy(ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False) + te_outputs_ref = _test_granular_accuracy( + ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False + ) # Shoule be bit-wise match for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): @@ -1628,7 +1630,15 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( def _test_grouped_linear_accuracy( - block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute=False + block, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute=False, ): reset_rng_states() if fp8: @@ -1772,7 +1782,15 @@ def test_grouped_linear_accuracy( delay_wgrad_compute, ) outputs = _test_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, ) # Shoule be bit-wise match diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 1efe2dce68..4828e9bc10 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -235,7 +235,9 @@ def __init__(self, delay_wgrad_compute=False, ub_bulk_wgrad=False): """ if delay_wgrad_compute: self.context = queue.Queue() - assert ub_bulk_wgrad is False, "ub_bulk_wgrad is not supported when enabling delay_wgrad_compute" + assert ( + ub_bulk_wgrad is False + ), "ub_bulk_wgrad is not supported when enabling delay_wgrad_compute" self.enabled = delay_wgrad_compute else: self.context = None diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 54899c7b36..f9bb7d767a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -396,7 +396,9 @@ def handle_custom_ddp_from_mcore(weight, wgrad): wgrad_list = [None] * ctx.num_gemms if not ctx.use_bias or ( - ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() and not ctx.fp8 + ctx.wgrad_store is not None + and ctx.wgrad_store.delay_wgrad_compute() + and not ctx.fp8 ): grad_biases = [None] * ctx.num_gemms From 7b2926511b50cc4b532ee78e4ef72cbfc3bbdff4 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 17 Apr 2025 19:16:07 -0700 Subject: [PATCH 50/53] Update transformer_engine/pytorch/module/layernorm_mlp.py Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index cd861f90cb..9cc66bfc42 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1328,7 +1328,9 @@ class LayerNormMLP(TransformerEngineBaseModule): fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase. delay_wgrad_compute : bool, default = `False` - Whether to delay weight gradient computation + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. From bfb3d3732faf3e198a9f90c693e065b9d0735ece Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 17 Apr 2025 19:16:18 -0700 Subject: [PATCH 51/53] Update transformer_engine/pytorch/module/linear.py Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3596f6aebd..7803f4a084 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -870,7 +870,9 @@ class Linear(TransformerEngineBaseModule): the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. delay_wgrad_compute : bool, default = `False` - Whether to delay weight gradient computation + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. From c630beb6ceebb5ebcb680ce357d7ec3ae2bd0227 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 17 Apr 2025 19:16:27 -0700 Subject: [PATCH 52/53] Update transformer_engine/pytorch/module/layernorm_linear.py Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d18e68d936..d3bfed5885 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1005,7 +1005,9 @@ class LayerNormLinear(TransformerEngineBaseModule): the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. delay_wgrad_compute : bool, default = `False` - Whether to delay weight gradient computation + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. From 2a7087e732a1f6c6de1c31eccba0370e890a32e4 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 17 Apr 2025 19:21:37 -0700 Subject: [PATCH 53/53] remove comments Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/layernorm_mlp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9cc66bfc42..b5f574f766 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -843,8 +843,6 @@ def backward( if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) fc2_wgrad = None - # if fc2_bias is not None and fc2_bias_grad is None: - # fc2_bias_grad = None else: fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad( act_out, @@ -1048,7 +1046,7 @@ def backward( layout="NT", quantization_params=ctx.fc1_grad_weight_quantizer, grad=fuse_gemm_and_bias_fc1_wgrad, - bias=(fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None), + bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub=ub_obj_fc1_wgrad, @@ -1059,7 +1057,6 @@ def backward( if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) fc1_wgrad = None - # (fc1_wgrad_outputs), _ = ctx.wgrad_store.pop() if fuse_gemm_and_bias_fc1_wgrad: fc1_bias_grad = None else: