pszemraj commited on
Commit
ae54cae
1 Parent(s): 9741fe8

✨ gradient checkpointing

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +4 -1
modeling_mpt.py CHANGED
@@ -33,7 +33,10 @@ class MPTPreTrainedModel(PreTrainedModel):
33
  base_model_prefix = "model"
34
  supports_gradient_checkpointing = True
35
  _no_split_modules = ["MPTBlock"]
36
-
 
 
 
37
 
38
  class MPTModel(MPTPreTrainedModel):
39
  def __init__(self, config: MPTConfig):
 
33
  base_model_prefix = "model"
34
  supports_gradient_checkpointing = True
35
  _no_split_modules = ["MPTBlock"]
36
+
37
+ def _set_gradient_checkpointing(self, module, value=False):
38
+ if isinstance(module, MPTModel):
39
+ module.gradient_checkpointing = value
40
 
41
  class MPTModel(MPTPreTrainedModel):
42
  def __init__(self, config: MPTConfig):