Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions transformer_engine/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,10 @@ def forward(
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
use_bf16_fprop: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
#print(f'====LayerNormLinear: {use_bf16_fprop}')
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
Expand All @@ -695,7 +697,7 @@ def forward(
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop):
if not return_layernorm_output:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
Expand Down Expand Up @@ -754,7 +756,7 @@ def forward(
else:
ln_out_total = ln_out

if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop):
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
Expand Down Expand Up @@ -1063,6 +1065,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -1313,6 +1316,7 @@ def forward(
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
use_bf16_fprop: bool = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
Expand Down Expand Up @@ -1388,6 +1392,7 @@ def forward(
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
use_bf16_fprop,
)
out = fwd_fn(*args)

Expand Down Expand Up @@ -1430,8 +1435,10 @@ def forward(
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
use_bf16_fprop: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
#print(f'++++Linear: {use_bf16_fprop}')
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
Expand Down Expand Up @@ -1470,7 +1477,7 @@ def forward(
), None

# Column Parallel Linear
if not fp8_meta["recipe"].override_linear_precision.fprop:
if not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop):
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
Expand All @@ -1481,7 +1488,7 @@ def forward(
else:
inputmat_total = inputmat_no_fp8

if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop):
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
Expand Down Expand Up @@ -1752,6 +1759,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -1954,6 +1962,7 @@ def forward(
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
use_bf16_fprop: bool = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
Expand Down Expand Up @@ -2022,6 +2031,7 @@ def forward(
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
use_bf16_fprop,
)
out = linear_fn(*args)

Expand Down Expand Up @@ -2070,8 +2080,10 @@ def forward(
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
use_bf16_fprop: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
#print(f'****LayerNormMLP: {use_bf16_fprop}')
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
Expand All @@ -2092,7 +2104,7 @@ def forward(
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop):
if not return_layernorm_output:
if is_grad_enabled:
ln_out, mu, rsigma = layernorm_fwd_fp8(
Expand Down Expand Up @@ -2144,7 +2156,7 @@ def forward(
else:
ln_out_total = ln_out

if fp8 and not fp8_meta["recipe"].override_linear_precision.fprop:
if fp8 and not (fp8_meta["recipe"].override_linear_precision.fprop or use_bf16_fprop):
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
Expand Down Expand Up @@ -2680,6 +2692,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -2919,7 +2932,7 @@ def reset_layer_norm_parameters(self) -> None:
init.zeros_(self.layer_norm_bias)

def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, use_bf16_fprop: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
Expand Down Expand Up @@ -2980,6 +2993,7 @@ def forward(
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
use_bf16_fprop,
)
out = fwd_fn(*args)

Expand Down
16 changes: 12 additions & 4 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def forward(
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
use_bf16_fprop: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
Expand Down Expand Up @@ -681,6 +682,7 @@ def forward(
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
use_bf16_fprop=use_bf16_fprop,
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
Expand All @@ -690,6 +692,7 @@ def forward(
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
use_bf16_fprop=use_bf16_fprop,
)

if self.qkv_weight_interleaved:
Expand Down Expand Up @@ -720,6 +723,7 @@ def forward(
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
use_bf16_fprop=use_bf16_fprop,
)

if self.qkv_weight_interleaved:
Expand Down Expand Up @@ -749,6 +753,7 @@ def forward(
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
use_bf16_fprop=use_bf16_fprop,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
Expand All @@ -758,6 +763,7 @@ def forward(
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
use_bf16_fprop=use_bf16_fprop,
)

# [sq, b, hp] --> [sq, b, np, hn]
Expand Down Expand Up @@ -807,7 +813,8 @@ def forward(
# =================

attention_output, attention_bias = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
context_layer, is_first_microbatch=is_first_microbatch,
use_bf16_fprop=use_bf16_fprop,
)

if self.input_layernorm and self.return_layernorm_output:
Expand Down Expand Up @@ -1118,6 +1125,7 @@ def forward(
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
use_bf16_fprop: bool = False,
) -> torch.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
Expand Down Expand Up @@ -1158,9 +1166,7 @@ def forward(
otherwise be occupied to store the forward activations until
backprop.
"""

hidden_states = hidden_states.contiguous()

if self.self_attn_mask_type != "causal" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
Expand All @@ -1179,6 +1185,7 @@ def forward(
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
use_bf16_fprop=use_bf16_fprop,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs
Expand Down Expand Up @@ -1217,6 +1224,7 @@ def forward(
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
use_bf16_fprop=use_bf16_fprop,
)
if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs
Expand All @@ -1231,7 +1239,7 @@ def forward(

# MLP.
mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch
bda_output, is_first_microbatch=is_first_microbatch, use_bf16_fprop=use_bf16_fprop,
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
Expand Down