Update modeling_lsg_mbart.py
Browse files- modeling_lsg_mbart.py +0 -4
modeling_lsg_mbart.py
CHANGED
@@ -673,10 +673,6 @@ class LSGMBartPretrainedModel(MBartPreTrainedModel):
|
|
673 |
base_model_prefix = "model"
|
674 |
supports_gradient_checkpointing = True
|
675 |
|
676 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
677 |
-
if isinstance(module, (MBartDecoder, MBartEncoder, LSGMBartEncoder)):
|
678 |
-
module.gradient_checkpointing = value
|
679 |
-
|
680 |
|
681 |
class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
682 |
"""
|
|
|
673 |
base_model_prefix = "model"
|
674 |
supports_gradient_checkpointing = True
|
675 |
|
|
|
|
|
|
|
|
|
676 |
|
677 |
class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
678 |
"""
|