Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions transformer_engine/common/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def scaling_factor_compute(amax: Tensor,

def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert self.override_linear_precision in (
(False, False, False),
(False, False, True),
), "Only wgrad GEMM override is currently supported."
assert self.override_linear_precision[0] == False, "Fprop GEMM Override is disabled"
#assert self.override_linear_precision in (
# (False, False, False),
# (False, False, True),
#), "Only wgrad GEMM override is currently supported."
154 changes: 99 additions & 55 deletions transformer_engine/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,20 +906,31 @@ def backward(
ctx.fp8_meta["recipe"], fprop_tensor=False
)

# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad:
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# Is weight already in bf16?
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm(
Expand Down Expand Up @@ -1598,20 +1609,30 @@ def backward(
ctx.fp8_meta["recipe"], fprop_tensor=False
)

# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad:
# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
else:
# DGRAD
dgrad, _, _ = gemm(
Expand Down Expand Up @@ -2321,21 +2342,33 @@ def backward(
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad:
# FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
fc2_dgrad, _, _ = gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=not ctx.bias_gelu_nvfusion,
grad=True,
gelu_input=fc1_out,
)

# FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)

# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
Expand Down Expand Up @@ -2402,19 +2435,30 @@ def backward(
dgelu_t = None

# FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad:
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
#NOTE: This will not work if wgrad is in FP8
fc1_dgrad, _, _ = gemm(
fc1_weight,
dgelu_no_fp8,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
else:
# FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm(
Expand Down