pszemraj commited on
Commit
304970e
1 Parent(s): 8267bf4

initial support for device_map=auto

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +3 -1
modeling_mpt.py CHANGED
@@ -23,7 +23,9 @@ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
-
 
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
29
  def __init__(self, config: MPTConfig):
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
+ supports_gradient_checkpointing = True
27
+ _no_split_modules = []
28
+
29
  class MPTModel(MPTPreTrainedModel):
30
 
31
  def __init__(self, config: MPTConfig):