diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b3adfb7dbf..7b15ebf527 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -14,6 +14,7 @@ from transformer_engine.common.recipe import Recipe from .base import ( get_multi_stream_cublas_workspace, + get_dummy_wgrad, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -80,6 +81,7 @@ def forward( module, skip_fp8_weight_update, save_original_input, + fine_grained_activation_offloading, *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -209,6 +211,27 @@ def forward( if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading + + if fine_grained_activation_offloading and cpu_offloading: + raise ValueError( + "Do not use both fine_grained_activation_offloading and cpu_offloading." + ) + + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading + if ( + fine_grained_activation_offloading + and weights[0].requires_grad + and fuse_wgrad_accumulation + ): + grad_added_to_main_grad_list = [] + for weight in weights: + if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): + grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad) + weight.grad_added_to_main_grad = True + ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list + tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, @@ -271,11 +294,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + # Restore the attributes main_grad and grad_added_to_main_grad of weights + if ( + ctx.cpu_offloading or ctx.fine_grained_activation_offloading + ) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - w.main_grad = main_grads[i] - weights[i] = w + origin_weights[i].main_grad = main_grads[i] + origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) @@ -426,18 +451,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -481,6 +503,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, *wgrad_list, *grad_biases, ) @@ -562,6 +585,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, + fine_grained_activation_offloading: bool = False, delay_wgrad_compute: bool = False, save_original_input: bool = False, ) -> None: @@ -585,6 +609,8 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.fine_grained_activation_offloading = fine_grained_activation_offloading + self.wgrad_store = WeightGradStore(delay_wgrad_compute) self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} @@ -803,6 +829,7 @@ def forward( self, skip_fp8_weight_update, self.save_original_input, + self.fine_grained_activation_offloading, *weight_tensors, *bias_tensors, ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e1c0eab2dc..57f4e25eba 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -124,6 +124,7 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_name: str, + fine_grained_activation_offloading: bool, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, @@ -434,10 +435,32 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading + + if fine_grained_activation_offloading and cpu_offloading: + raise ValueError( + "Do not use both fine_grained_activation_offloading and cpu_offloading." + ) + + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): + if hasattr(weight, "grad_added_to_main_grad"): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + ctx.weight_object = weight + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -570,9 +593,12 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + # Restore the attributes grad_added_to_main_grad of weights + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: + if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object + if ctx.fine_grained_activation_offloading: + origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad @@ -1041,6 +1067,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # fine_grained_activation_offloading None, # fsdp_group None, # debug None, # module @@ -1176,6 +1203,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: str = None, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1194,7 +1222,7 @@ def __init__( ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name @@ -1595,6 +1623,7 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_name, + self.fine_grained_activation_offloading, self.fsdp_group, self, skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 02872439a3..7b974e4a7f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -111,6 +111,7 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, + fine_grained_activation_offloading: bool, fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -408,10 +409,32 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading + + if fine_grained_activation_offloading and cpu_offloading: + raise ValueError( + "Do not use both fine_grained_activation_offloading and cpu_offloading." + ) + + # Record the attributes grad_added_to_main_grad of weights for backward pass + # since these attributes will be lost during offloading + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): + if hasattr(weight, "grad_added_to_main_grad"): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + ctx.weight_object = weight + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -507,9 +530,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + # Restore the attributes main_grad and grad_added_to_main_grad of weights + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: + if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object + if ctx.fine_grained_activation_offloading: + weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: weight.main_grad = main_grad @@ -992,6 +1018,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # fine_grained_activation_offloading None, # fp8_output None, # fsdp_group None, # module @@ -1114,6 +1141,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1129,7 +1157,7 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input self.name = name - + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": @@ -1473,6 +1501,7 @@ def forward( self.ub_bulk_dgrad, self.ub_bulk_wgrad, self.ub_name, + self.fine_grained_activation_offloading, fp8_output, self.fsdp_group, self,