From cdfd8ad8f3b81e9c47261bfd6cf0606b616d79d1 Mon Sep 17 00:00:00 2001 From: ANMOL GUPTA Date: Tue, 18 Apr 2023 22:33:42 -0700 Subject: [PATCH] override dgrad precision --- transformer_engine/common/recipe.py | 9 +- transformer_engine/pytorch/module.py | 154 +++++++++++++++++---------- 2 files changed, 104 insertions(+), 59 deletions(-) diff --git a/transformer_engine/common/recipe.py b/transformer_engine/common/recipe.py index 3bb5320475..2bba4555da 100644 --- a/transformer_engine/common/recipe.py +++ b/transformer_engine/common/recipe.py @@ -133,7 +133,8 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert self.override_linear_precision in ( - (False, False, False), - (False, False, True), - ), "Only wgrad GEMM override is currently supported." + assert self.override_linear_precision[0] == False, "Fprop GEMM Override is disabled" + #assert self.override_linear_precision in ( + # (False, False, False), + # (False, False, True), + #), "Only wgrad GEMM override is currently supported." diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 516081c7b2..ef8672748d 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -906,20 +906,31 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) - # DGRAD: Evaluated unconditionally to feed into Linear backward - dgrad = fp8_gemm( - weight_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ) + if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad: + # DGRAD: Evaluated unconditionally to feed into Linear backward + dgrad = fp8_gemm( + weight_t_fp8, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + grad_output_c, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + use_split_accumulator=_2X_ACC_DGRAD, + ) + else: + # Is weight already in bf16? + dgrad, _, _ = gemm( + weight, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NN", + grad=True, + ) else: # DGRAD: Evaluated unconditionally to feed into Linear backward dgrad, _, _ = gemm( @@ -1598,20 +1609,30 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) - # DGRAD - dgrad = fp8_gemm( - weight_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ) + if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad: + # DGRAD + dgrad = fp8_gemm( + weight_t_fp8, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + grad_output_c, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + use_split_accumulator=_2X_ACC_DGRAD, + ) + else: + dgrad, _, _ = gemm( + weight, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NN", + grad=True, + ) else: # DGRAD dgrad, _, _ = gemm( @@ -2321,21 +2342,33 @@ def backward( fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) + if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad: + # FC2 DGRAD; Unconditional + fc2_dgrad = fp8_gemm( + fc2_weight_t_fp8, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype_forward, + grad_output_c, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + use_split_accumulator=_2X_ACC_DGRAD, + ) + else: + fc2_dgrad, _, _ = gemm( + fc2_weight, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NN", + gelu=not ctx.bias_gelu_nvfusion, + grad=True, + gelu_input=fc1_out, + ) - # FC2 DGRAD; Unconditional - fc2_dgrad = fp8_gemm( - fc2_weight_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype_forward, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ) # FC2 WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: @@ -2402,19 +2435,30 @@ def backward( dgelu_t = None # FC1 DGRAD: Unconditional - fc1_dgrad = fp8_gemm( - fc1_weight_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - dgelu, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ) + if not ctx.fp8_meta["recipe"].override_linear_precision.dgrad: + fc1_dgrad = fp8_gemm( + fc1_weight_t_fp8, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + dgelu, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT2, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + use_split_accumulator=_2X_ACC_DGRAD, + ) + else: + #NOTE: This will not work if wgrad is in FP8 + fc1_dgrad, _, _ = gemm( + fc1_weight, + dgelu_no_fp8, + ctx.activation_dtype, + get_workspace(), + layout="NN", + grad=True, + ) else: # FC2 DGRAD; Unconditional fc2_dgrad, _, _ = gemm(