jaandoui commited on
Commit
b3284e9
1 Parent(s): 7f3b692

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +5 -0
bert_layers.py CHANGED
@@ -195,6 +195,11 @@ class BertUnpadSelfAttention(nn.Module):
195
  # attn_mask is 1 for attend and 0 for don't
196
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
197
  print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
 
 
 
 
 
198
  return rearrange(attention, 'nnz h d -> nnz (h d)')
199
 
200
 
 
195
  # attn_mask is 1 for attend and 0 for don't
196
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
197
  print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
198
+ rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
199
+ try:
200
+ print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
201
+ except:
202
+ print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
203
  return rearrange(attention, 'nnz h d -> nnz (h d)')
204
 
205