new validation for mpt w grad checkpoints
Browse files- src/axolotl/utils/validation.py +5 -0
- tests/test_validation.py +14 -0
src/axolotl/utils/validation.py
CHANGED
@@ -57,6 +57,11 @@ def validate_config(cfg):
|
|
57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
58 |
raise ValueError("FSDP is not supported for falcon models")
|
59 |
|
|
|
|
|
|
|
|
|
|
|
60 |
# TODO
|
61 |
# MPT 7b
|
62 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
58 |
raise ValueError("FSDP is not supported for falcon models")
|
59 |
|
60 |
+
if (
|
61 |
+
cfg.base_model and "mpt" in cfg.base_model.lower()
|
62 |
+
) and cfg.gradient_checkpointing:
|
63 |
+
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
64 |
+
|
65 |
# TODO
|
66 |
# MPT 7b
|
67 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
tests/test_validation.py
CHANGED
@@ -198,3 +198,17 @@ class ValidationTest(unittest.TestCase):
|
|
198 |
)
|
199 |
|
200 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
)
|
199 |
|
200 |
validate_config(cfg)
|
201 |
+
|
202 |
+
def test_mpt_gradient_checkpointing(self):
|
203 |
+
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
|
204 |
+
|
205 |
+
# Check for lower-case
|
206 |
+
cfg = DictDefault(
|
207 |
+
{
|
208 |
+
"base_model": "mosaicml/mpt-7b",
|
209 |
+
"gradient_checkpointing": True,
|
210 |
+
}
|
211 |
+
)
|
212 |
+
|
213 |
+
with pytest.raises(ValueError, match=regex_exp):
|
214 |
+
validate_config(cfg)
|