jaandoui commited on
Commit
969245c
1 Parent(s): 65dd5c9

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +13 -3
bert_layers.py CHANGED
@@ -169,9 +169,12 @@ class BertUnpadSelfAttention(nn.Module):
169
  self.attention_head_size)
170
  attention_scores = attention_scores + bias
171
  attention_probs = nn.functional.softmax(attention_scores, dim=-1)
 
172
  attention_probs = self.dropout(attention_probs)
 
173
  attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
174
  3) # b s h d
 
175
  else:
176
  # Triton implementation only supports 0 attention dropout
177
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
@@ -182,13 +185,16 @@ class BertUnpadSelfAttention(nn.Module):
182
  bias_dtype = bias.dtype
183
  bias = bias.to(torch.float16)
184
  attention = flash_attn_qkvpacked_func(qkv, bias)
 
185
  attention = attention.to(orig_dtype)
 
186
  bias = bias.to(bias_dtype)
187
  else:
188
  attention = flash_attn_qkvpacked_func(qkv, bias)
189
-
190
  # attn_mask is 1 for attend and 0 for don't
191
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
 
192
  return rearrange(attention, 'nnz h d -> nnz (h d)')
193
 
194
 
@@ -329,7 +335,9 @@ class BertLayer(nn.Module):
329
  """
330
  attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
331
  subset_idx, indices, attn_mask, bias)
 
332
  layer_output = self.mlp(attention_output)
 
333
  return layer_output, attention_output # JAANDOUI: this only returns layer_output in the original work.
334
 
335
 
@@ -350,7 +358,7 @@ class BertEncoder(nn.Module):
350
  [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
351
 
352
  self.num_attention_heads = config.num_attention_heads
353
-
354
  # The alibi mask will be dynamically expanded if it is too small for
355
  # the input the model receives. But it generally helps to initialize it
356
  # to a reasonably large size to help pre-allocate CUDA memory.
@@ -937,6 +945,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
937
 
938
  if not return_dict:
939
  # JAANDOUI TODO maybe.
 
940
  output = (logits,) + outputs[2:]
941
  return ((loss,) + output) if loss is not None else output
942
 
@@ -947,6 +956,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
947
  logits=logits,
948
  hidden_states=outputs[0],
949
  #JAANDOUI: returning all_attention_weights here
950
- attentions=torch.stack(outputs[2], dim=0),
 
951
  )
952
 
 
169
  self.attention_head_size)
170
  attention_scores = attention_scores + bias
171
  attention_probs = nn.functional.softmax(attention_scores, dim=-1)
172
+ print(f'BUSA: attention_probs 1 shape: {attention_probs.shape}')
173
  attention_probs = self.dropout(attention_probs)
174
+ print(f'BUSA: attention_probs 2 shape: {attention_probs.shape}')
175
  attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
176
  3) # b s h d
177
+ print(f'BUSA: attention shape: {attention.shape}')
178
  else:
179
  # Triton implementation only supports 0 attention dropout
180
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
 
185
  bias_dtype = bias.dtype
186
  bias = bias.to(torch.float16)
187
  attention = flash_attn_qkvpacked_func(qkv, bias)
188
+ print(f'BUSA Triton: attention 0 shape: {attention_probs.shape}')
189
  attention = attention.to(orig_dtype)
190
+ print(f'BUSA Triton: attention 1 shape: {attention_probs.shape}')
191
  bias = bias.to(bias_dtype)
192
  else:
193
  attention = flash_attn_qkvpacked_func(qkv, bias)
194
+ print(f'BUSA Triton: attention 2 shape: {attention_probs.shape}')
195
  # attn_mask is 1 for attend and 0 for don't
196
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
197
+ print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
198
  return rearrange(attention, 'nnz h d -> nnz (h d)')
199
 
200
 
 
335
  """
336
  attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
337
  subset_idx, indices, attn_mask, bias)
338
+ print(f'BertLayer attention_output shape: {attention_output}')
339
  layer_output = self.mlp(attention_output)
340
+ print(f'BertLayer layer_output shape: {layer_output}')
341
  return layer_output, attention_output # JAANDOUI: this only returns layer_output in the original work.
342
 
343
 
 
358
  [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
359
 
360
  self.num_attention_heads = config.num_attention_heads
361
+ print(f'nbr of attention heads: {self.num_attention_heads}')
362
  # The alibi mask will be dynamically expanded if it is too small for
363
  # the input the model receives. But it generally helps to initialize it
364
  # to a reasonably large size to help pre-allocate CUDA memory.
 
945
 
946
  if not return_dict:
947
  # JAANDOUI TODO maybe.
948
+ print(f'return_dict is {return_dict}')
949
  output = (logits,) + outputs[2:]
950
  return ((loss,) + output) if loss is not None else output
951
 
 
956
  logits=logits,
957
  hidden_states=outputs[0],
958
  #JAANDOUI: returning all_attention_weights here
959
+ # attentions=torch.stack(outputs[2], dim=0),
960
+ attentions=torch.stack(outputs[2], dim=0), # JAANDOUI TODO: should I stack here ????
961
  )
962