Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformer_engine.common.recipe import Recipe
from .base import (
get_multi_stream_cublas_workspace,
get_dummy_wgrad,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
Expand Down Expand Up @@ -80,6 +81,7 @@ def forward(
module,
skip_fp8_weight_update,
save_original_input,
fine_grained_activation_offloading,
*weights_and_biases,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -209,6 +211,27 @@ def forward(
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)

ctx.fine_grained_activation_offloading = fine_grained_activation_offloading

if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
"Do not use both fine_grained_activation_offloading and cpu_offloading."
)

# Record the attributes grad_added_to_main_grad of weights for backward pass
# since these attributes will be lost during offloading
if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list

tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
Expand Down Expand Up @@ -271,11 +294,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]

if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
# Restore the attributes main_grad and grad_added_to_main_grad of weights
if (
ctx.cpu_offloading or ctx.fine_grained_activation_offloading
) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
origin_weights[i].main_grad = main_grads[i]
origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]

# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
Expand Down Expand Up @@ -426,18 +451,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
Expand Down Expand Up @@ -481,6 +503,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
None,
None,
None,
None,
*wgrad_list,
*grad_biases,
)
Expand Down Expand Up @@ -562,6 +585,7 @@ def __init__(
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
delay_wgrad_compute: bool = False,
save_original_input: bool = False,
) -> None:
Expand All @@ -585,6 +609,8 @@ def __init__(
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name

self.fine_grained_activation_offloading = fine_grained_activation_offloading

self.wgrad_store = WeightGradStore(delay_wgrad_compute)

self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
Expand Down Expand Up @@ -803,6 +829,7 @@ def forward(
self,
skip_fp8_weight_update,
self.save_original_input,
self.fine_grained_activation_offloading,
*weight_tensors,
*bias_tensors,
)
Expand Down
39 changes: 34 additions & 5 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def forward(
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_name: str,
fine_grained_activation_offloading: bool,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
Expand Down Expand Up @@ -434,10 +435,32 @@ def forward(
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")

ctx.fine_grained_activation_offloading = fine_grained_activation_offloading

if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
"Do not use both fine_grained_activation_offloading and cpu_offloading."
)

# Record the attributes grad_added_to_main_grad of weights for backward pass
# since these attributes will be lost during offloading
if (
fine_grained_activation_offloading
and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False

if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

if ctx.grad_added_to_main_grad:
if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
Expand Down Expand Up @@ -570,9 +593,12 @@ def backward(

# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
# Restore the attributes grad_added_to_main_grad of weights
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading:
if ctx.has_grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.fine_grained_activation_offloading:
origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad

Expand Down Expand Up @@ -1041,6 +1067,7 @@ def wgrad_gemm(
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fine_grained_activation_offloading
None, # fsdp_group
None, # debug
None, # module
Expand Down Expand Up @@ -1176,6 +1203,7 @@ def __init__(
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: str = None,
fine_grained_activation_offloading: bool = False,
) -> None:
super().__init__()

Expand All @@ -1194,7 +1222,7 @@ def __init__(
)
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type

self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name

Expand Down Expand Up @@ -1595,6 +1623,7 @@ def forward(
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fine_grained_activation_offloading,
self.fsdp_group,
self,
skip_fp8_weight_update,
Expand Down
39 changes: 34 additions & 5 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def forward(
ub_bulk_dgrad: bool,
ub_bulk_wgrad: bool,
ub_name: str,
fine_grained_activation_offloading: bool,
fp8_output: bool, # pylint: disable=unused-argument
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
Expand Down Expand Up @@ -408,10 +409,32 @@ def forward(
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")

ctx.fine_grained_activation_offloading = fine_grained_activation_offloading

if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
"Do not use both fine_grained_activation_offloading and cpu_offloading."
)

# Record the attributes grad_added_to_main_grad of weights for backward pass
# since these attributes will be lost during offloading
if (
fine_grained_activation_offloading
and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False

if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

if ctx.grad_added_to_main_grad:
if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
Expand Down Expand Up @@ -507,9 +530,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
else None
)

if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
# Restore the attributes main_grad and grad_added_to_main_grad of weights
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading:
if ctx.has_grad_added_to_main_grad:
weight = ctx.weight_object
if ctx.fine_grained_activation_offloading:
weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad

Expand Down Expand Up @@ -992,6 +1018,7 @@ def wgrad_gemm(
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fine_grained_activation_offloading
None, # fp8_output
None, # fsdp_group
None, # module
Expand Down Expand Up @@ -1114,6 +1141,7 @@ def __init__(
symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False,
name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
) -> None:
super().__init__()

Expand All @@ -1129,7 +1157,7 @@ def __init__(
self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input
self.name = name

self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)

if device == "meta":
Expand Down Expand Up @@ -1473,6 +1501,7 @@ def forward(
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
self.fine_grained_activation_offloading,
fp8_output,
self.fsdp_group,
self,
Expand Down
Loading