Update bert_layers.py
Browse files- bert_layers.py +2 -1
bert_layers.py
CHANGED
@@ -954,10 +954,11 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
954 |
print(f'not stacked final attention SHAPE: {outputs[2].shape}')
|
955 |
except:
|
956 |
print(f'not stacked final attention LEN: {len(outputs[2])}')
|
957 |
-
|
958 |
print(f'STACKED final attention SHAPE: {(torch.stack(outputs[2], dim=0)).shape}')
|
959 |
except:
|
960 |
print(f'STACKED final attention LEN: {len(torch.stack(outputs[2], dim=0))}')
|
|
|
961 |
return SequenceClassifierOutput(
|
962 |
loss=loss,
|
963 |
logits=logits,
|
|
|
954 |
print(f'not stacked final attention SHAPE: {outputs[2].shape}')
|
955 |
except:
|
956 |
print(f'not stacked final attention LEN: {len(outputs[2])}')
|
957 |
+
try:
|
958 |
print(f'STACKED final attention SHAPE: {(torch.stack(outputs[2], dim=0)).shape}')
|
959 |
except:
|
960 |
print(f'STACKED final attention LEN: {len(torch.stack(outputs[2], dim=0))}')
|
961 |
+
|
962 |
return SequenceClassifierOutput(
|
963 |
loss=loss,
|
964 |
logits=logits,
|