Update bert_layers.py
Browse files- bert_layers.py +5 -0
bert_layers.py
CHANGED
@@ -195,6 +195,11 @@ class BertUnpadSelfAttention(nn.Module):
|
|
195 |
# attn_mask is 1 for attend and 0 for don't
|
196 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
197 |
print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
|
|
|
|
|
|
|
|
|
|
|
198 |
return rearrange(attention, 'nnz h d -> nnz (h d)')
|
199 |
|
200 |
|
|
|
195 |
# attn_mask is 1 for attend and 0 for don't
|
196 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
197 |
print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
|
198 |
+
rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
|
199 |
+
try:
|
200 |
+
print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
201 |
+
except:
|
202 |
+
print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
|
203 |
return rearrange(attention, 'nnz h d -> nnz (h d)')
|
204 |
|
205 |
|