jaandoui commited on
Commit
5100834
1 Parent(s): 57cbc98

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +14 -8
bert_layers.py CHANGED
@@ -911,8 +911,14 @@ class BertForSequenceClassification(BertPreTrainedModel):
911
 
912
  # JAANDOUI:
913
  all_attention_weights = outputs[2]
914
-
915
- # print(f'last: {all_attention_weights}')
 
 
 
 
 
 
916
 
917
  pooled_output = self.dropout(pooled_output)
918
  logits = self.classifier(pooled_output)
@@ -956,12 +962,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
956
  print(f'not stacked final attention LEN: {len(outputs[2])}')
957
 
958
  try:
959
- print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
960
- except:
961
- try:
962
- print(f'STACKED final attention LEN: {(outputs.attentions)[0].shape}')
963
- except:
964
- print(f'STACKED final attention LEN 2: {len(outputs.attentions)}')
965
 
966
  return SequenceClassifierOutput(
967
  loss=loss,
 
911
 
912
  # JAANDOUI:
913
  all_attention_weights = outputs[2]
914
+ try:
915
+ print(f'last: {all_attention_weights.shape}')
916
+ except:
917
+ try:
918
+ print(f'last: {all_attention_weights[0].shape}')
919
+ except:
920
+ print(f'last: {len(all_attention_weights[0])}')
921
+
922
 
923
  pooled_output = self.dropout(pooled_output)
924
  logits = self.classifier(pooled_output)
 
962
  print(f'not stacked final attention LEN: {len(outputs[2])}')
963
 
964
  try:
965
+ # print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
966
+ # except:
967
+ # try:
968
+ # print(f'STACKED final attention LEN: {(outputs.attentions)[0].shape}')
969
+ # except:
970
+ # print(f'STACKED final attention LEN 2: {len(outputs.attentions)}')
971
 
972
  return SequenceClassifierOutput(
973
  loss=loss,