Nanobit commited on
Commit
52dd92a
1 Parent(s): 8888959

Feat: Update validate_config and add tests

Browse files
src/axolotl/utils/validation.py CHANGED
@@ -3,24 +3,39 @@ import logging
3
 
4
  def validate_config(cfg):
5
  if cfg.load_4bit:
6
- raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq")
 
 
7
 
8
  if cfg.adapter == "qlora":
9
  if cfg.merge_lora:
10
  # can't merge qlora if loaded in 8bit or 4bit
11
- assert cfg.load_in_8bit is not True
12
- assert cfg.gptq is not True
13
- assert cfg.load_in_4bit is not True
 
 
 
 
 
 
14
  else:
15
- assert cfg.load_in_8bit is not True
16
- assert cfg.gptq is not True
17
- assert cfg.load_in_4bit is True
 
 
 
 
 
18
 
19
  if not cfg.load_in_8bit and cfg.adapter == "lora":
20
  logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
21
-
22
  if cfg.trust_remote_code:
23
- logging.warning("`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.")
 
 
24
 
25
  # TODO
26
  # MPT 7b
 
3
 
4
  def validate_config(cfg):
5
  if cfg.load_4bit:
6
+ raise ValueError(
7
+ "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
8
+ )
9
 
10
  if cfg.adapter == "qlora":
11
  if cfg.merge_lora:
12
  # can't merge qlora if loaded in 8bit or 4bit
13
+ if cfg.load_in_8bit:
14
+ raise ValueError("Can't merge qlora if loaded in 8bit")
15
+
16
+ if cfg.gptq:
17
+ raise ValueError("Can't merge qlora if gptq")
18
+
19
+ if cfg.load_in_4bit:
20
+ raise ValueError("Can't merge qlora if loaded in 4bit")
21
+
22
  else:
23
+ if cfg.load_in_8bit:
24
+ raise ValueError("Can't load qlora in 8bit")
25
+
26
+ if cfg.gptq:
27
+ raise ValueError("Can't load qlora if gptq")
28
+
29
+ if not cfg.load_in_4bit:
30
+ raise ValueError("Require cfg.load_in_4bit to be True for qlora")
31
 
32
  if not cfg.load_in_8bit and cfg.adapter == "lora":
33
  logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
34
+
35
  if cfg.trust_remote_code:
36
+ logging.warning(
37
+ "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
38
+ )
39
 
40
  # TODO
41
  # MPT 7b
tests/test_validation.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ import pytest
4
+
5
+ from axolotl.utils.validation import validate_config
6
+ from axolotl.utils.dict import DictDefault
7
+
8
+
9
+ class ValidationTest(unittest.TestCase):
10
+ def test_load_4bit_deprecate(self):
11
+ cfg = DictDefault(
12
+ {
13
+ "load_4bit": True,
14
+ }
15
+ )
16
+
17
+ with pytest.raises(ValueError):
18
+ validate_config(cfg)
19
+
20
+ def test_qlora(self):
21
+ base_cfg = DictDefault(
22
+ {
23
+ "adapter": "qlora",
24
+ }
25
+ )
26
+
27
+ cfg = base_cfg | DictDefault(
28
+ {
29
+ "load_in_8bit": True,
30
+ }
31
+ )
32
+
33
+ with pytest.raises(ValueError, match=r".*8bit.*"):
34
+ validate_config(cfg)
35
+
36
+ cfg = base_cfg | DictDefault(
37
+ {
38
+ "gptq": True,
39
+ }
40
+ )
41
+
42
+ with pytest.raises(ValueError, match=r".*gptq.*"):
43
+ validate_config(cfg)
44
+
45
+ cfg = base_cfg | DictDefault(
46
+ {
47
+ "load_in_4bit": False,
48
+ }
49
+ )
50
+
51
+ with pytest.raises(ValueError, match=r".*4bit.*"):
52
+ validate_config(cfg)
53
+
54
+ cfg = base_cfg | DictDefault(
55
+ {
56
+ "load_in_4bit": True,
57
+ }
58
+ )
59
+
60
+ validate_config(cfg)
61
+
62
+ def test_qlora_merge(self):
63
+ base_cfg = DictDefault(
64
+ {
65
+ "adapter": "qlora",
66
+ "merge_lora": True,
67
+ }
68
+ )
69
+
70
+ cfg = base_cfg | DictDefault(
71
+ {
72
+ "load_in_8bit": True,
73
+ }
74
+ )
75
+
76
+ with pytest.raises(ValueError, match=r".*8bit.*"):
77
+ validate_config(cfg)
78
+
79
+ cfg = base_cfg | DictDefault(
80
+ {
81
+ "gptq": True,
82
+ }
83
+ )
84
+
85
+ with pytest.raises(ValueError, match=r".*gptq.*"):
86
+ validate_config(cfg)
87
+
88
+ cfg = base_cfg | DictDefault(
89
+ {
90
+ "load_in_4bit": True,
91
+ }
92
+ )
93
+
94
+ with pytest.raises(ValueError, match=r".*4bit.*"):
95
+ validate_config(cfg)