Update bert_layers.py
Browse files- bert_layers.py +38 -18
bert_layers.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
# Copyright 2022 MosaicML Examples authors
|
2 |
# SPDX-License-Identifier: Apache-2.0
|
3 |
|
@@ -328,7 +330,7 @@ class BertLayer(nn.Module):
|
|
328 |
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
|
329 |
subset_idx, indices, attn_mask, bias)
|
330 |
layer_output = self.mlp(attention_output)
|
331 |
-
return layer_output, attention_output
|
332 |
|
333 |
|
334 |
class BertEncoder(nn.Module):
|
@@ -343,7 +345,7 @@ class BertEncoder(nn.Module):
|
|
343 |
|
344 |
def __init__(self, config):
|
345 |
super().__init__()
|
346 |
-
layer = BertLayer(config)
|
347 |
self.layer = nn.ModuleList(
|
348 |
[copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
349 |
|
@@ -446,6 +448,7 @@ class BertEncoder(nn.Module):
|
|
446 |
|
447 |
if subset_mask is None:
|
448 |
for layer_module in self.layer:
|
|
|
449 |
hidden_states, attention_weights = layer_module(hidden_states,
|
450 |
cu_seqlens,
|
451 |
seqlen,
|
@@ -453,8 +456,9 @@ class BertEncoder(nn.Module):
|
|
453 |
indices,
|
454 |
attn_mask=attention_mask,
|
455 |
bias=alibi_attn_mask)
|
456 |
-
|
457 |
-
print(f'Inner Attention
|
|
|
458 |
all_attention_weights.append(attention_weights) # Store attention weights
|
459 |
if output_all_encoded_layers:
|
460 |
all_encoder_layers.append(hidden_states)
|
@@ -467,6 +471,7 @@ class BertEncoder(nn.Module):
|
|
467 |
else:
|
468 |
for i in range(len(self.layer) - 1):
|
469 |
layer_module = self.layer[i]
|
|
|
470 |
hidden_states, attention_weights = layer_module(hidden_states,
|
471 |
cu_seqlens,
|
472 |
seqlen,
|
@@ -474,11 +479,12 @@ class BertEncoder(nn.Module):
|
|
474 |
indices,
|
475 |
attn_mask=attention_mask,
|
476 |
bias=alibi_attn_mask)
|
477 |
-
all_attention_weights.append(attention_weights) # Store attention weights
|
478 |
if output_all_encoded_layers:
|
479 |
all_encoder_layers.append(hidden_states)
|
480 |
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
481 |
as_tuple=False).flatten()
|
|
|
482 |
hidden_states, attention_weights = self.layer[-1](hidden_states,
|
483 |
cu_seqlens,
|
484 |
seqlen,
|
@@ -486,14 +492,16 @@ class BertEncoder(nn.Module):
|
|
486 |
indices=indices,
|
487 |
attn_mask=attention_mask,
|
488 |
bias=alibi_attn_mask)
|
489 |
-
all_attention_weights.append(attention_weights) #
|
490 |
-
print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
|
491 |
-
print(f'and this is the shape inside encoder: \n {all_attention_weights.shape}')
|
492 |
|
493 |
if not output_all_encoded_layers:
|
494 |
all_encoder_layers.append(hidden_states)
|
495 |
-
|
496 |
-
|
|
|
|
|
497 |
|
498 |
|
499 |
|
@@ -617,7 +625,9 @@ class BertModel(BertPreTrainedModel):
|
|
617 |
first_col_mask[:, 0] = True
|
618 |
subset_mask = masked_tokens_mask | first_col_mask
|
619 |
|
620 |
-
|
|
|
|
|
621 |
embedding_output,
|
622 |
attention_mask,
|
623 |
output_all_encoded_layers=output_all_encoded_layers,
|
@@ -645,11 +655,13 @@ class BertModel(BertPreTrainedModel):
|
|
645 |
if not output_all_encoded_layers:
|
646 |
encoder_outputs = sequence_output
|
647 |
|
|
|
648 |
if self.pooler is not None:
|
649 |
-
return encoder_outputs, pooled_output
|
650 |
-
|
651 |
-
return encoder_outputs, None
|
652 |
|
|
|
|
|
|
|
653 |
|
654 |
###################
|
655 |
# Bert Heads
|
@@ -705,6 +717,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
705 |
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
706 |
'bi-directional self-attention.')
|
707 |
|
|
|
|
|
708 |
self.bert = BertModel(config, add_pooling_layer=False)
|
709 |
self.cls = BertOnlyMLMHead(config,
|
710 |
self.bert.embeddings.word_embeddings.weight)
|
@@ -754,6 +768,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
754 |
|
755 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
756 |
|
|
|
757 |
outputs = self.bert(
|
758 |
input_ids,
|
759 |
attention_mask=attention_mask,
|
@@ -789,7 +804,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
789 |
b=batch)
|
790 |
|
791 |
if not return_dict:
|
792 |
-
output = (prediction_scores,) + outputs[2:]
|
793 |
return ((loss,) + output) if loss is not None else output
|
794 |
|
795 |
return MaskedLMOutput(
|
@@ -823,7 +838,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
823 |
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
824 |
|
825 |
|
826 |
-
|
827 |
class BertForSequenceClassification(BertPreTrainedModel):
|
828 |
"""Bert Model transformer with a sequence classification/regression head.
|
829 |
|
@@ -869,7 +884,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
869 |
|
870 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
871 |
|
872 |
-
outputs = self.bert(
|
873 |
input_ids,
|
874 |
attention_mask=attention_mask,
|
875 |
token_type_ids=token_type_ids,
|
@@ -882,6 +897,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
882 |
)
|
883 |
|
884 |
pooled_output = outputs[1]
|
|
|
|
|
|
|
885 |
|
886 |
pooled_output = self.dropout(pooled_output)
|
887 |
logits = self.classifier(pooled_output)
|
@@ -913,6 +931,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
913 |
loss = loss_fct(logits, labels)
|
914 |
|
915 |
if not return_dict:
|
|
|
916 |
output = (logits,) + outputs[2:]
|
917 |
return ((loss,) + output) if loss is not None else output
|
918 |
|
@@ -923,6 +942,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
923 |
loss=loss,
|
924 |
logits=logits,
|
925 |
hidden_states=outputs[0],
|
926 |
-
|
|
|
927 |
)
|
928 |
|
|
|
1 |
+
# search for JAANDOUI for the parts I have modified, and for JAANDOUI TODO for the parts that might need to be changed.
|
2 |
+
|
3 |
# Copyright 2022 MosaicML Examples authors
|
4 |
# SPDX-License-Identifier: Apache-2.0
|
5 |
|
|
|
330 |
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
|
331 |
subset_idx, indices, attn_mask, bias)
|
332 |
layer_output = self.mlp(attention_output)
|
333 |
+
return layer_output, attention_output # JAANDOUI: this only returns layer_output in the original work.
|
334 |
|
335 |
|
336 |
class BertEncoder(nn.Module):
|
|
|
345 |
|
346 |
def __init__(self, config):
|
347 |
super().__init__()
|
348 |
+
layer = BertLayer(config) # JAANDOUI: In this line we define the BertLayer, note that now the forward of this class returns attention too!! 2 values instead of 1
|
349 |
self.layer = nn.ModuleList(
|
350 |
[copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
351 |
|
|
|
448 |
|
449 |
if subset_mask is None:
|
450 |
for layer_module in self.layer:
|
451 |
+
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
452 |
hidden_states, attention_weights = layer_module(hidden_states,
|
453 |
cu_seqlens,
|
454 |
seqlen,
|
|
|
456 |
indices,
|
457 |
attn_mask=attention_mask,
|
458 |
bias=alibi_attn_mask)
|
459 |
+
# JAANDOUI
|
460 |
+
# print(f'Inner Attention: {attention_weights}')
|
461 |
+
# print(f'Inner Attention shape: {attention_weights.shape}')
|
462 |
all_attention_weights.append(attention_weights) # Store attention weights
|
463 |
if output_all_encoded_layers:
|
464 |
all_encoder_layers.append(hidden_states)
|
|
|
471 |
else:
|
472 |
for i in range(len(self.layer) - 1):
|
473 |
layer_module = self.layer[i]
|
474 |
+
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
475 |
hidden_states, attention_weights = layer_module(hidden_states,
|
476 |
cu_seqlens,
|
477 |
seqlen,
|
|
|
479 |
indices,
|
480 |
attn_mask=attention_mask,
|
481 |
bias=alibi_attn_mask)
|
482 |
+
all_attention_weights.append(attention_weights) # JAANDOUI: Store attention weights
|
483 |
if output_all_encoded_layers:
|
484 |
all_encoder_layers.append(hidden_states)
|
485 |
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
486 |
as_tuple=False).flatten()
|
487 |
+
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
488 |
hidden_states, attention_weights = self.layer[-1](hidden_states,
|
489 |
cu_seqlens,
|
490 |
seqlen,
|
|
|
492 |
indices=indices,
|
493 |
attn_mask=attention_mask,
|
494 |
bias=alibi_attn_mask)
|
495 |
+
all_attention_weights.append(attention_weights) # JAANDOUI: appending the attention of different layers together.
|
496 |
+
# print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
|
497 |
+
# print(f'and this is the shape inside encoder: \n {all_attention_weights.shape}')
|
498 |
|
499 |
if not output_all_encoded_layers:
|
500 |
all_encoder_layers.append(hidden_states)
|
501 |
+
|
502 |
+
# JAANDOUI: Since we now return both, we need to handle them wherever BertEncoder forward is called.
|
503 |
+
return all_encoder_layers, all_attention_weights # Return both hidden states and attention weights
|
504 |
+
# return all_encoder_layers # JAANDOUI: original return.
|
505 |
|
506 |
|
507 |
|
|
|
625 |
first_col_mask[:, 0] = True
|
626 |
subset_mask = masked_tokens_mask | first_col_mask
|
627 |
|
628 |
+
# JAANDOUI: first part where we call self.encoder (which is the instance of BertEncoder defined here)
|
629 |
+
# JAANDOUI: need to return the attention weights here too.
|
630 |
+
encoder_outputs, all_attention_weights = self.encoder(
|
631 |
embedding_output,
|
632 |
attention_mask,
|
633 |
output_all_encoded_layers=output_all_encoded_layers,
|
|
|
655 |
if not output_all_encoded_layers:
|
656 |
encoder_outputs = sequence_output
|
657 |
|
658 |
+
# JAANDOUI: returning all_attention_weights too
|
659 |
if self.pooler is not None:
|
660 |
+
return encoder_outputs, pooled_output, all_attention_weights
|
|
|
|
|
661 |
|
662 |
+
# JAANDOUI: returning all_attention_weights too
|
663 |
+
return encoder_outputs, None, all_attention_weights
|
664 |
+
# JAANDOUI: need to handle the returned elements wherever BertModel is instantiated.
|
665 |
|
666 |
###################
|
667 |
# Bert Heads
|
|
|
717 |
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
718 |
'bi-directional self-attention.')
|
719 |
|
720 |
+
# JAANDOUI: this part is only for the pretraining, I don't think it is called if we finetune
|
721 |
+
# there handle the returned elements (we now get 3 elements) of BertModel if pretraining
|
722 |
self.bert = BertModel(config, add_pooling_layer=False)
|
723 |
self.cls = BertOnlyMLMHead(config,
|
724 |
self.bert.embeddings.word_embeddings.weight)
|
|
|
768 |
|
769 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
770 |
|
771 |
+
# JAANDOUI: for the pretraining: return handled here.
|
772 |
outputs = self.bert(
|
773 |
input_ids,
|
774 |
attention_mask=attention_mask,
|
|
|
804 |
b=batch)
|
805 |
|
806 |
if not return_dict:
|
807 |
+
output = (prediction_scores,) + outputs[2:] # JAANDOUI TODO: might need to handle this part and everywhere where we get outputs (now outputs has 3 elements not 2)
|
808 |
return ((loss,) + output) if loss is not None else output
|
809 |
|
810 |
return MaskedLMOutput(
|
|
|
838 |
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
839 |
|
840 |
|
841 |
+
# JAANDOUI: this model is the one used for finetuning.
|
842 |
class BertForSequenceClassification(BertPreTrainedModel):
|
843 |
"""Bert Model transformer with a sequence classification/regression head.
|
844 |
|
|
|
884 |
|
885 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
886 |
|
887 |
+
outputs, _, all_attention_weights = self.bert(
|
888 |
input_ids,
|
889 |
attention_mask=attention_mask,
|
890 |
token_type_ids=token_type_ids,
|
|
|
897 |
)
|
898 |
|
899 |
pooled_output = outputs[1]
|
900 |
+
|
901 |
+
# JAANDOUI:
|
902 |
+
all_attention_weights = outputs[2]
|
903 |
|
904 |
pooled_output = self.dropout(pooled_output)
|
905 |
logits = self.classifier(pooled_output)
|
|
|
931 |
loss = loss_fct(logits, labels)
|
932 |
|
933 |
if not return_dict:
|
934 |
+
# JAANDOUI TODO maybe.
|
935 |
output = (logits,) + outputs[2:]
|
936 |
return ((loss,) + output) if loss is not None else output
|
937 |
|
|
|
942 |
loss=loss,
|
943 |
logits=logits,
|
944 |
hidden_states=outputs[0],
|
945 |
+
#JAANDOUI: returning all_attention_weights here
|
946 |
+
attentions=outputs[2],
|
947 |
)
|
948 |
|