ccdv commited on
Commit
43c3663
1 Parent(s): 5f0fe86

update for transformers >= 4.29.1

Browse files
Files changed (1) hide show
  1. modeling_lsg_bert.py +15 -22
modeling_lsg_bert.py CHANGED
@@ -189,19 +189,25 @@ class CausalAttentionProduct(nn.Module):
189
  del key_layer
190
 
191
  if attention_mask is not None:
192
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
193
- attention_scores = attention_scores + attention_mask
194
-
195
  # Add causal mask
196
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
197
  causal_mask = torch.tril(
198
  torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
199
  diagonal=-1
200
  )
201
- causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
202
- attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
 
 
 
203
 
 
 
 
 
 
204
  del attention_mask
 
205
 
206
  # Normalize the attention scores to probabilities.
207
  attention_probs = nn.Softmax(dim=-1)(attention_scores)
@@ -991,8 +997,6 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
991
  documentation alongside usage examples.
992
  """
993
 
994
- config_class = LSGBertConfig
995
-
996
  def __init__(self, config, add_pooling_layer=True):
997
 
998
  LSGBertPreTrainedModel.__init__(self, config)
@@ -1031,6 +1035,8 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1031
 
1032
  class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
1033
 
 
 
1034
  def __init__(self, config):
1035
 
1036
  LSGBertPreTrainedModel.__init__(self, config)
@@ -1044,8 +1050,7 @@ class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
1044
 
1045
  class LSGBertLMHeadModel(LSGBertPreTrainedModel, BertLMHeadModel):
1046
 
1047
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
1048
- _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1049
 
1050
  def __init__(self, config):
1051
 
@@ -1067,9 +1072,7 @@ class LSGBertForMaskedLM(LSGBertPreTrainedModel, BertForMaskedLM):
1067
  documentation alongside usage examples.
1068
  """
1069
 
1070
- config_class = LSGBertConfig
1071
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
1072
- _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1073
 
1074
  def __init__(self, config):
1075
 
@@ -1107,8 +1110,6 @@ class LSGBertForSequenceClassification(LSGBertPreTrainedModel, BertForSequenceCl
1107
  appropriate documentation alongside usage examples.
1108
  """
1109
 
1110
- config_class = LSGBertConfig
1111
-
1112
  def __init__(self, config):
1113
 
1114
  LSGBertPreTrainedModel.__init__(self, config)
@@ -1133,8 +1134,6 @@ class LSGBertForMultipleChoice(LSGBertPreTrainedModel, BertForMultipleChoice):
1133
  appropriate documentation alongside usage examples.
1134
  """
1135
 
1136
- config_class = LSGBertConfig
1137
-
1138
  def __init__(self, config):
1139
 
1140
  LSGBertPreTrainedModel.__init__(self, config)
@@ -1156,9 +1155,6 @@ class LSGBertForTokenClassification(LSGBertPreTrainedModel, BertForTokenClassifi
1156
  appropriate documentation alongside usage examples.
1157
  """
1158
 
1159
- config_class = LSGBertConfig
1160
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
1161
-
1162
  def __init__(self, config):
1163
 
1164
  LSGBertPreTrainedModel.__init__(self, config)
@@ -1182,9 +1178,6 @@ class LSGBertForQuestionAnswering(LSGBertPreTrainedModel, BertForQuestionAnsweri
1182
  appropriate documentation alongside usage examples.
1183
  """
1184
 
1185
- config_class = LSGBertConfig
1186
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
1187
-
1188
  def __init__(self, config):
1189
 
1190
  LSGBertPreTrainedModel.__init__(self, config)
 
189
  del key_layer
190
 
191
  if attention_mask is not None:
 
 
 
192
  # Add causal mask
193
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
194
  causal_mask = torch.tril(
195
  torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
196
  diagonal=-1
197
  )
198
+
199
+ # Min value
200
+ dtype_min = torch.tensor(
201
+ torch.finfo(attention_scores.dtype).min, device=attention_scores.device, dtype=attention_scores.dtype
202
+ )
203
 
204
+ # Build causal + attention_mask
205
+ causal_mask = torch.nn.functional.pad(causal_mask.T * dtype_min, (attention_mask.size()[-1] - self.block_size, 0), value=0)
206
+ attention_mask = torch.max(attention_mask + causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype_min)
207
+
208
+ attention_scores = attention_scores + attention_mask
209
  del attention_mask
210
+ del causal_mask
211
 
212
  # Normalize the attention scores to probabilities.
213
  attention_probs = nn.Softmax(dim=-1)(attention_scores)
 
997
  documentation alongside usage examples.
998
  """
999
 
 
 
1000
  def __init__(self, config, add_pooling_layer=True):
1001
 
1002
  LSGBertPreTrainedModel.__init__(self, config)
 
1035
 
1036
  class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
1037
 
1038
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1039
+
1040
  def __init__(self, config):
1041
 
1042
  LSGBertPreTrainedModel.__init__(self, config)
 
1050
 
1051
  class LSGBertLMHeadModel(LSGBertPreTrainedModel, BertLMHeadModel):
1052
 
1053
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
 
1054
 
1055
  def __init__(self, config):
1056
 
 
1072
  documentation alongside usage examples.
1073
  """
1074
 
1075
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
 
 
1076
 
1077
  def __init__(self, config):
1078
 
 
1110
  appropriate documentation alongside usage examples.
1111
  """
1112
 
 
 
1113
  def __init__(self, config):
1114
 
1115
  LSGBertPreTrainedModel.__init__(self, config)
 
1134
  appropriate documentation alongside usage examples.
1135
  """
1136
 
 
 
1137
  def __init__(self, config):
1138
 
1139
  LSGBertPreTrainedModel.__init__(self, config)
 
1155
  appropriate documentation alongside usage examples.
1156
  """
1157
 
 
 
 
1158
  def __init__(self, config):
1159
 
1160
  LSGBertPreTrainedModel.__init__(self, config)
 
1178
  appropriate documentation alongside usage examples.
1179
  """
1180
 
 
 
 
1181
  def __init__(self, config):
1182
 
1183
  LSGBertPreTrainedModel.__init__(self, config)