ccdv commited on
Commit
d766a2f
1 Parent(s): b97ea6d

update for transformers >= 4.29.1

Browse files
Files changed (1) hide show
  1. modeling_lsg_distilbert.py +15 -33
modeling_lsg_distilbert.py CHANGED
@@ -233,19 +233,25 @@ class CausalAttentionProduct(nn.Module):
233
  del key_layer
234
 
235
  if attention_mask is not None:
236
- # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
237
- attention_scores = attention_scores + attention_mask
238
-
239
  # Add causal mask
240
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
241
  causal_mask = torch.tril(
242
  torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
243
  diagonal=-1
244
  )
245
- causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
246
- attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
 
 
 
 
 
 
 
247
 
 
248
  del attention_mask
 
249
 
250
  # Normalize the attention scores to probabilities.
251
  attention_probs = nn.Softmax(dim=-1)(attention_scores)
@@ -777,7 +783,7 @@ class LSGTransformer(Transformer):
777
  attn_mask[..., 0] = mask_value
778
 
779
  attn_mask = torch.finfo(x.dtype).min*(1 - attn_mask).unsqueeze(1).unsqueeze(1)
780
-
781
  encoder_outputs = super().forward(
782
  x=x,
783
  attn_mask=attn_mask,
@@ -822,36 +828,12 @@ class LSGDistilBertModel(LSGDistilBertPreTrainedModel, DistilBertModel):
822
  # Initialize weights and apply final processing
823
  self.post_init()
824
 
825
- def forward(
826
- self,
827
- input_ids: Optional[torch.Tensor] = None,
828
- attention_mask: Optional[torch.Tensor] = None,
829
- head_mask: Optional[torch.Tensor] = None,
830
- inputs_embeds: Optional[torch.Tensor] = None,
831
- output_attentions: Optional[bool] = None,
832
- output_hidden_states: Optional[bool] = None,
833
- return_dict: Optional[bool] = None,
834
- ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
835
-
836
-
837
- if input_ids is None and inputs_embeds is not None:
838
- inputs_embeds = self.embeddings(None, inputs_embeds)
839
- if attention_mask is None:
840
- n, t, d = inputs_embeds.size()
841
- attention_mask = torch.ones(n, t - self.num_global_tokens, device=inputs_embeds.device)
842
-
843
- return super().forward(
844
- input_ids=input_ids,
845
- attention_mask=attention_mask,
846
- head_mask=head_mask,
847
- inputs_embeds=inputs_embeds,
848
- output_attentions=output_attentions,
849
- output_hidden_states=output_hidden_states,
850
- return_dict=return_dict
851
- )
852
 
853
  class LSGDistilBertForMaskedLM(LSGDistilBertPreTrainedModel, DistilBertForMaskedLM):
854
 
 
 
 
855
  def __init__(self, config):
856
 
857
  LSGDistilBertPreTrainedModel.__init__(self, config)
 
233
  del key_layer
234
 
235
  if attention_mask is not None:
 
 
 
236
  # Add causal mask
237
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
238
  causal_mask = torch.tril(
239
  torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
240
  diagonal=-1
241
  )
242
+
243
+ # Min value
244
+ dtype_min = torch.tensor(
245
+ torch.finfo(attention_scores.dtype).min, device=attention_scores.device, dtype=attention_scores.dtype
246
+ )
247
+
248
+ # Build causal + attention_mask
249
+ causal_mask = torch.nn.functional.pad(causal_mask.T * dtype_min, (attention_mask.size()[-1] - self.block_size, 0), value=0)
250
+ attention_mask = torch.max(attention_mask + causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype_min)
251
 
252
+ attention_scores = attention_scores + attention_mask
253
  del attention_mask
254
+ del causal_mask
255
 
256
  # Normalize the attention scores to probabilities.
257
  attention_probs = nn.Softmax(dim=-1)(attention_scores)
 
783
  attn_mask[..., 0] = mask_value
784
 
785
  attn_mask = torch.finfo(x.dtype).min*(1 - attn_mask).unsqueeze(1).unsqueeze(1)
786
+
787
  encoder_outputs = super().forward(
788
  x=x,
789
  attn_mask=attn_mask,
 
828
  # Initialize weights and apply final processing
829
  self.post_init()
830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
 
832
  class LSGDistilBertForMaskedLM(LSGDistilBertPreTrainedModel, DistilBertForMaskedLM):
833
 
834
+ _keys_to_ignore_on_load_missing = ["vocab_projector.weight"]
835
+ _tied_weights_keys = ["vocab_projector.weight"]
836
+
837
  def __init__(self, config):
838
 
839
  LSGDistilBertPreTrainedModel.__init__(self, config)