From 40649698d143f3b40b070581b31d2db0ac3d70d0 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Mon, 8 Aug 2022 18:26:32 -0400 Subject: [PATCH 1/8] add multi-query attention logic in attention module --- megatron/arguments.py | 4 ++ megatron/model/transformer.py | 87 +++++++++++++++++++++++++++++------ 2 files changed, 76 insertions(+), 15 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 3261398..66663aa 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -383,6 +383,10 @@ def _add_network_size_args(parser): 'attention. This is set to ' ' args.hidden_size // args.num_attention_heads ' 'if not provided.') + group.add_argument('--attention-head-type', type=str, default='multihead', + choices=['multihead', 'multiquery'], + help='Type of attention heads. `multihead` is the standard multi-head attention.' + '`multiquery` shares the values and keys across attention heads') group.add_argument('--max-position-embeddings', type=int, default=None, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index b9c1b79..adb0b84 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -26,7 +26,7 @@ from megatron.model import LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl -from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu +from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_linear_layer """ We use the following notation throughout this file: @@ -214,11 +214,12 @@ def __init__(self, layer_number, self.attention_dropout = torch.nn.Dropout(args.attention_dropout) def forward(self, query_layer, key_layer, - value_layer, attention_mask): + value_layer, attention_mask, expand_key_value=False): # =================================== # Raw attention scores. [b, np, s, s] # =================================== + np = query_layer.size(2) # [b, np, sq, sk] output_size = (query_layer.size(1), @@ -229,9 +230,15 @@ def forward(self, query_layer, key_layer, # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, 1, hn] -> [sk, b * np, hn] + # TODO: Check that we indeed get the speedup at inference. Isn't the reshape memory allocation a bottleneck? + if expand_key_value: + key_layer = key_layer.expand(output_size[3], output_size[0], np, -1) + key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1) # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], - output_size[0] * output_size[1], -1) + else: + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = get_global_memory_buffer().get_tensor( @@ -274,13 +281,18 @@ def forward(self, query_layer, key_layer, # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), - value_layer.size(2), + np, query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), - output_size[0] * output_size[1], -1) + # [sk, b, 1, hn] -> [sk, b * np, hn] + if expand_key_value: + value_layer = value_layer.expand(value_layer.size(0), value_layer.size(1), np, -1) + value_layer = value_layer.reshape(value_layer.size(0), value_layer.size(1) * np, -1) + else: + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], @@ -320,6 +332,7 @@ def __init__(self, init_method, self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype + self.attention_head_type = args.attention_head_type projection_size = args.kv_channels * args.num_attention_heads @@ -331,12 +344,28 @@ def __init__(self, init_method, args.num_attention_heads, world_size) # Strided linear layer. - if attention_type == AttnType.self_attn: + if attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) + elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': + self.query = mpu.ColumnParallelLinear( + args.hidden_size, + projection_size, + gather_output=False, + init_method=init_method) + # In MultiQuery attention, keys and values are shared across heads + # Use args.kv_channels instead of projection_size + # No `.fork()` so the rng tracker is shared across tensor-parallel processes. + # with mpu.get_cuda_rng_tracker(): + self.key_value = get_linear_layer( + args.hidden_size, + 2 * args.kv_channels, + init_method=init_method) + print(f"KV WEIGHT {layer_number}", self.key_value.weight) + # TODO: add elif block for cross_attn and multiquery? else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( @@ -364,7 +393,7 @@ def __init__(self, init_method, skip_bias_add=True) def _checkpointed_attention_forward(self, query_layer, key_layer, - value_layer, attention_mask): + value_layer, attention_mask, expand_key_value): """Forward method with activation checkpointing.""" def custom_forward(*inputs): query_layer = inputs[0] @@ -372,7 +401,7 @@ def custom_forward(*inputs): value_layer = inputs[2] attention_mask = inputs[3] output_ = self.core_attention(query_layer, key_layer, - value_layer, attention_mask) + value_layer, attention_mask, expand_key_value) return output_ hidden_states = mpu.checkpoint( @@ -385,11 +414,12 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, - self.num_attention_heads_per_partition, + self.num_attention_heads_per_partition if self.attention_head_type == "multihead" else 1, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) + def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None): # hidden_states: [sq, b, h] @@ -415,7 +445,7 @@ def forward(self, hidden_states, attention_mask, # Query, Key, and Value # ===================== - if self.attention_type == AttnType.self_attn: + if self.attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) @@ -429,6 +459,33 @@ def forward(self, hidden_states, attention_mask, (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) + elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': + # Attention heads [sq, b, h] --> [sq, b, (2 * hn)] + mixed_kv_layer = self.key_value(hidden_states) + + # [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn] + # new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + # (self.num_attention_heads_per_partition, + # 2 * self.hidden_size_per_attention_head) + # mixed_kv_layer = mixed_kv_layer.unsqueeze(2).expand(*new_tensor_shape) + + # [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (1, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn] + (key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, np * hn] + query_layer, _ = self.query(hidden_states) + # [sq, b, np * hn] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) @@ -478,10 +535,10 @@ def forward(self, hidden_states, attention_mask, if self.checkpoint_core_attention: context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask) + query_layer, key_layer, value_layer, attention_mask, expand_key_value=True) else: context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask) + query_layer, key_layer, value_layer, attention_mask, expand_key_value=True) # ================= # Output. [sq, b, h] From 190e328617b629c10404822a261835e6e99f7f16 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Tue, 9 Aug 2022 11:56:00 -0400 Subject: [PATCH 2/8] add kv weight gradient reduction in tensor-parallel group --- megatron/mpu/layers.py | 3 ++- megatron/optimizer/optimizer.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 3ee9db2..ac78c3a 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -264,7 +264,8 @@ def backward(ctx, grad_output): handle.wait() # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], + # TODO: Is the reshape preventing us from getting a speedup here? + grad_output = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index b265145..6e83e65 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -265,6 +265,19 @@ def allreduce_embedding_grads(self, args): """All-reduce both word and position embeddings.""" self.allreduce_word_embedding_grads(args) self.allreduce_position_embedding_grads(args) + + def allreduce_key_value_grads(self, args): + # TODO: models[0] ? + unwrapped_model = self.models[0] + unwrapped_model = unwrap_model( + unwrapped_model, (torchDDP, LocalDDP, Float16Module)) + for layer in unwrapped_model.language_model.encoder.layers: + kv_weight = layer.self_attention.key_value.weight + if args.DDP_impl == 'local': + grad = kv_weight.main_grad + else: + grad = kv_weight.grad + torch.distributed.all_reduce(grad, group=mpu.get_tensor_model_parallel_group()) def allreduce_layernorm_grads(self, args): @@ -310,6 +323,13 @@ def reduce_model_grads(self, args, timers): self.allreduce_embedding_grads(args) timers('backward-embedding-all-reduce').stop() + # All-reduce key-value grads if needed. + if args.attention_head_type == "multiquery": + timers('backward-key-value-all-reduce').start() + self.allreduce_key_value_grads(args) + timers('backward-key-value-all-reduce').stop() + + class MixedPrecisionOptimizer(MegatronOptimizer): """Base class for both the float-16 and the distributed optimizer. From 6fd0c29dfe204bfa52f5c37602de954dae4e521b Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Wed, 10 Aug 2022 12:46:56 -0400 Subject: [PATCH 3/8] more efficient multiquery attention --- megatron/model/transformer.py | 168 ++++++++++++++++++++++++++++++++-- 1 file changed, 161 insertions(+), 7 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index adb0b84..aa776b2 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -215,7 +215,7 @@ def __init__(self, layer_number, def forward(self, query_layer, key_layer, value_layer, attention_mask, expand_key_value=False): - + timers = get_timers() # =================================== # Raw attention scores. [b, np, s, s] # =================================== @@ -230,27 +230,30 @@ def forward(self, query_layer, key_layer, # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, 1, hn] -> [sk, b * np, hn] - # TODO: Check that we indeed get the speedup at inference. Isn't the reshape memory allocation a bottleneck? + timers("CoreAttention: K view/reshape").start() if expand_key_value: + # [sk, b, 1, hn] -> [sk, b * np, hn] key_layer = key_layer.expand(output_size[3], output_size[0], np, -1) key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1) - # [sk, b, np, hn] -> [sk, b * np, hn] else: + # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + timers("CoreAttention: K view/reshape").stop() # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = get_global_memory_buffer().get_tensor( (output_size[0]*output_size[1], output_size[2], output_size[3]), query_layer.dtype, "mpu") + timers("CoreAttention: QK matmul").start() # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0/self.norm_factor)) + timers("CoreAttention: QK matmul").stop() # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -259,6 +262,7 @@ def forward(self, query_layer, key_layer, # Attention probs and dropout # =========================== + timers("CoreAttention: Softmax, dropout").start() # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) @@ -271,6 +275,7 @@ def forward(self, query_layer, key_layer, attention_probs = self.attention_dropout(attention_probs) else: attention_probs = self.attention_dropout(attention_probs) + timers("CoreAttention: Softmax, dropout").stop() # ========================= # Context layer. [sq, b, hp] @@ -285,6 +290,7 @@ def forward(self, query_layer, key_layer, query_layer.size(0), value_layer.size(3)) + timers("CoreAttention: V view/reshape").start() # [sk, b, 1, hn] -> [sk, b * np, hn] if expand_key_value: value_layer = value_layer.expand(value_layer.size(0), value_layer.size(1), np, -1) @@ -293,19 +299,137 @@ def forward(self, query_layer, key_layer, # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + timers("CoreAttention: V view/reshape").stop() # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + timers("CoreAttention: V matmul").start() # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + timers("CoreAttention: V matmul").stop() # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) + timers("CoreAttention: context contiguous").start() + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + timers("CoreAttention: context contiguous").stop() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class MultiQueryCoreAttention(CoreAttention): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, query_layer, key_layer, value_layer, attention_mask, expand_key_value=False): + timers = get_timers() + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + sq = query_layer.size(0) + bs = query_layer.size(1) + np = query_layer.size(2) + + sk = key_layer.size(0) + # Only one head for key and values + assert key_layer.size(2) == 1 and value_layer.size(2) == 1 + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + + timers("CoreAttention: K view/reshape").start() + # [sq, b, np, hn] -> [b, np * sq, hn] + query_layer = query_layer.permute([1, 2, 0, 3]).reshape(bs, np * sq, -1) + # [sk, b, 1, hn] -> [b, hn, sk] + key_layer = key_layer.squeeze(2).permute(1, 2, 0) + # [sk, b, 1, hn] -> [sk, b * np, hn] + # key_layer = key_layer.expand(output_size[3], output_size[0], np, -1) + # key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1) + + # preallocting input tensor: [b, np * sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (bs, np * sq, sk), + query_layer.dtype, "mpu") + timers("CoreAttention: K view/reshape").stop() + + timers("CoreAttention: QK matmul").start() + # Raw attention scores. [b, np * sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, # [b, np * sq, hn] + key_layer, # [b, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + timers("CoreAttention: QK matmul").stop() + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(bs, np, sq, sk) + + # =========================== + # Attention probs and dropout + # =========================== + + timers("CoreAttention: Softmax, dropout").start() + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + timers("CoreAttention: Softmax, dropout").stop() + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + np, + query_layer.size(0), + value_layer.size(3)) + + timers("CoreAttention: V view/reshape").start() + # [sk, b, 1, hn] -> [b, sk, hn] + value_layer = value_layer.squeeze(2).transpose(0, 1) + timers("CoreAttention: V view/reshape").stop() + + # change view [b, np * sq, sk] + attention_probs = attention_probs.view(bs, np * sq, -1) + + timers("CoreAttention: V matmul").start() + # matmul: [b, np * sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + timers("CoreAttention: V matmul").stop() + + # change view [b, np, sq, hn] + context_layer = context_layer.view(bs, np, sq, -1) + + timers("CoreAttention: context contiguous").start() # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + timers("CoreAttention: context contiguous").stop() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ @@ -380,8 +504,11 @@ def __init__(self, init_method, gather_output=False, init_method=init_method) - self.core_attention = CoreAttention(self.layer_number, - self.attn_mask_type) + if self.attention_head_type == 'multihead': + self.core_attention = CoreAttention(self.layer_number, + self.attn_mask_type) + else: + self.core_attention = MultiQueryCoreAttention(self.layer_number, self.attn_mask_type) self.checkpoint_core_attention = args.recompute_granularity == 'selective' # Output. @@ -423,10 +550,11 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size): def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None): # hidden_states: [sq, b, h] - + timers = get_timers() # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= + timers("inference_params init").start() if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len @@ -440,11 +568,13 @@ def forward(self, hidden_states, attention_mask, else: inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] + timers("inference_params init").stop() # ===================== # Query, Key, and Value # ===================== + timers("KV forward").start() if self.attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) @@ -486,6 +616,8 @@ def forward(self, hidden_states, attention_mask, (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) + + # [sq, b, np, hn] -> [b, np * sq, hn] else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) @@ -508,10 +640,15 @@ def forward(self, hidden_states, attention_mask, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) + timers("KV forward").stop() + # ================================== # Adjust key and value for inference # ================================== + + timers("Inference params").start() + if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) @@ -528,23 +665,30 @@ def forward(self, hidden_states, attention_mask, :sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[ :sequence_end, batch_start:batch_end, ...] + + timers("Inference params").stop() # ================================== # core attention computation # ================================== + timers("Core attention forward").start() + if self.checkpoint_core_attention: context_layer = self._checkpointed_attention_forward( query_layer, key_layer, value_layer, attention_mask, expand_key_value=True) else: context_layer = self.core_attention( query_layer, key_layer, value_layer, attention_mask, expand_key_value=True) + timers("Core attention forward").stop() # ================= # Output. [sq, b, h] # ================= + timers("dense").start() output, bias = self.dense(context_layer) + timers("dense").stop() return output, bias @@ -655,16 +799,19 @@ def __init__(self, init_method, output_layer_init_method, def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): + timers = get_timers() # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. + timers("attention forward").start() attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params) + timers("attention forward").stop() # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -722,7 +869,9 @@ def forward(self, hidden_states, attention_mask, layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. + timers("MLP forward").start() mlp_output, mlp_bias = self.mlp(layernorm_output) + timers("MLP forward").stop() # Second residual connection. if self.apply_residual_connection_post_layernorm: @@ -946,6 +1095,9 @@ def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [s, b, h] + timers = get_timers() + + timers("Transformer forward").start() # Checks. if inference_params: @@ -1003,4 +1155,6 @@ def forward(self, hidden_states, attention_mask, if self.post_process and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) + timers("Transformer forward").stop() + return hidden_states From 254ff4b826a89d0b22610dea20471d7a0d92fd8c Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 2 Sep 2022 11:04:41 -0400 Subject: [PATCH 4/8] raise if trying to uyse multi-query cross-atteention --- megatron/model/transformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index aa776b2..07e83e3 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -475,6 +475,7 @@ def __init__(self, init_method, gather_output=False, init_method=init_method) elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': + # TODO: Find a way to merge the query and key-value computations? self.query = mpu.ColumnParallelLinear( args.hidden_size, projection_size, @@ -488,9 +489,7 @@ def __init__(self, init_method, args.hidden_size, 2 * args.kv_channels, init_method=init_method) - print(f"KV WEIGHT {layer_number}", self.key_value.weight) - # TODO: add elif block for cross_attn and multiquery? - else: + elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multihead': assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( args.hidden_size, @@ -503,6 +502,8 @@ def __init__(self, init_method, 2 * projection_size, gather_output=False, init_method=init_method) + else: + raise NotImplementedError("Multiquery attention not implemented for cross-attention.") if self.attention_head_type == 'multihead': self.core_attention = CoreAttention(self.layer_number, From eaf617466219fc9354ed61cfb0ca6a7d18cce4b9 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 2 Sep 2022 11:18:11 -0400 Subject: [PATCH 5/8] remove expand_key_value parameter since CoreAttention for multi-query is now in a separate class --- megatron/model/transformer.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 07e83e3..bf43d00 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -214,7 +214,7 @@ def __init__(self, layer_number, self.attention_dropout = torch.nn.Dropout(args.attention_dropout) def forward(self, query_layer, key_layer, - value_layer, attention_mask, expand_key_value=False): + value_layer, attention_mask): timers = get_timers() # =================================== # Raw attention scores. [b, np, s, s] @@ -231,14 +231,9 @@ def forward(self, query_layer, key_layer, query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) timers("CoreAttention: K view/reshape").start() - if expand_key_value: - # [sk, b, 1, hn] -> [sk, b * np, hn] - key_layer = key_layer.expand(output_size[3], output_size[0], np, -1) - key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1) - else: - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], - output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) timers("CoreAttention: K view/reshape").stop() # preallocting input tensor: [b * np, sq, sk] @@ -291,14 +286,9 @@ def forward(self, query_layer, key_layer, value_layer.size(3)) timers("CoreAttention: V view/reshape").start() - # [sk, b, 1, hn] -> [sk, b * np, hn] - if expand_key_value: - value_layer = value_layer.expand(value_layer.size(0), value_layer.size(1), np, -1) - value_layer = value_layer.reshape(value_layer.size(0), value_layer.size(1) * np, -1) - else: - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), - output_size[0] * output_size[1], -1) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) timers("CoreAttention: V view/reshape").stop() # change view [b * np, sq, sk] @@ -331,7 +321,7 @@ class MultiQueryCoreAttention(CoreAttention): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def forward(self, query_layer, key_layer, value_layer, attention_mask, expand_key_value=False): + def forward(self, query_layer, key_layer, value_layer, attention_mask): timers = get_timers() # =================================== # Raw attention scores. [b, np, s, s] @@ -521,7 +511,7 @@ def __init__(self, init_method, skip_bias_add=True) def _checkpointed_attention_forward(self, query_layer, key_layer, - value_layer, attention_mask, expand_key_value): + value_layer, attention_mask): """Forward method with activation checkpointing.""" def custom_forward(*inputs): query_layer = inputs[0] @@ -529,7 +519,7 @@ def custom_forward(*inputs): value_layer = inputs[2] attention_mask = inputs[3] output_ = self.core_attention(query_layer, key_layer, - value_layer, attention_mask, expand_key_value) + value_layer, attention_mask) return output_ hidden_states = mpu.checkpoint( @@ -677,10 +667,10 @@ def forward(self, hidden_states, attention_mask, if self.checkpoint_core_attention: context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask, expand_key_value=True) + query_layer, key_layer, value_layer, attention_mask) else: context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask, expand_key_value=True) + query_layer, key_layer, value_layer, attention_mask) timers("Core attention forward").stop() # ================= From 15131378f67b8e27af3df6b20df62f1dbe78fdd0 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 2 Sep 2022 13:02:09 -0400 Subject: [PATCH 6/8] remove most timers --- megatron/arguments.py | 3 + megatron/model/transformer.py | 53 +--------- tools/text_generation_benchmark.py | 163 +++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 50 deletions(-) create mode 100644 tools/text_generation_benchmark.py diff --git a/megatron/arguments.py b/megatron/arguments.py index 66663aa..c5f4e42 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -454,6 +454,9 @@ def _add_logging_args(parser): help="Name of wandb entity for reporting") group.add_argument('--wandb-project-name', type=str, default=None, help="Name of wandb project") + group.add_argument('--transformer-timers', action='store_true', + help="If set, activate the timers within the transformer layers." + "Only for debugging, as this slows down the model.") return parser diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index bf43d00..9bcfb56 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -215,7 +215,6 @@ def __init__(self, layer_number, def forward(self, query_layer, key_layer, value_layer, attention_mask): - timers = get_timers() # =================================== # Raw attention scores. [b, np, s, s] # =================================== @@ -230,25 +229,21 @@ def forward(self, query_layer, key_layer, # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - timers("CoreAttention: K view/reshape").start() # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - timers("CoreAttention: K view/reshape").stop() # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = get_global_memory_buffer().get_tensor( (output_size[0]*output_size[1], output_size[2], output_size[3]), query_layer.dtype, "mpu") - timers("CoreAttention: QK matmul").start() # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0/self.norm_factor)) - timers("CoreAttention: QK matmul").stop() # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -257,7 +252,6 @@ def forward(self, query_layer, key_layer, # Attention probs and dropout # =========================== - timers("CoreAttention: Softmax, dropout").start() # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) @@ -270,7 +264,6 @@ def forward(self, query_layer, key_layer, attention_probs = self.attention_dropout(attention_probs) else: attention_probs = self.attention_dropout(attention_probs) - timers("CoreAttention: Softmax, dropout").stop() # ========================= # Context layer. [sq, b, hp] @@ -285,28 +278,22 @@ def forward(self, query_layer, key_layer, query_layer.size(0), value_layer.size(3)) - timers("CoreAttention: V view/reshape").start() # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - timers("CoreAttention: V view/reshape").stop() # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - timers("CoreAttention: V matmul").start() # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - timers("CoreAttention: V matmul").stop() # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) - timers("CoreAttention: context contiguous").start() # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - timers("CoreAttention: context contiguous").stop() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ @@ -322,7 +309,6 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def forward(self, query_layer, key_layer, value_layer, attention_mask): - timers = get_timers() # =================================== # Raw attention scores. [b, np, s, s] # =================================== @@ -340,8 +326,6 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): query_layer.size(0), key_layer.size(0)) - - timers("CoreAttention: K view/reshape").start() # [sq, b, np, hn] -> [b, np * sq, hn] query_layer = query_layer.permute([1, 2, 0, 3]).reshape(bs, np * sq, -1) # [sk, b, 1, hn] -> [b, hn, sk] @@ -354,16 +338,13 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): matmul_input_buffer = get_global_memory_buffer().get_tensor( (bs, np * sq, sk), query_layer.dtype, "mpu") - timers("CoreAttention: K view/reshape").stop() - timers("CoreAttention: QK matmul").start() # Raw attention scores. [b, np * sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, query_layer, # [b, np * sq, hn] key_layer, # [b, hn, sk] beta=0.0, alpha=(1.0/self.norm_factor)) - timers("CoreAttention: QK matmul").stop() # change view to [b, np, sq, sk] attention_scores = matmul_result.view(bs, np, sq, sk) @@ -372,7 +353,6 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # Attention probs and dropout # =========================== - timers("CoreAttention: Softmax, dropout").start() # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) @@ -385,7 +365,6 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): attention_probs = self.attention_dropout(attention_probs) else: attention_probs = self.attention_dropout(attention_probs) - timers("CoreAttention: Softmax, dropout").stop() # ========================= # Context layer. [sq, b, hp] @@ -400,26 +379,20 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): query_layer.size(0), value_layer.size(3)) - timers("CoreAttention: V view/reshape").start() # [sk, b, 1, hn] -> [b, sk, hn] value_layer = value_layer.squeeze(2).transpose(0, 1) - timers("CoreAttention: V view/reshape").stop() # change view [b, np * sq, sk] attention_probs = attention_probs.view(bs, np * sq, -1) - timers("CoreAttention: V matmul").start() # matmul: [b, np * sq, hn] context_layer = torch.bmm(attention_probs, value_layer) - timers("CoreAttention: V matmul").stop() # change view [b, np, sq, hn] context_layer = context_layer.view(bs, np, sq, -1) - timers("CoreAttention: context contiguous").start() # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - timers("CoreAttention: context contiguous").stop() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ @@ -541,11 +514,9 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size): def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None): # hidden_states: [sq, b, h] - timers = get_timers() # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= - timers("inference_params init").start() if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len @@ -559,13 +530,11 @@ def forward(self, hidden_states, attention_mask, else: inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] - timers("inference_params init").stop() # ===================== # Query, Key, and Value # ===================== - timers("KV forward").start() if self.attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) @@ -631,15 +600,11 @@ def forward(self, hidden_states, attention_mask, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) - timers("KV forward").stop() - # ================================== # Adjust key and value for inference # ================================== - timers("Inference params").start() - if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) @@ -656,30 +621,22 @@ def forward(self, hidden_states, attention_mask, :sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[ :sequence_end, batch_start:batch_end, ...] - - timers("Inference params").stop() # ================================== # core attention computation # ================================== - timers("Core attention forward").start() - if self.checkpoint_core_attention: context_layer = self._checkpointed_attention_forward( query_layer, key_layer, value_layer, attention_mask) else: context_layer = self.core_attention( query_layer, key_layer, value_layer, attention_mask) - timers("Core attention forward").stop() # ================= # Output. [sq, b, h] # ================= - - timers("dense").start() output, bias = self.dense(context_layer) - timers("dense").stop() return output, bias @@ -790,19 +747,16 @@ def __init__(self, init_method, output_layer_init_method, def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): - timers = get_timers() # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. - timers("attention forward").start() attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params) - timers("attention forward").stop() # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -860,9 +814,7 @@ def forward(self, hidden_states, attention_mask, layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. - timers("MLP forward").start() mlp_output, mlp_bias = self.mlp(layernorm_output) - timers("MLP forward").stop() # Second residual connection. if self.apply_residual_connection_post_layernorm: @@ -1087,8 +1039,9 @@ def forward(self, hidden_states, attention_mask, inference_params=None): # hidden_states: [s, b, h] timers = get_timers() + args = get_args() - timers("Transformer forward").start() + if args.transformer_timers: timers("Transformer forward").start() # Checks. if inference_params: @@ -1146,6 +1099,6 @@ def forward(self, hidden_states, attention_mask, if self.post_process and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) - timers("Transformer forward").stop() + if args.transformer_timers: timers("Transformer forward").stop() return hidden_states diff --git a/tools/text_generation_benchmark.py b/tools/text_generation_benchmark.py new file mode 100644 index 0000000..ee458f3 --- /dev/null +++ b/tools/text_generation_benchmark.py @@ -0,0 +1,163 @@ + +"""Sample Generate GPT""" +import os +import sys +import re +sys.path.append(os.path.abspath(os.path.join( + os.getcwd(), + "Megatron-LM", +))) +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import mpu +from megatron.checkpointing import load_checkpoint +from megatron.initialize import initialize_megatron +from megatron.model import GPTModel +from megatron.training import get_model +from megatron.text_generation import generate_and_post_process +import torch +from human_eval.data import write_jsonl, read_problems +from tqdm import tqdm + + +GENERATE_NUM = 0 + +# End on unindented code +# EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"] + + +BATCH_SIZE = 512 +TOKENS_TO_GENERATE = 128 +PROMPT_LENGTH = 128 +NUM_BATCHES = 8 + + +# NUM_SAMPLES_PER_TASK = 5 +# # Number of human-eval tasks +# NUM_TASKS = 200 + +def send_do_generate(): + choice = torch.cuda.LongTensor([GENERATE_NUM]) + torch.distributed.broadcast(choice, 0) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) + + return model + +def get_batches(prompts, batch_size): + for start_idx in tqdm(range(0, len(prompts), batch_size)): + actual_batch_size = min(batch_size, len(prompts) - start_idx) + yield prompts[start_idx: start_idx + actual_batch_size] + + +def unbatch(d: dict): + return [dict(zip(d.keys(), t)) for t in zip(*d.values())] + + +# Use fixed-length prompts +def load_evaluation_data(args): + # HumanEval data + # problems = read_problems() + + # batches = get_batches( + # [ + # problems[task_id]["prompt"] + # for task_id in problems + # for _ in range(5) + # ], + # BATCH_SIZE + # ) + # return batches + + prompt = " ".join(["one"] * PROMPT_LENGTH) + prompts = [prompt] * (BATCH_SIZE * NUM_BATCHES) + + batches = get_batches(prompts, BATCH_SIZE) + return batches + + +if __name__ == "__main__": + # Initialize Megatron + initialize_megatron(extra_args_provider=None, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer', + 'no_load_rng': True, + 'no_load_optim': True}) + + args = get_args() + timers = get_timers() + + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for text generation.") + exit() + # Setup model and load checkpoint + model = get_model(model_provider, wrap_with_ddp=False) + + if args.load is not None: + iteration = load_checkpoint(model, None, None, iteration=None) + else: + iteration = None + + assert len(model) == 1 + model = model[0] + + def generate(prompts): + response, response_seg, response_logprobs, tokens = \ + generate_and_post_process( + model, + prompts=prompts, + tokens_to_generate=TOKENS_TO_GENERATE, + return_output_log_probs=True, + use_eod_token_for_early_termination=False) + + assert all([r.startswith(p) for r, p in zip(response, prompts)]) + result = { + "response": response, + "response_seg": response_seg, + "raw_completion": [r[len(p):] for r, p in zip(response, prompts)] + } + # The "completion" field contains the string that is actually going to be evaluated by the HumanEval script + # result["completion"] = [post_process_completion(c) for c in result["raw_completion"]] + # Return a list of dicts + return unbatch(result) + + # if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + # server = MegatronServer(model) + # server.run("0.0.0.0") + + # while True: + # choice = torch.cuda.LongTensor(1) + # torch.distributed.broadcast(choice, 0) + # if choice[0].item() == 0: + # generate_and_post_process(model) + + + # Evaluation data iterator + batches = load_evaluation_data(args) + + timers('generate').start() + # Generate + samples = [ + generate_dict + for batch in batches + for generate_dict in generate(batch) + ] + timers('generate').stop() + + elapsed = timers.timers['generate'].elapsed(reset=False) + num_tokens = TOKENS_TO_GENERATE * NUM_BATCHES * BATCH_SIZE + print(f"{elapsed * 1000 / (num_tokens)} ms per token") + timers.log(['generate']) + if args.transformer_timers: + timers.log(["Transformer forward"]) + print("DONE") + + # Write results to file + # if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + # write_jsonl(args.output_file.format(iteration), samples) + From 5045d6f191480c87027132b90cd451f1a954503f Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Wed, 7 Sep 2022 10:22:15 -0400 Subject: [PATCH 7/8] resolve conflict --- megatron/model/transformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 31af14f..08e5fb2 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -228,12 +228,8 @@ def __init__(self, layer_number, self.attention_dropout = torch.nn.Dropout(args.attention_dropout) def forward(self, query_layer, key_layer, -<<<<<<< HEAD - value_layer, attention_mask): -======= value_layer, attention_mask, alibi): ->>>>>>> load-iter # =================================== # Raw attention scores. [b, np, s, s] # =================================== From 21170585ab38bce077f1e28d31405d861d3e4cb5 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Wed, 7 Sep 2022 11:15:45 -0400 Subject: [PATCH 8/8] implement alibi in multiquery core-attention --- megatron/model/transformer.py | 47 ++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 08e5fb2..57d6d9f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -254,6 +254,7 @@ def forward(self, query_layer, key_layer, (output_size[0]*output_size[1], output_size[2], output_size[3]), query_layer.dtype, "mpu") else: + # alibi: (batch_size * num_attention_heads, 1, max_seq_len) matmul_input_buffer = alibi[:output_size[0]*output_size[1], :, :output_size[3]] # Raw attention scores. [b * np, sq, sk] @@ -342,7 +343,7 @@ class MultiQueryCoreAttention(CoreAttention): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def forward(self, query_layer, key_layer, value_layer, attention_mask): + def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi): # =================================== # Raw attention scores. [b, np, s, s] # =================================== @@ -368,17 +369,39 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # key_layer = key_layer.expand(output_size[3], output_size[0], np, -1) # key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1) - # preallocting input tensor: [b, np * sq, sk] - matmul_input_buffer = get_global_memory_buffer().get_tensor( - (bs, np * sq, sk), - query_layer.dtype, "mpu") - - # Raw attention scores. [b, np * sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer, # [b, np * sq, hn] - key_layer, # [b, hn, sk] - beta=0.0, alpha=(1.0/self.norm_factor)) + if alibi is None: + # preallocting input tensor: [b, np * sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (bs, np * sq, sk), + query_layer.dtype, "mpu") + else: + # alibi: (batch_size * num_attention_heads, 1, max_seq_len) + # TODO: ideally, alibi would have the shape: (1, num_heads * sq, sk) + matmul_input_buffer = alibi[:bs * np, :, :sk].view(bs, np, sk) + matmul_input_buffer = matmul_input_buffer.repeat(1, sq, 1) # [b, np * sq, sk] + + if alibi is None: + # Raw attention scores. [b, np * sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, # [b, np * sq, hn] + key_layer, # [b, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + else: + if not hasattr(self, "logged_alibi"): + print("Using Alibi.") + self.logged_alibi = True + + if self.apply_query_key_layer_scaling: + beta = 1.0 / self.layer_number + else: + beta = 1.0 + + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, + key_layer, + beta=beta, alpha=(1.0 / self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(bs, np, sq, sk)