jaandoui commited on
Commit
6fb83d0
1 Parent(s): 1ff80de

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +12 -8
bert_layers.py CHANGED
@@ -950,14 +950,18 @@ class BertForSequenceClassification(BertPreTrainedModel):
950
  return ((loss,) + output) if loss is not None else output
951
 
952
  # print(outputs.attentions)
953
- # try:
954
- # print(f'not stacked final attention SHAPE: {outputs[2][0].shape}')
955
- # except:
956
- # print(f'not stacked final attention LEN: {len(outputs[2])}')
957
- # try:
958
- # print(f'STACKED final attention SHAPE: {(torch.stack(outputs[2], dim=0)).shape}')
959
- # except:
960
- # print(f'STACKED final attention LEN: {len(torch.stack(outputs[2], dim=0))}')
 
 
 
 
961
 
962
  return SequenceClassifierOutput(
963
  loss=loss,
 
950
  return ((loss,) + output) if loss is not None else output
951
 
952
  # print(outputs.attentions)
953
+ try:
954
+ print(f'not stacked final attention SHAPE: {outputs[2][0].shape}')
955
+ except:
956
+ print(f'not stacked final attention LEN: {len(outputs[2])}')
957
+
958
+ try:
959
+ print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
960
+ except:
961
+ try:
962
+ print(f'STACKED final attention LEN: {(outputs.attentions)[0].shape}')
963
+ except:
964
+ print(f'STACKED final attention LEN 2: {len(outputs.attentions)}")
965
 
966
  return SequenceClassifierOutput(
967
  loss=loss,