diff --git a/transformer_engine/common/recipe.py b/transformer_engine/common/recipe.py index 3bb5320475..486e4a19e2 100644 --- a/transformer_engine/common/recipe.py +++ b/transformer_engine/common/recipe.py @@ -133,7 +133,7 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert self.override_linear_precision in ( - (False, False, False), - (False, False, True), - ), "Only wgrad GEMM override is currently supported." + #assert self.override_linear_precision in ( + # (False, False, False), + # (False, False, True), + #), "Only wgrad GEMM override is currently supported." diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 516081c7b2..58614a44fb 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -695,6 +695,7 @@ def forward( if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: if not return_layernorm_output: if is_grad_enabled: ln_out, mu, rsigma = layernorm_fwd_fp8( @@ -753,7 +754,7 @@ def forward( else: ln_out_total = ln_out - if fp8: + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -816,6 +817,32 @@ def forward( use_bias=use_bias, ) + if fp8: + if update_fp8_weights: + if is_grad_enabled: + fp8_cast_transpose_fused( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=weight_fp8, + transpose_out=weight_t_fp8, + ) + else: + weight_t_fp8 = None + weight_fp8 = cast_to_fp8( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward) + + ln_out = cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + if is_grad_enabled: ctx.save_for_backward( inputmat, @@ -1413,11 +1440,9 @@ def forward( ), "Input and weight dimensions are not compatible for FP8 execution." update_fp8_weights = is_first_microbatch is None or is_first_microbatch - # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_no_fp8 = inputmat - if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -1445,12 +1470,18 @@ def forward( ), None # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) + if not fp8_meta["recipe"].override_linear_precision.fprop: + if parallel_mode == "column" and sequence_parallel: + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) + else: + inputmat_total = inputmat else: - inputmat_total = inputmat + if parallel_mode == "column" and sequence_parallel: + inputmat_total, _ = gather_along_first_dim(inputmat_no_fp8, tp_group) + else: + inputmat_total = inputmat_no_fp8 - if fp8: + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -1514,6 +1545,26 @@ def forward( use_bias=use_bias, ) + if fp8: + if update_fp8_weights: + if is_grad_enabled: + fp8_cast_transpose_fused( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=weight_fp8, + transpose_out=weight_t_fp8, + ) + else: + weight_t_fp8 = None + weight_fp8 = cast_to_fp8( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + ) + if is_grad_enabled: fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad ctx.save_for_backward( @@ -2040,6 +2091,8 @@ def forward( # of an extra fp8 cast. if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: if not return_layernorm_output: if is_grad_enabled: ln_out, mu, rsigma = layernorm_fwd_fp8( @@ -2091,7 +2144,7 @@ def forward( else: ln_out_total = ln_out - if fp8: + if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -2222,6 +2275,57 @@ def forward( bias=fc2_bias, use_bias=use_bias, ) + + if fp8: + if update_fp8_weights: + if is_grad_enabled: + fp8_cast_transpose_fused( + fc1_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=fc1_weight_fp8, + transpose_out=fc1_weight_t_fp8, + ) + + fp8_cast_transpose_fused( + fc2_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype_forward, + cast_out=fc2_weight_fp8, + transpose_out=fc2_weight_t_fp8, + ) + else: + fc1_weight_t_fp8 = None + fc1_weight_fp8 = cast_to_fp8( + fc1_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + ) + fc2_weight_t_fp8 = None + fc2_weight_fp8 = cast_to_fp8( + fc2_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype_forward, + ) + + ln_out = cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + + gelu_out = cast_to_fp8( + gelu_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype_forward, + ) + if is_grad_enabled: ctx.save_for_backward( inputmat,