fix for transformers >= 4.35.2
Browse files- README.md +1 -1
- modeling_lsg_bart.py +3 -3
README.md
CHANGED
@@ -9,7 +9,7 @@ pipeline_tag: fill-mask
|
|
9 |
---
|
10 |
|
11 |
# LSG model
|
12 |
-
**Transformers >= 4.
|
13 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
14 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
15 |
|
|
|
9 |
---
|
10 |
|
11 |
# LSG model
|
12 |
+
**Transformers >= 4.35.2**\
|
13 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
14 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
15 |
|
modeling_lsg_bart.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from logging import warn
|
2 |
import torch
|
3 |
from transformers.models.bart.modeling_bart import *
|
4 |
-
from transformers.
|
5 |
import torch.nn as nn
|
6 |
import sys
|
7 |
|
@@ -852,7 +852,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
852 |
# expand attention_mask
|
853 |
if attention_mask is not None:
|
854 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
855 |
-
attention_mask =
|
856 |
|
857 |
encoder_states = () if output_hidden_states else None
|
858 |
all_attentions = () if output_attentions else None
|
@@ -1093,4 +1093,4 @@ try:
|
|
1093 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1094 |
except:
|
1095 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1096 |
-
warn("Update to transformers >= 4.
|
|
|
1 |
from logging import warn
|
2 |
import torch
|
3 |
from transformers.models.bart.modeling_bart import *
|
4 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
5 |
import torch.nn as nn
|
6 |
import sys
|
7 |
|
|
|
852 |
# expand attention_mask
|
853 |
if attention_mask is not None:
|
854 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
855 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
856 |
|
857 |
encoder_states = () if output_hidden_states else None
|
858 |
all_attentions = () if output_attentions else None
|
|
|
1093 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1094 |
except:
|
1095 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1096 |
+
warn("Update to transformers >= 4.35.2 to fix.")
|