small fix
Browse files- README.md +1 -1
- modeling_lsg_bert.py +15 -3
README.md
CHANGED
@@ -7,7 +7,7 @@ pipeline_tag: fill-mask
|
|
7 |
---
|
8 |
|
9 |
# LSG model
|
10 |
-
**Transformers >= 4.
|
11 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
12 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
13 |
|
|
|
7 |
---
|
8 |
|
9 |
# LSG model
|
10 |
+
**Transformers >= 4.36.1**\
|
11 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
12 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
13 |
|
modeling_lsg_bert.py
CHANGED
@@ -411,8 +411,13 @@ class LSGBertEmbeddings(BertEmbeddings):
|
|
411 |
self.block_size = config.block_size
|
412 |
|
413 |
def forward(
|
414 |
-
self,
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
416 |
if input_ids is not None:
|
417 |
input_shape = input_ids.size()
|
418 |
else:
|
@@ -1005,6 +1010,7 @@ class LSGBertEncoder(BertEncoder):
|
|
1005 |
encoder_outputs.last_hidden_state = sequence_output
|
1006 |
return encoder_outputs
|
1007 |
|
|
|
1008 |
class LSGBertPreTrainedModel(BertPreTrainedModel):
|
1009 |
"""
|
1010 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
@@ -1039,6 +1045,12 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1039 |
"Cross attention is computed using full attention since it is not LSG compatible."
|
1040 |
)
|
1041 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1042 |
# Initialize weights and apply final processing
|
1043 |
self.post_init()
|
1044 |
|
@@ -1228,4 +1240,4 @@ try:
|
|
1228 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1229 |
except:
|
1230 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1231 |
-
warn("Update to transformers >= 4.
|
|
|
411 |
self.block_size = config.block_size
|
412 |
|
413 |
def forward(
|
414 |
+
self,
|
415 |
+
input_ids: Optional[torch.LongTensor] = None,
|
416 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
417 |
+
position_ids: Optional[torch.LongTensor] = None,
|
418 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
419 |
+
past_key_values_length: int = 0,
|
420 |
+
) -> torch.Tensor:
|
421 |
if input_ids is not None:
|
422 |
input_shape = input_ids.size()
|
423 |
else:
|
|
|
1010 |
encoder_outputs.last_hidden_state = sequence_output
|
1011 |
return encoder_outputs
|
1012 |
|
1013 |
+
|
1014 |
class LSGBertPreTrainedModel(BertPreTrainedModel):
|
1015 |
"""
|
1016 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
|
1045 |
"Cross attention is computed using full attention since it is not LSG compatible."
|
1046 |
)
|
1047 |
|
1048 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
1049 |
+
if self._use_flash_attention_2:
|
1050 |
+
logger.warning(
|
1051 |
+
"[WARNING flash-attention]: LSG doesnt support flash-attention currently"
|
1052 |
+
)
|
1053 |
+
|
1054 |
# Initialize weights and apply final processing
|
1055 |
self.post_init()
|
1056 |
|
|
|
1240 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1241 |
except:
|
1242 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1243 |
+
warn("Update to transformers >= 4.36.1 to fix.")
|