Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions transformer_engine/common/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def scaling_factor_compute(amax: Tensor,

def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert self.override_linear_precision in (
(False, False, False),
(False, False, True),
), "Only wgrad GEMM override is currently supported."
#assert self.override_linear_precision in (
# (False, False, False),
# (False, False, True),
#), "Only wgrad GEMM override is currently supported."
120 changes: 112 additions & 8 deletions transformer_engine/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def forward(
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
if not return_layernorm_output:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
Expand Down Expand Up @@ -753,7 +754,7 @@ def forward(
else:
ln_out_total = ln_out

if fp8:
if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
Expand Down Expand Up @@ -816,6 +817,32 @@ def forward(
use_bias=use_bias,
)

if fp8:
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)

ln_out = cast_to_fp8(
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)

if is_grad_enabled:
ctx.save_for_backward(
inputmat,
Expand Down Expand Up @@ -1413,11 +1440,9 @@ def forward(
), "Input and weight dimensions are not compatible for FP8 execution."

update_fp8_weights = is_first_microbatch is None or is_first_microbatch

# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_no_fp8 = inputmat

if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

Expand Down Expand Up @@ -1445,12 +1470,18 @@ def forward(
), None

# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
if not fp8_meta["recipe"].override_linear_precision.fprop:
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
else:
inputmat_total = inputmat
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat_no_fp8, tp_group)
else:
inputmat_total = inputmat_no_fp8

if fp8:
if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
Expand Down Expand Up @@ -1514,6 +1545,26 @@ def forward(
use_bias=use_bias,
)

if fp8:
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)

if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward(
Expand Down Expand Up @@ -2040,6 +2091,8 @@ def forward(
# of an extra fp8 cast.
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
if not return_layernorm_output:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
Expand Down Expand Up @@ -2091,7 +2144,7 @@ def forward(
else:
ln_out_total = ln_out

if fp8:
if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
Expand Down Expand Up @@ -2222,6 +2275,57 @@ def forward(
bias=fc2_bias,
use_bias=use_bias,
)

if fp8:
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=fc1_weight_fp8,
transpose_out=fc1_weight_t_fp8,
)

fp8_cast_transpose_fused(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
cast_out=fc2_weight_fp8,
transpose_out=fc2_weight_t_fp8,
)
else:
fc1_weight_t_fp8 = None
fc1_weight_fp8 = cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
fc2_weight_fp8 = cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
)

ln_out = cast_to_fp8(
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)

gelu_out = cast_to_fp8(
gelu_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
)

if is_grad_enabled:
ctx.save_for_backward(
inputmat,
Expand Down