update for transformers >= 4.29.1
Browse files- modeling_lsg_pegasus.py +15 -9
modeling_lsg_pegasus.py
CHANGED
@@ -678,6 +678,8 @@ class LSGPegasusEncoderLayer(PegasusEncoderLayer):
|
|
678 |
class LSGPegasusPreTrainedModel(PegasusPreTrainedModel):
|
679 |
|
680 |
config_class = LSGPegasusConfig
|
|
|
|
|
681 |
|
682 |
def _set_gradient_checkpointing(self, module, value=False):
|
683 |
if isinstance(module, (PegasusDecoder, PegasusEncoder, LSGPegasusEncoder)):
|
@@ -880,8 +882,13 @@ class LSGPegasusEncoder(LSGPegasusPreTrainedModel, PegasusEncoder):
|
|
880 |
if output_hidden_states:
|
881 |
encoder_states = encoder_states + (hidden_states,)
|
882 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
883 |
-
|
884 |
-
if self.training
|
|
|
|
|
|
|
|
|
|
|
885 |
layer_outputs = (None, None)
|
886 |
else:
|
887 |
if self.gradient_checkpointing and self.training:
|
@@ -925,6 +932,8 @@ class LSGPegasusEncoder(LSGPegasusPreTrainedModel, PegasusEncoder):
|
|
925 |
|
926 |
class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
|
927 |
|
|
|
|
|
928 |
def __init__(self, config: LSGPegasusConfig):
|
929 |
|
930 |
LSGPegasusPreTrainedModel.__init__(self, config)
|
@@ -1032,13 +1041,8 @@ class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
|
|
1032 |
class LSGPegasusForConditionalGeneration(LSGPegasusPreTrainedModel, PegasusForConditionalGeneration):
|
1033 |
|
1034 |
base_model_prefix = "model"
|
1035 |
-
_keys_to_ignore_on_load_missing = [
|
1036 |
-
|
1037 |
-
r"encoder\.version",
|
1038 |
-
r"decoder\.version",
|
1039 |
-
r"lm_head\.weight",
|
1040 |
-
r"embed_positions\.weight",
|
1041 |
-
]
|
1042 |
|
1043 |
def __init__(self, config: LSGPegasusConfig):
|
1044 |
|
@@ -1065,6 +1069,8 @@ class LSGPegasusDecoderWrapper(LSGPegasusPreTrainedModel, PegasusDecoderWrapper)
|
|
1065 |
|
1066 |
class LSGPegasusForCausalLM(LSGPegasusPreTrainedModel, PegasusForCausalLM):
|
1067 |
|
|
|
|
|
1068 |
def __init__(self, config):
|
1069 |
|
1070 |
LSGPegasusPreTrainedModel.__init__(self, config)
|
|
|
678 |
class LSGPegasusPreTrainedModel(PegasusPreTrainedModel):
|
679 |
|
680 |
config_class = LSGPegasusConfig
|
681 |
+
base_model_prefix = "model"
|
682 |
+
supports_gradient_checkpointing = True
|
683 |
|
684 |
def _set_gradient_checkpointing(self, module, value=False):
|
685 |
if isinstance(module, (PegasusDecoder, PegasusEncoder, LSGPegasusEncoder)):
|
|
|
882 |
if output_hidden_states:
|
883 |
encoder_states = encoder_states + (hidden_states,)
|
884 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
885 |
+
to_drop = False
|
886 |
+
if self.training:
|
887 |
+
dropout_probability = torch.rand([])
|
888 |
+
if dropout_probability < self.layerdrop: # skip the layer
|
889 |
+
to_drop = True
|
890 |
+
|
891 |
+
if to_drop:
|
892 |
layer_outputs = (None, None)
|
893 |
else:
|
894 |
if self.gradient_checkpointing and self.training:
|
|
|
932 |
|
933 |
class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
|
934 |
|
935 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
936 |
+
|
937 |
def __init__(self, config: LSGPegasusConfig):
|
938 |
|
939 |
LSGPegasusPreTrainedModel.__init__(self, config)
|
|
|
1041 |
class LSGPegasusForConditionalGeneration(LSGPegasusPreTrainedModel, PegasusForConditionalGeneration):
|
1042 |
|
1043 |
base_model_prefix = "model"
|
1044 |
+
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
1045 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
|
|
|
|
|
|
|
|
|
|
1046 |
|
1047 |
def __init__(self, config: LSGPegasusConfig):
|
1048 |
|
|
|
1069 |
|
1070 |
class LSGPegasusForCausalLM(LSGPegasusPreTrainedModel, PegasusForCausalLM):
|
1071 |
|
1072 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1073 |
+
|
1074 |
def __init__(self, config):
|
1075 |
|
1076 |
LSGPegasusPreTrainedModel.__init__(self, config)
|