Update bert_layers.py
Browse files- bert_layers.py +2 -0
bert_layers.py
CHANGED
@@ -196,6 +196,8 @@ class BertUnpadSelfAttention(nn.Module):
|
|
196 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
197 |
print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
|
198 |
print(f'ATTENTION: {attention.shape}')
|
|
|
|
|
199 |
rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
|
200 |
try:
|
201 |
print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
|
|
196 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
197 |
print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
|
198 |
print(f'ATTENTION: {attention.shape}')
|
199 |
+
|
200 |
+
print(f'PROBLEM HERE: UNDERSTAND IT!!')
|
201 |
rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
|
202 |
try:
|
203 |
print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|