jaandoui commited on
Commit
5150f64
1 Parent(s): 05f5b36

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +38 -18
bert_layers.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # Copyright 2022 MosaicML Examples authors
2
  # SPDX-License-Identifier: Apache-2.0
3
 
@@ -328,7 +330,7 @@ class BertLayer(nn.Module):
328
  attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
329
  subset_idx, indices, attn_mask, bias)
330
  layer_output = self.mlp(attention_output)
331
- return layer_output, attention_output
332
 
333
 
334
  class BertEncoder(nn.Module):
@@ -343,7 +345,7 @@ class BertEncoder(nn.Module):
343
 
344
  def __init__(self, config):
345
  super().__init__()
346
- layer = BertLayer(config)
347
  self.layer = nn.ModuleList(
348
  [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
349
 
@@ -446,6 +448,7 @@ class BertEncoder(nn.Module):
446
 
447
  if subset_mask is None:
448
  for layer_module in self.layer:
 
449
  hidden_states, attention_weights = layer_module(hidden_states,
450
  cu_seqlens,
451
  seqlen,
@@ -453,8 +456,9 @@ class BertEncoder(nn.Module):
453
  indices,
454
  attn_mask=attention_mask,
455
  bias=alibi_attn_mask)
456
- print(f'Inner Attention: {attention_weights}')
457
- print(f'Inner Attention shape: {attention_weights.shape}')
 
458
  all_attention_weights.append(attention_weights) # Store attention weights
459
  if output_all_encoded_layers:
460
  all_encoder_layers.append(hidden_states)
@@ -467,6 +471,7 @@ class BertEncoder(nn.Module):
467
  else:
468
  for i in range(len(self.layer) - 1):
469
  layer_module = self.layer[i]
 
470
  hidden_states, attention_weights = layer_module(hidden_states,
471
  cu_seqlens,
472
  seqlen,
@@ -474,11 +479,12 @@ class BertEncoder(nn.Module):
474
  indices,
475
  attn_mask=attention_mask,
476
  bias=alibi_attn_mask)
477
- all_attention_weights.append(attention_weights) # Store attention weights
478
  if output_all_encoded_layers:
479
  all_encoder_layers.append(hidden_states)
480
  subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
481
  as_tuple=False).flatten()
 
482
  hidden_states, attention_weights = self.layer[-1](hidden_states,
483
  cu_seqlens,
484
  seqlen,
@@ -486,14 +492,16 @@ class BertEncoder(nn.Module):
486
  indices=indices,
487
  attn_mask=attention_mask,
488
  bias=alibi_attn_mask)
489
- all_attention_weights.append(attention_weights) # Store attention weights
490
- print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
491
- print(f'and this is the shape inside encoder: \n {all_attention_weights.shape}')
492
 
493
  if not output_all_encoded_layers:
494
  all_encoder_layers.append(hidden_states)
495
- # return all_encoder_layers, all_attention_weights # Return both hidden states and attention weights
496
- return all_encoder_layers # Return both hidden states and attention weights
 
 
497
 
498
 
499
 
@@ -617,7 +625,9 @@ class BertModel(BertPreTrainedModel):
617
  first_col_mask[:, 0] = True
618
  subset_mask = masked_tokens_mask | first_col_mask
619
 
620
- encoder_outputs = self.encoder(
 
 
621
  embedding_output,
622
  attention_mask,
623
  output_all_encoded_layers=output_all_encoded_layers,
@@ -645,11 +655,13 @@ class BertModel(BertPreTrainedModel):
645
  if not output_all_encoded_layers:
646
  encoder_outputs = sequence_output
647
 
 
648
  if self.pooler is not None:
649
- return encoder_outputs, pooled_output
650
-
651
- return encoder_outputs, None
652
 
 
 
 
653
 
654
  ###################
655
  # Bert Heads
@@ -705,6 +717,8 @@ class BertForMaskedLM(BertPreTrainedModel):
705
  'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
706
  'bi-directional self-attention.')
707
 
 
 
708
  self.bert = BertModel(config, add_pooling_layer=False)
709
  self.cls = BertOnlyMLMHead(config,
710
  self.bert.embeddings.word_embeddings.weight)
@@ -754,6 +768,7 @@ class BertForMaskedLM(BertPreTrainedModel):
754
 
755
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
756
 
 
757
  outputs = self.bert(
758
  input_ids,
759
  attention_mask=attention_mask,
@@ -789,7 +804,7 @@ class BertForMaskedLM(BertPreTrainedModel):
789
  b=batch)
790
 
791
  if not return_dict:
792
- output = (prediction_scores,) + outputs[2:]
793
  return ((loss,) + output) if loss is not None else output
794
 
795
  return MaskedLMOutput(
@@ -823,7 +838,7 @@ class BertForMaskedLM(BertPreTrainedModel):
823
  return {'input_ids': input_ids, 'attention_mask': attention_mask}
824
 
825
 
826
-
827
  class BertForSequenceClassification(BertPreTrainedModel):
828
  """Bert Model transformer with a sequence classification/regression head.
829
 
@@ -869,7 +884,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
869
 
870
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
871
 
872
- outputs = self.bert(
873
  input_ids,
874
  attention_mask=attention_mask,
875
  token_type_ids=token_type_ids,
@@ -882,6 +897,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
882
  )
883
 
884
  pooled_output = outputs[1]
 
 
 
885
 
886
  pooled_output = self.dropout(pooled_output)
887
  logits = self.classifier(pooled_output)
@@ -913,6 +931,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
913
  loss = loss_fct(logits, labels)
914
 
915
  if not return_dict:
 
916
  output = (logits,) + outputs[2:]
917
  return ((loss,) + output) if loss is not None else output
918
 
@@ -923,6 +942,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
923
  loss=loss,
924
  logits=logits,
925
  hidden_states=outputs[0],
926
- attentions=None,
 
927
  )
928
 
 
1
+ # search for JAANDOUI for the parts I have modified, and for JAANDOUI TODO for the parts that might need to be changed.
2
+
3
  # Copyright 2022 MosaicML Examples authors
4
  # SPDX-License-Identifier: Apache-2.0
5
 
 
330
  attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
331
  subset_idx, indices, attn_mask, bias)
332
  layer_output = self.mlp(attention_output)
333
+ return layer_output, attention_output # JAANDOUI: this only returns layer_output in the original work.
334
 
335
 
336
  class BertEncoder(nn.Module):
 
345
 
346
  def __init__(self, config):
347
  super().__init__()
348
+ layer = BertLayer(config) # JAANDOUI: In this line we define the BertLayer, note that now the forward of this class returns attention too!! 2 values instead of 1
349
  self.layer = nn.ModuleList(
350
  [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
351
 
 
448
 
449
  if subset_mask is None:
450
  for layer_module in self.layer:
451
+ # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
452
  hidden_states, attention_weights = layer_module(hidden_states,
453
  cu_seqlens,
454
  seqlen,
 
456
  indices,
457
  attn_mask=attention_mask,
458
  bias=alibi_attn_mask)
459
+ # JAANDOUI
460
+ # print(f'Inner Attention: {attention_weights}')
461
+ # print(f'Inner Attention shape: {attention_weights.shape}')
462
  all_attention_weights.append(attention_weights) # Store attention weights
463
  if output_all_encoded_layers:
464
  all_encoder_layers.append(hidden_states)
 
471
  else:
472
  for i in range(len(self.layer) - 1):
473
  layer_module = self.layer[i]
474
+ # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
475
  hidden_states, attention_weights = layer_module(hidden_states,
476
  cu_seqlens,
477
  seqlen,
 
479
  indices,
480
  attn_mask=attention_mask,
481
  bias=alibi_attn_mask)
482
+ all_attention_weights.append(attention_weights) # JAANDOUI: Store attention weights
483
  if output_all_encoded_layers:
484
  all_encoder_layers.append(hidden_states)
485
  subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
486
  as_tuple=False).flatten()
487
+ # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
488
  hidden_states, attention_weights = self.layer[-1](hidden_states,
489
  cu_seqlens,
490
  seqlen,
 
492
  indices=indices,
493
  attn_mask=attention_mask,
494
  bias=alibi_attn_mask)
495
+ all_attention_weights.append(attention_weights) # JAANDOUI: appending the attention of different layers together.
496
+ # print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
497
+ # print(f'and this is the shape inside encoder: \n {all_attention_weights.shape}')
498
 
499
  if not output_all_encoded_layers:
500
  all_encoder_layers.append(hidden_states)
501
+
502
+ # JAANDOUI: Since we now return both, we need to handle them wherever BertEncoder forward is called.
503
+ return all_encoder_layers, all_attention_weights # Return both hidden states and attention weights
504
+ # return all_encoder_layers # JAANDOUI: original return.
505
 
506
 
507
 
 
625
  first_col_mask[:, 0] = True
626
  subset_mask = masked_tokens_mask | first_col_mask
627
 
628
+ # JAANDOUI: first part where we call self.encoder (which is the instance of BertEncoder defined here)
629
+ # JAANDOUI: need to return the attention weights here too.
630
+ encoder_outputs, all_attention_weights = self.encoder(
631
  embedding_output,
632
  attention_mask,
633
  output_all_encoded_layers=output_all_encoded_layers,
 
655
  if not output_all_encoded_layers:
656
  encoder_outputs = sequence_output
657
 
658
+ # JAANDOUI: returning all_attention_weights too
659
  if self.pooler is not None:
660
+ return encoder_outputs, pooled_output, all_attention_weights
 
 
661
 
662
+ # JAANDOUI: returning all_attention_weights too
663
+ return encoder_outputs, None, all_attention_weights
664
+ # JAANDOUI: need to handle the returned elements wherever BertModel is instantiated.
665
 
666
  ###################
667
  # Bert Heads
 
717
  'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
718
  'bi-directional self-attention.')
719
 
720
+ # JAANDOUI: this part is only for the pretraining, I don't think it is called if we finetune
721
+ # there handle the returned elements (we now get 3 elements) of BertModel if pretraining
722
  self.bert = BertModel(config, add_pooling_layer=False)
723
  self.cls = BertOnlyMLMHead(config,
724
  self.bert.embeddings.word_embeddings.weight)
 
768
 
769
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
770
 
771
+ # JAANDOUI: for the pretraining: return handled here.
772
  outputs = self.bert(
773
  input_ids,
774
  attention_mask=attention_mask,
 
804
  b=batch)
805
 
806
  if not return_dict:
807
+ output = (prediction_scores,) + outputs[2:] # JAANDOUI TODO: might need to handle this part and everywhere where we get outputs (now outputs has 3 elements not 2)
808
  return ((loss,) + output) if loss is not None else output
809
 
810
  return MaskedLMOutput(
 
838
  return {'input_ids': input_ids, 'attention_mask': attention_mask}
839
 
840
 
841
+ # JAANDOUI: this model is the one used for finetuning.
842
  class BertForSequenceClassification(BertPreTrainedModel):
843
  """Bert Model transformer with a sequence classification/regression head.
844
 
 
884
 
885
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
886
 
887
+ outputs, _, all_attention_weights = self.bert(
888
  input_ids,
889
  attention_mask=attention_mask,
890
  token_type_ids=token_type_ids,
 
897
  )
898
 
899
  pooled_output = outputs[1]
900
+
901
+ # JAANDOUI:
902
+ all_attention_weights = outputs[2]
903
 
904
  pooled_output = self.dropout(pooled_output)
905
  logits = self.classifier(pooled_output)
 
931
  loss = loss_fct(logits, labels)
932
 
933
  if not return_dict:
934
+ # JAANDOUI TODO maybe.
935
  output = (logits,) + outputs[2:]
936
  return ((loss,) + output) if loss is not None else output
937
 
 
942
  loss=loss,
943
  logits=logits,
944
  hidden_states=outputs[0],
945
+ #JAANDOUI: returning all_attention_weights here
946
+ attentions=outputs[2],
947
  )
948