From ae109ffa8a899bdaa63cd701ca03b5184a746069 Mon Sep 17 00:00:00 2001 From: geyuhong Date: Tue, 2 Sep 2025 23:48:02 +0800 Subject: [PATCH 01/10] adapt grouped_linear, layernorm_linear and linear --- .../pytorch/module/grouped_linear.py | 28 ++++++++++++++++--- .../pytorch/module/layernorm_linear.py | 28 ++++++++++++++++--- transformer_engine/pytorch/module/linear.py | 28 ++++++++++++++++--- 3 files changed, 72 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3d7a5efaca..18aef78c13 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -209,6 +209,24 @@ def forward( if isinstance(weight, QuantizedTensorBase): weight.update_usage(columnwise_usage=True) + offload_activation = False + if hasattr(inp, 'offloading_activation'): + offload_activation = True + for i in range(num_gemms): + inputmats[i].offloading_activation = inp.offloading_activation + ctx.offload_activation = offload_activation + + if offload_activation and cpu_offloading: + raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + + if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: + grad_added_to_main_grad_list = [] + for weight in weights: + if weight.requires_grad and hasattr(weight, 'grad_added_to_main_grad'): + grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad) + weight.grad_added_to_main_grad = True + ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list + tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, @@ -271,11 +289,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - w.main_grad = main_grads[i] - weights[i] = w + if not ctx.cpu_offloading: + w = torch.nn.Parameter(weights[i], weights[i].requires_grad) + weights[i] = w + weights[i].main_grad = main_grads[i] + weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..9af38d8e76 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -424,10 +424,28 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + offload_activation = False + if hasattr(inp, 'offloading_activation'): + offload_activation = True + if inputmat.is_contiguous(): + inputmat = inputmat.contiguous() + ctx.offload_activation = offload_activation + + if offload_activation and cpu_offloading: + raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + + if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if hasattr(weight, 'grad_added_to_main_grad'): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -560,9 +578,11 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + if ctx.cpu_offloading or ctx.offload_activation: + if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object + if ctx.offload_activation: + origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..c725c92c11 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -395,10 +395,28 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + offload_activation = False + if hasattr(inp, 'offload_activation'): + offload_activation = True + if saved_inputmat.is_contiguous(): + saved_inputmat = saved_inputmat.contiguous() + ctx.offload_activation = offload_activation + + if offload_activation and cpu_offloading: + raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + + if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if hasattr(weight, 'grad_added_to_main_grad'): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -493,9 +511,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + if ctx.cpu_offloading or ctx.offload_activation: + if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object + if ctx.offload_activation: + weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: weight.main_grad = main_grad From c44b45d37c7ac15a7ac999c908194587a5488e51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Sep 2025 16:11:35 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 8 +++++--- transformer_engine/pytorch/module/layernorm_linear.py | 8 +++++--- transformer_engine/pytorch/module/linear.py | 10 ++++++---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 18aef78c13..c9402cded5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -210,19 +210,21 @@ def forward( weight.update_usage(columnwise_usage=True) offload_activation = False - if hasattr(inp, 'offloading_activation'): + if hasattr(inp, "offloading_activation"): offload_activation = True for i in range(num_gemms): inputmats[i].offloading_activation = inp.offloading_activation ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: - raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + raise ValueError( + f"Do not use offload_activation and cpu_offloading at the same time." + ) if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: grad_added_to_main_grad_list = [] for weight in weights: - if weight.requires_grad and hasattr(weight, 'grad_added_to_main_grad'): + if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad) weight.grad_added_to_main_grad = True ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9af38d8e76..f8aa9a4718 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -425,17 +425,19 @@ def forward( nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") offload_activation = False - if hasattr(inp, 'offloading_activation'): + if hasattr(inp, "offloading_activation"): offload_activation = True if inputmat.is_contiguous(): inputmat = inputmat.contiguous() ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: - raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + raise ValueError( + f"Do not use offload_activation and cpu_offloading at the same time." + ) if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: - if hasattr(weight, 'grad_added_to_main_grad'): + if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c725c92c11..751bc6832b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -396,17 +396,19 @@ def forward( nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") offload_activation = False - if hasattr(inp, 'offload_activation'): + if hasattr(inp, "offload_activation"): offload_activation = True if saved_inputmat.is_contiguous(): saved_inputmat = saved_inputmat.contiguous() ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: - raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") - + raise ValueError( + f"Do not use offload_activation and cpu_offloading at the same time." + ) + if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: - if hasattr(weight, 'grad_added_to_main_grad'): + if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True From 93be702489eb5c588ed6704b8c6fa7f8fe41cd5b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 07:00:32 -0700 Subject: [PATCH 03/10] Bug fix Signed-off-by: Hongbin Liu --- .../pytorch/module/grouped_linear.py | 33 ++++++++++--------- .../pytorch/module/layernorm_linear.py | 18 ++++++---- transformer_engine/pytorch/module/linear.py | 17 ++++++---- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9402cded5..28c92156ee 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -14,6 +14,7 @@ from transformer_engine.common.recipe import Recipe from .base import ( get_multi_stream_cublas_workspace, + get_dummy_wgrad, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -80,6 +81,7 @@ def forward( module, skip_fp8_weight_update, save_original_input, + offload_activation, *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -209,11 +211,10 @@ def forward( if isinstance(weight, QuantizedTensorBase): weight.update_usage(columnwise_usage=True) - offload_activation = False - if hasattr(inp, "offloading_activation"): - offload_activation = True - for i in range(num_gemms): - inputmats[i].offloading_activation = inp.offloading_activation + for i in range(num_gemms): + weights[i].offloading_activation = False + weights_fp8[i].offloading_activation = False + biases[i].offloading_activation = False ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: @@ -448,18 +449,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -506,6 +504,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, *wgrad_list, *grad_biases, ) @@ -587,6 +586,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, + offload_activation: bool = False, delay_wgrad_compute: bool = False, save_original_input: bool = False, ) -> None: @@ -610,6 +610,8 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.offload_activation = offload_activation + self.wgrad_store = WeightGradStore(delay_wgrad_compute) self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} @@ -825,6 +827,7 @@ def forward( self.sequence_parallel, self.activation_dtype, torch.is_grad_enabled(), + self.offload_activation, self, skip_fp8_weight_update, self.save_original_input, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f8aa9a4718..9eaec062b2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -122,6 +122,7 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_name: str, + offload_activation: bool, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, @@ -424,11 +425,12 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - offload_activation = False - if hasattr(inp, "offloading_activation"): - offload_activation = True - if inputmat.is_contiguous(): - inputmat = inputmat.contiguous() + # Do not offload weights and biases + weight.offloading_activation = False + weightmat.offloading_activation = False + if bias is not None: + bias.offloading_activation = False + ln_weight.offloading_activation = False ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: @@ -441,6 +443,7 @@ def forward( ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True + ctx.weight_object = weight else: ctx.has_grad_added_to_main_grad = False @@ -1043,6 +1046,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # offload_activation None, # fsdp_group None, # debug None, # module @@ -1178,6 +1182,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: str = None, + offload_activation: bool = False, ) -> None: super().__init__() @@ -1194,7 +1199,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - + self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name @@ -1597,6 +1602,7 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_name, + self.offload_activation, self.fsdp_group, self, skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 751bc6832b..6d7864dda9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -109,6 +109,7 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, + offload_activation: bool, fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -395,11 +396,6 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - offload_activation = False - if hasattr(inp, "offload_activation"): - offload_activation = True - if saved_inputmat.is_contiguous(): - saved_inputmat = saved_inputmat.contiguous() ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: @@ -412,6 +408,7 @@ def forward( ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True + ctx.weight_object = weight else: ctx.has_grad_added_to_main_grad = False @@ -426,6 +423,11 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight + # Do not offload weights and biases + weight.offloading_activation = False + weightmat.offloading_activation = False + if bias is not None: + bias.offloading_activation = False # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -990,6 +992,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # offload_activation None, # fp8_output None, # fsdp_group None, # module @@ -1112,6 +1115,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + offload_activation: bool = False, ) -> None: super().__init__() @@ -1127,7 +1131,7 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input self.name = name - + self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": @@ -1474,6 +1478,7 @@ def forward( self.ub_bulk_dgrad, self.ub_bulk_wgrad, self.ub_name, + self.offload_activation, fp8_output, self.fsdp_group, self, From f0726f704a403dd390e86652485a279c4bfc2360 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 22:41:00 -0700 Subject: [PATCH 04/10] renaming Signed-off-by: Hongbin Liu --- .../pytorch/module/grouped_linear.py | 18 +++++++-------- .../pytorch/module/layernorm_linear.py | 22 +++++++++---------- transformer_engine/pytorch/module/linear.py | 22 +++++++++---------- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 28c92156ee..1361c4c217 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -81,7 +81,7 @@ def forward( module, skip_fp8_weight_update, save_original_input, - offload_activation, + fine_grained_activation_offloading, *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -215,14 +215,14 @@ def forward( weights[i].offloading_activation = False weights_fp8[i].offloading_activation = False biases[i].offloading_activation = False - ctx.offload_activation = offload_activation + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading - if offload_activation and cpu_offloading: + if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use offload_activation and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) - if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: + if fine_grained_activation_offloading and weights[0].requires_grad and fuse_wgrad_accumulation: grad_added_to_main_grad_list = [] for weight in weights: if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): @@ -292,7 +292,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation: + if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): if not ctx.cpu_offloading: w = torch.nn.Parameter(weights[i], weights[i].requires_grad) @@ -586,7 +586,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, - offload_activation: bool = False, + fine_grained_activation_offloading: bool = False, delay_wgrad_compute: bool = False, save_original_input: bool = False, ) -> None: @@ -610,7 +610,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.offload_activation = offload_activation + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute) @@ -827,7 +827,7 @@ def forward( self.sequence_parallel, self.activation_dtype, torch.is_grad_enabled(), - self.offload_activation, + self.fine_grained_activation_offloading, self, skip_fp8_weight_update, self.save_original_input, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9eaec062b2..55e226066a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -122,7 +122,7 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_name: str, - offload_activation: bool, + fine_grained_activation_offloading: bool, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, @@ -431,14 +431,14 @@ def forward( if bias is not None: bias.offloading_activation = False ln_weight.offloading_activation = False - ctx.offload_activation = offload_activation + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading - if offload_activation and cpu_offloading: + if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use offload_activation and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) - if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad @@ -583,10 +583,10 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. - if ctx.cpu_offloading or ctx.offload_activation: + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object - if ctx.offload_activation: + if ctx.fine_grained_activation_offloading: origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad @@ -1046,7 +1046,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name - None, # offload_activation + None, # fine_grained_activation_offloading None, # fsdp_group None, # debug None, # module @@ -1182,7 +1182,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: str = None, - offload_activation: bool = False, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1199,7 +1199,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - self.offload_activation = offload_activation + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name @@ -1602,7 +1602,7 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_name, - self.offload_activation, + self.fine_grained_activation_offloading, self.fsdp_group, self, skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6d7864dda9..aa449a10a4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -109,7 +109,7 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, - offload_activation: bool, + fine_grained_activation_offloading: bool, fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -396,14 +396,14 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - ctx.offload_activation = offload_activation + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading - if offload_activation and cpu_offloading: + if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use offload_activation and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) - if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad @@ -515,10 +515,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) - if ctx.cpu_offloading or ctx.offload_activation: + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object - if ctx.offload_activation: + if ctx.fine_grained_activation_offloading: weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: weight.main_grad = main_grad @@ -992,7 +992,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name - None, # offload_activation + None, # fine_grained_activation_offloading None, # fp8_output None, # fsdp_group None, # module @@ -1115,7 +1115,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, - offload_activation: bool = False, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1131,7 +1131,7 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input self.name = name - self.offload_activation = offload_activation + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": @@ -1478,7 +1478,7 @@ def forward( self.ub_bulk_dgrad, self.ub_bulk_wgrad, self.ub_name, - self.offload_activation, + self.fine_grained_activation_offloading, fp8_output, self.fsdp_group, self, From a1c6e073fbffd74516a8b6a5611a926b86fa0622 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 22 Sep 2025 04:34:56 -0700 Subject: [PATCH 05/10] minor fix for fp8 Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1361c4c217..311475818a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -827,10 +827,10 @@ def forward( self.sequence_parallel, self.activation_dtype, torch.is_grad_enabled(), - self.fine_grained_activation_offloading, self, skip_fp8_weight_update, self.save_original_input, + self.fine_grained_activation_offloading, *weight_tensors, *bias_tensors, ) From 7933781da84f26bd0029467ee6c1e68aa18ce47e Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 26 Sep 2025 01:03:45 -0700 Subject: [PATCH 06/10] temp fix to enable --overlap-grad-reduce Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9966c78f8..94d83fd638 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -459,7 +459,8 @@ def handle_custom_ddp_from_mcore(weight, wgrad): list(weight.main_grad.shape), weight.dtype, ) - elif ctx.fuse_wgrad_accumulation: + # TODO: Need to check why weight doesn't have attr grad_added_to_main_grad when fine_grained_activation_offloading is True. + elif ctx.fuse_wgrad_accumulation and not ctx.fine_grained_activation_offloading: wgrad = None else: wgrad = None From 963b39c50c4025eb447f524335ad239f57bc6935 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 28 Sep 2025 23:45:22 -0700 Subject: [PATCH 07/10] fix to enable --overlap-grad-reduce Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 94d83fd638..e4e622d157 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -294,11 +294,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): - if not ctx.cpu_offloading: - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - weights[i] = w - weights[i].main_grad = main_grads[i] - weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] + origin_weights[i].main_grad = main_grads[i] + origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) @@ -459,8 +456,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): list(weight.main_grad.shape), weight.dtype, ) - # TODO: Need to check why weight doesn't have attr grad_added_to_main_grad when fine_grained_activation_offloading is True. - elif ctx.fuse_wgrad_accumulation and not ctx.fine_grained_activation_offloading: + elif ctx.fuse_wgrad_accumulation: wgrad = None else: wgrad = None From 98d354c3d3b2571c7752f85a2fcf97fa6fd2aab9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 11 Oct 2025 12:42:07 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 13 ++++++++++--- .../pytorch/module/layernorm_linear.py | 9 +++++++-- transformer_engine/pytorch/module/linear.py | 9 +++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d258504df1..1815a5ef96 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -219,10 +219,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weights[0].requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weights[0].requires_grad + and fuse_wgrad_accumulation + ): grad_added_to_main_grad_list = [] for weight in weights: if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): @@ -292,7 +297,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: + if ( + ctx.cpu_offloading or ctx.fine_grained_activation_offloading + ) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): origin_weights[i].main_grad = main_grads[i] origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bba7f991a4..f230fc13a0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -445,10 +445,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0ac223bf89..9acd5ad1c6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -413,10 +413,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad From 88c7b050d8acd7fb9c363bff66ed64e9cd04a694 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sat, 11 Oct 2025 06:00:18 -0700 Subject: [PATCH 09/10] add comments Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 6 +++++- transformer_engine/pytorch/module/layernorm_linear.py | 5 ++++- transformer_engine/pytorch/module/linear.py | 5 ++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d258504df1..e0f43986b8 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -211,6 +211,7 @@ def forward( if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) + # Do not offload weights and biases for i in range(num_gemms): weights[i].offloading_activation = False weights_fp8[i].offloading_activation = False @@ -219,9 +220,11 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + "Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading if fine_grained_activation_offloading and weights[0].requires_grad and fuse_wgrad_accumulation: grad_added_to_main_grad_list = [] for weight in weights: @@ -292,6 +295,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + # Restore the attributes main_grad and grad_added_to_main_grad of weights if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): origin_weights[i].main_grad = main_grads[i] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bba7f991a4..adf616f7a7 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -445,9 +445,11 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + "Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True @@ -593,6 +595,7 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. + # Restore the attributes grad_added_to_main_grad of weights if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0ac223bf89..7180636ba4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -413,9 +413,11 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + "Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True @@ -529,6 +531,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) + # Restore the attributes main_grad and grad_added_to_main_grad of weights if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object From fe9bab470cc5fd1a6f575301df425a6f6862d258 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Tue, 21 Oct 2025 01:39:39 -0700 Subject: [PATCH 10/10] remove unused code Signed-off-by: hongbinl --- transformer_engine/pytorch/module/grouped_linear.py | 5 ----- transformer_engine/pytorch/module/layernorm_linear.py | 6 ------ transformer_engine/pytorch/module/linear.py | 5 ----- 3 files changed, 16 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 6b4c35170c..7b15ebf527 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -211,11 +211,6 @@ def forward( if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) - # Do not offload weights and biases - for i in range(num_gemms): - weights[i].offloading_activation = False - weights_fp8[i].offloading_activation = False - biases[i].offloading_activation = False ctx.fine_grained_activation_offloading = fine_grained_activation_offloading if fine_grained_activation_offloading and cpu_offloading: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 75c8a030d9..57f4e25eba 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -435,12 +435,6 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - # Do not offload weights and biases - weight.offloading_activation = False - weightmat.offloading_activation = False - if bias is not None: - bias.offloading_activation = False - ln_weight.offloading_activation = False ctx.fine_grained_activation_offloading = fine_grained_activation_offloading if fine_grained_activation_offloading and cpu_offloading: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ffbe02ae98..7b974e4a7f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -442,11 +442,6 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - # Do not offload weights and biases - weight.offloading_activation = False - weightmat.offloading_activation = False - if bias is not None: - bias.offloading_activation = False # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat,