ccdv commited on
Commit
2aa24de
1 Parent(s): 2042dbf
Files changed (2) hide show
  1. README.md +1 -1
  2. modeling_lsg_pegasus.py +7 -1
README.md CHANGED
@@ -9,7 +9,7 @@ pipeline_tag: fill-mask
9
  ---
10
 
11
  # LSG model
12
- **Transformers >= 4.35.2**\
13
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
14
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
15
 
9
  ---
10
 
11
  # LSG model
12
+ **Transformers >= 4.36.1**\
13
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
14
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
15
 
modeling_lsg_pegasus.py CHANGED
@@ -972,6 +972,12 @@ class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
972
  self.encoder = LSGPegasusEncoder(config, self.shared)
973
  self.decoder = PegasusDecoder(config, self.shared)
974
 
 
 
 
 
 
 
975
  # Initialize weights and apply final processing
976
  self.post_init()
977
 
@@ -1122,4 +1128,4 @@ try:
1122
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1123
  except:
1124
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1125
- warn("Update to transformers >= 4.35.2 to fix.")
972
  self.encoder = LSGPegasusEncoder(config, self.shared)
973
  self.decoder = PegasusDecoder(config, self.shared)
974
 
975
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
976
+ if self._use_flash_attention_2:
977
+ logger.warning(
978
+ "[WARNING flash-attention]: LSG doesnt support flash-attention currently"
979
+ )
980
+
981
  # Initialize weights and apply final processing
982
  self.post_init()
983
 
1128
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1129
  except:
1130
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1131
+ warn("Update to transformers >= 4.36.1 to fix.")