From b88d2ec9619660f3035b10475227f9cd00e1c0ac Mon Sep 17 00:00:00 2001 From: ANMOL GUPTA Date: Tue, 25 Apr 2023 21:40:57 -0700 Subject: [PATCH 1/2] fprop bf16 --- transformer_engine/common/recipe.py | 8 +-- transformer_engine/pytorch/module.py | 97 +++++++++++++++++++++++++--- 2 files changed, 93 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/recipe.py b/transformer_engine/common/recipe.py index 3bb5320475..486e4a19e2 100644 --- a/transformer_engine/common/recipe.py +++ b/transformer_engine/common/recipe.py @@ -133,7 +133,7 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert self.override_linear_precision in ( - (False, False, False), - (False, False, True), - ), "Only wgrad GEMM override is currently supported." + #assert self.override_linear_precision in ( + # (False, False, False), + # (False, False, True), + #), "Only wgrad GEMM override is currently supported." diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 516081c7b2..66eea765c4 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -692,7 +692,7 @@ def forward( # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. - if fp8: + if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: @@ -753,7 +753,7 @@ def forward( else: ln_out_total = ln_out - if fp8: + if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -815,6 +815,25 @@ def forward( bias=bias, use_bias=use_bias, ) + + if fp8: + if update_fp8_weights: + if is_grad_enabled: + fp8_cast_transpose_fused( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=weight_fp8, + transpose_out=weight_t_fp8, + ) + else: + weight_t_fp8 = None + weight_fp8 = cast_to_fp8( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward) if is_grad_enabled: ctx.save_for_backward( @@ -1445,12 +1464,18 @@ def forward( ), None # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) + if not fp8_meta["recipe"].override_linear_precision.fprop: + if parallel_mode == "column" and sequence_parallel: + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) + else: + inputmat_total = inputmat else: - inputmat_total = inputmat + if parallel_mode == "column" and sequence_parallel: + inputmat_total, _ = gather_along_first_dim(inputmat_no_fp8, tp_group) + else: + inputmat_total = inputmat_no_fp8 - if fp8: + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -1513,6 +1538,26 @@ def forward( bias=bias, use_bias=use_bias, ) + + if fp8: + if update_fp8_weights: + if is_grad_enabled: + fp8_cast_transpose_fused( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=weight_fp8, + transpose_out=weight_t_fp8, + ) + else: + weight_t_fp8 = None + weight_fp8 = cast_to_fp8( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + ) if is_grad_enabled: fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad @@ -2038,7 +2083,7 @@ def forward( # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. - if fp8: + if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: if is_grad_enabled: @@ -2091,7 +2136,7 @@ def forward( else: ln_out_total = ln_out - if fp8: + if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -2222,6 +2267,42 @@ def forward( bias=fc2_bias, use_bias=use_bias, ) + + if fp8: + if update_fp8_weights: + if is_grad_enabled: + fp8_cast_transpose_fused( + fc1_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=fc1_weight_fp8, + transpose_out=fc1_weight_t_fp8, + ) + + fp8_cast_transpose_fused( + fc2_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype_forward, + cast_out=fc2_weight_fp8, + transpose_out=fc2_weight_t_fp8, + ) + else: + fc1_weight_t_fp8 = None + fc1_weight_fp8 = cast_to_fp8( + fc1_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + ) + fc2_weight_t_fp8 = None + fc2_weight_fp8 = cast_to_fp8( + fc2_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype_forward, + ) if is_grad_enabled: ctx.save_for_backward( inputmat, From 41f63406bde3cf07ef7b107f2e72dd47f983670b Mon Sep 17 00:00:00 2001 From: ANMOL GUPTA Date: Wed, 26 Apr 2023 15:20:36 -0700 Subject: [PATCH 2/2] fixes for fprop_bf16 --- transformer_engine/pytorch/module.py | 39 ++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 66eea765c4..58614a44fb 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -692,9 +692,10 @@ def forward( # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. - if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: + if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: if not return_layernorm_output: if is_grad_enabled: ln_out, mu, rsigma = layernorm_fwd_fp8( @@ -753,7 +754,7 @@ def forward( else: ln_out_total = ln_out - if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -815,7 +816,7 @@ def forward( bias=bias, use_bias=use_bias, ) - + if fp8: if update_fp8_weights: if is_grad_enabled: @@ -835,6 +836,13 @@ def forward( tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward) + ln_out = cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + if is_grad_enabled: ctx.save_for_backward( inputmat, @@ -1432,11 +1440,9 @@ def forward( ), "Input and weight dimensions are not compatible for FP8 execution." update_fp8_weights = is_first_microbatch is None or is_first_microbatch - # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_no_fp8 = inputmat - if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -1538,7 +1544,7 @@ def forward( bias=bias, use_bias=use_bias, ) - + if fp8: if update_fp8_weights: if is_grad_enabled: @@ -2083,8 +2089,10 @@ def forward( # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. - if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: + if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: if not return_layernorm_output: if is_grad_enabled: ln_out, mu, rsigma = layernorm_fwd_fp8( @@ -2136,7 +2144,7 @@ def forward( else: ln_out_total = ln_out - if fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.fprop: + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -2303,6 +2311,21 @@ def forward( tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, ) + + ln_out = cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + + gelu_out = cast_to_fp8( + gelu_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype_forward, + ) + if is_grad_enabled: ctx.save_for_backward( inputmat,