winglian commited on
Commit
14668fa
1 Parent(s): fe0b768

new validation for mpt w grad checkpoints

Browse files
src/axolotl/utils/validation.py CHANGED
@@ -57,6 +57,11 @@ def validate_config(cfg):
57
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
58
  raise ValueError("FSDP is not supported for falcon models")
59
 
 
 
 
 
 
60
  # TODO
61
  # MPT 7b
62
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
57
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
58
  raise ValueError("FSDP is not supported for falcon models")
59
 
60
+ if (
61
+ cfg.base_model and "mpt" in cfg.base_model.lower()
62
+ ) and cfg.gradient_checkpointing:
63
+ raise ValueError("gradient_checkpointing is not supported for MPT models")
64
+
65
  # TODO
66
  # MPT 7b
67
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/test_validation.py CHANGED
@@ -198,3 +198,17 @@ class ValidationTest(unittest.TestCase):
198
  )
199
 
200
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
 
200
  validate_config(cfg)
201
+
202
+ def test_mpt_gradient_checkpointing(self):
203
+ regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
204
+
205
+ # Check for lower-case
206
+ cfg = DictDefault(
207
+ {
208
+ "base_model": "mosaicml/mpt-7b",
209
+ "gradient_checkpointing": True,
210
+ }
211
+ )
212
+
213
+ with pytest.raises(ValueError, match=regex_exp):
214
+ validate_config(cfg)