ccdv commited on
Commit
8788f75
1 Parent(s): 54a7242
Files changed (2) hide show
  1. README.md +1 -1
  2. modeling_lsg_bert.py +15 -3
README.md CHANGED
@@ -7,7 +7,7 @@ pipeline_tag: fill-mask
7
  ---
8
 
9
  # LSG model
10
- **Transformers >= 4.35.2**\
11
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
12
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
13
 
 
7
  ---
8
 
9
  # LSG model
10
+ **Transformers >= 4.36.1**\
11
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
12
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
13
 
modeling_lsg_bert.py CHANGED
@@ -411,8 +411,13 @@ class LSGBertEmbeddings(BertEmbeddings):
411
  self.block_size = config.block_size
412
 
413
  def forward(
414
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
415
- ):
 
 
 
 
 
416
  if input_ids is not None:
417
  input_shape = input_ids.size()
418
  else:
@@ -1005,6 +1010,7 @@ class LSGBertEncoder(BertEncoder):
1005
  encoder_outputs.last_hidden_state = sequence_output
1006
  return encoder_outputs
1007
 
 
1008
  class LSGBertPreTrainedModel(BertPreTrainedModel):
1009
  """
1010
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -1039,6 +1045,12 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1039
  "Cross attention is computed using full attention since it is not LSG compatible."
1040
  )
1041
 
 
 
 
 
 
 
1042
  # Initialize weights and apply final processing
1043
  self.post_init()
1044
 
@@ -1228,4 +1240,4 @@ try:
1228
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1229
  except:
1230
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1231
- warn("Update to transformers >= 4.35.2 to fix.")
 
411
  self.block_size = config.block_size
412
 
413
  def forward(
414
+ self,
415
+ input_ids: Optional[torch.LongTensor] = None,
416
+ token_type_ids: Optional[torch.LongTensor] = None,
417
+ position_ids: Optional[torch.LongTensor] = None,
418
+ inputs_embeds: Optional[torch.FloatTensor] = None,
419
+ past_key_values_length: int = 0,
420
+ ) -> torch.Tensor:
421
  if input_ids is not None:
422
  input_shape = input_ids.size()
423
  else:
 
1010
  encoder_outputs.last_hidden_state = sequence_output
1011
  return encoder_outputs
1012
 
1013
+
1014
  class LSGBertPreTrainedModel(BertPreTrainedModel):
1015
  """
1016
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
 
1045
  "Cross attention is computed using full attention since it is not LSG compatible."
1046
  )
1047
 
1048
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1049
+ if self._use_flash_attention_2:
1050
+ logger.warning(
1051
+ "[WARNING flash-attention]: LSG doesnt support flash-attention currently"
1052
+ )
1053
+
1054
  # Initialize weights and apply final processing
1055
  self.post_init()
1056
 
 
1240
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1241
  except:
1242
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1243
+ warn("Update to transformers >= 4.36.1 to fix.")