Update bert_layers.py
Browse files- bert_layers.py +6 -0
bert_layers.py
CHANGED
@@ -248,6 +248,12 @@ class BertUnpadAttention(nn.Module):
|
|
248 |
"""
|
249 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
|
250 |
attn_mask, bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
if subset_idx is not None:
|
252 |
return self.output(index_first_axis(self_output, subset_idx),
|
253 |
index_first_axis(input_tensor, subset_idx))
|
|
|
248 |
"""
|
249 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
|
250 |
attn_mask, bias)
|
251 |
+
|
252 |
+
try:
|
253 |
+
print(f'IMPORTANT: {self_output.shape}')
|
254 |
+
except:
|
255 |
+
print(f'IMPORTANT2: {self_output[0].shape}')
|
256 |
+
|
257 |
if subset_idx is not None:
|
258 |
return self.output(index_first_axis(self_output, subset_idx),
|
259 |
index_first_axis(input_tensor, subset_idx))
|