From 939d13c99199fada498bbe30da931c5b95ce089a Mon Sep 17 00:00:00 2001 From: ANMOL GUPTA Date: Thu, 4 May 2023 11:09:28 -0700 Subject: [PATCH] custom bf16 fprop in TE --- transformer_engine/pytorch/module.py | 28 +++++++++++++++++------ transformer_engine/pytorch/transformer.py | 16 +++++++++---- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 58614a44fb..9606b2283f 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -673,8 +673,10 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + use_bf16_fprop: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible + #print(f'====LayerNormLinear: {use_bf16_fprop}') in_features = ln_weight.numel() assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) @@ -695,7 +697,7 @@ def forward( if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: + if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop): if not return_layernorm_output: if is_grad_enabled: ln_out, mu, rsigma = layernorm_fwd_fp8( @@ -754,7 +756,7 @@ def forward( else: ln_out_total = ln_out - if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: + if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop): bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -1063,6 +1065,7 @@ def backward( None, None, None, + None, ) @@ -1313,6 +1316,7 @@ def forward( weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, + use_bf16_fprop: bool = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1388,6 +1392,7 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + use_bf16_fprop, ) out = fwd_fn(*args) @@ -1430,8 +1435,10 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, + use_bf16_fprop: bool, ) -> torch.Tensor: # Make sure input dimensions are compatible + #print(f'++++Linear: {use_bf16_fprop}') in_features = weight.shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) @@ -1470,7 +1477,7 @@ def forward( ), None # Column Parallel Linear - if not fp8_meta["recipe"].override_linear_precision.fprop: + if not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop): if parallel_mode == "column" and sequence_parallel: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: @@ -1481,7 +1488,7 @@ def forward( else: inputmat_total = inputmat_no_fp8 - if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: + if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop): bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -1752,6 +1759,7 @@ def backward( None, None, None, + None, ) @@ -1954,6 +1962,7 @@ def forward( weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, + use_bf16_fprop: bool = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -2022,6 +2031,7 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), + use_bf16_fprop, ) out = linear_fn(*args) @@ -2070,8 +2080,10 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + use_bf16_fprop: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible + #print(f'****LayerNormMLP: {use_bf16_fprop}') in_features = ln_weight.numel() assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) @@ -2092,7 +2104,7 @@ def forward( if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: + if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop): if not return_layernorm_output: if is_grad_enabled: ln_out, mu, rsigma = layernorm_fwd_fp8( @@ -2144,7 +2156,7 @@ def forward( else: ln_out_total = ln_out - if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop: + if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop): bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -2680,6 +2692,7 @@ def backward( None, None, None, + None, ) @@ -2919,7 +2932,7 @@ def reset_layer_norm_parameters(self) -> None: init.zeros_(self.layer_norm_bias) def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, use_bf16_fprop: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -2980,6 +2993,7 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + use_bf16_fprop, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index fa00fb86fc..287ded5715 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -638,6 +638,7 @@ def forward( is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[Any] = None, + use_bf16_fprop: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """MultiHeadAttention FWD""" # hidden_states: [sq, b, h] @@ -681,6 +682,7 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, + use_bf16_fprop=use_bf16_fprop, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -690,6 +692,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, + use_bf16_fprop=use_bf16_fprop, ) if self.qkv_weight_interleaved: @@ -720,6 +723,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, + use_bf16_fprop=use_bf16_fprop, ) if self.qkv_weight_interleaved: @@ -749,6 +753,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, + use_bf16_fprop=use_bf16_fprop, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -758,6 +763,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, + use_bf16_fprop=use_bf16_fprop, ) # [sq, b, hp] --> [sq, b, np, hn] @@ -807,7 +813,8 @@ def forward( # ================= attention_output, attention_bias = self.proj( - context_layer, is_first_microbatch=is_first_microbatch + context_layer, is_first_microbatch=is_first_microbatch, + use_bf16_fprop=use_bf16_fprop, ) if self.input_layernorm and self.return_layernorm_output: @@ -1118,6 +1125,7 @@ def forward( is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[Any] = None, + use_bf16_fprop: bool = False, ) -> torch.Tensor: """ Transformer Layer: attention block and a feedforward network (MLP) @@ -1158,9 +1166,7 @@ def forward( otherwise be occupied to store the forward activations until backprop. """ - hidden_states = hidden_states.contiguous() - if self.self_attn_mask_type != "causal" and attention_mask is not None: assert ( attention_mask.dtype == torch.bool @@ -1179,6 +1185,7 @@ def forward( inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, + use_bf16_fprop=use_bf16_fprop, ) if self.apply_residual_connection_post_layernorm and not self.output_layernorm: attention_output, attention_bias, residual = self_attention_outputs @@ -1217,6 +1224,7 @@ def forward( encoder_output=encoder_output, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, + use_bf16_fprop=use_bf16_fprop, ) if self.apply_residual_connection_post_layernorm: attention_output, attention_bias, residual = inter_attention_outputs @@ -1231,7 +1239,7 @@ def forward( # MLP. mlp_outputs = self.layernorm_mlp( - bda_output, is_first_microbatch=is_first_microbatch + bda_output, is_first_microbatch=is_first_microbatch, use_bf16_fprop=use_bf16_fprop, ) if self.apply_residual_connection_post_layernorm: mlp_output, mlp_bias, residual = mlp_outputs