Feat: Warns to add to modules_to_save when adding tokens or switching special_tokens (#787)
Browse files* Feat: Auto add to modules_to_save when adding tokens
* fix: swap to error instead of warning
* feat: add check when special_tokens differ and add test
- src/axolotl/utils/config.py +14 -0
- src/axolotl/utils/models.py +17 -0
- tests/test_tokenizers.py +36 -0
- tests/test_validation.py +37 -0
src/axolotl/utils/config.py
CHANGED
@@ -448,6 +448,20 @@ def validate_config(cfg):
|
|
448 |
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
449 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
# TODO
|
452 |
# MPT 7b
|
453 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
448 |
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
449 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
450 |
|
451 |
+
if (
|
452 |
+
cfg.adapter
|
453 |
+
and cfg.tokens
|
454 |
+
and (
|
455 |
+
not cfg.lora_modules_to_save
|
456 |
+
or not all(
|
457 |
+
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
458 |
+
)
|
459 |
+
)
|
460 |
+
):
|
461 |
+
raise ValueError(
|
462 |
+
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
463 |
+
)
|
464 |
+
|
465 |
# TODO
|
466 |
# MPT 7b
|
467 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/models.py
CHANGED
@@ -136,6 +136,23 @@ def load_tokenizer(cfg):
|
|
136 |
|
137 |
if cfg.special_tokens:
|
138 |
for k, val in cfg.special_tokens.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
tokenizer.add_special_tokens(
|
140 |
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
141 |
)
|
|
|
136 |
|
137 |
if cfg.special_tokens:
|
138 |
for k, val in cfg.special_tokens.items():
|
139 |
+
# check if new special token is not already in tokenizer and
|
140 |
+
# is adapter training to make sure lora_modules_to_save is set
|
141 |
+
if (
|
142 |
+
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
143 |
+
and cfg.adapter
|
144 |
+
and (
|
145 |
+
not cfg.lora_modules_to_save
|
146 |
+
or not all(
|
147 |
+
x in cfg.lora_modules_to_save
|
148 |
+
for x in ["embed_tokens", "lm_head"]
|
149 |
+
)
|
150 |
+
)
|
151 |
+
):
|
152 |
+
raise ValueError(
|
153 |
+
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
154 |
+
)
|
155 |
+
|
156 |
tokenizer.add_special_tokens(
|
157 |
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
158 |
)
|
tests/test_tokenizers.py
CHANGED
@@ -3,6 +3,8 @@ Test cases for the tokenizer loading
|
|
3 |
"""
|
4 |
import unittest
|
5 |
|
|
|
|
|
6 |
from axolotl.utils.dict import DictDefault
|
7 |
from axolotl.utils.models import load_tokenizer
|
8 |
|
@@ -31,6 +33,40 @@ class TestTokenizers(unittest.TestCase):
|
|
31 |
tokenizer = load_tokenizer(cfg)
|
32 |
assert "Fast" not in tokenizer.__class__.__name__
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
unittest.main()
|
|
|
3 |
"""
|
4 |
import unittest
|
5 |
|
6 |
+
import pytest
|
7 |
+
|
8 |
from axolotl.utils.dict import DictDefault
|
9 |
from axolotl.utils.models import load_tokenizer
|
10 |
|
|
|
33 |
tokenizer = load_tokenizer(cfg)
|
34 |
assert "Fast" not in tokenizer.__class__.__name__
|
35 |
|
36 |
+
def test_special_tokens_modules_to_save(self):
|
37 |
+
# setting special_tokens to new token
|
38 |
+
cfg = DictDefault(
|
39 |
+
{
|
40 |
+
"tokenizer_config": "huggyllama/llama-7b",
|
41 |
+
"adapter": "lora",
|
42 |
+
"special_tokens": {"bos_token": "[INST]"},
|
43 |
+
}
|
44 |
+
)
|
45 |
+
with pytest.raises(
|
46 |
+
ValueError,
|
47 |
+
match=r".*Please set lora_modules_to_save*",
|
48 |
+
):
|
49 |
+
load_tokenizer(cfg)
|
50 |
+
|
51 |
+
# setting special_tokens but not changing from default
|
52 |
+
cfg = DictDefault(
|
53 |
+
{
|
54 |
+
"tokenizer_config": "huggyllama/llama-7b",
|
55 |
+
"adapter": "lora",
|
56 |
+
"special_tokens": {"bos_token": "<s>"},
|
57 |
+
}
|
58 |
+
)
|
59 |
+
load_tokenizer(cfg)
|
60 |
+
|
61 |
+
# non-adapter setting special_tokens
|
62 |
+
cfg = DictDefault(
|
63 |
+
{
|
64 |
+
"tokenizer_config": "huggyllama/llama-7b",
|
65 |
+
"special_tokens": {"bos_token": "[INST]"},
|
66 |
+
}
|
67 |
+
)
|
68 |
+
load_tokenizer(cfg)
|
69 |
+
|
70 |
|
71 |
if __name__ == "__main__":
|
72 |
unittest.main()
|
tests/test_validation.py
CHANGED
@@ -682,6 +682,43 @@ class ValidationTest(unittest.TestCase):
|
|
682 |
|
683 |
validate_config(cfg)
|
684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
|
686 |
class ValidationWandbTest(ValidationTest):
|
687 |
"""
|
|
|
682 |
|
683 |
validate_config(cfg)
|
684 |
|
685 |
+
def test_add_tokens_adapter(self):
|
686 |
+
cfg = DictDefault(
|
687 |
+
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
688 |
+
)
|
689 |
+
|
690 |
+
with pytest.raises(
|
691 |
+
ValueError,
|
692 |
+
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
693 |
+
):
|
694 |
+
validate_config(cfg)
|
695 |
+
|
696 |
+
cfg = DictDefault(
|
697 |
+
{
|
698 |
+
"adapter": "qlora",
|
699 |
+
"load_in_4bit": True,
|
700 |
+
"tokens": ["<|imstart|>"],
|
701 |
+
"lora_modules_to_save": ["embed_tokens"],
|
702 |
+
}
|
703 |
+
)
|
704 |
+
|
705 |
+
with pytest.raises(
|
706 |
+
ValueError,
|
707 |
+
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
708 |
+
):
|
709 |
+
validate_config(cfg)
|
710 |
+
|
711 |
+
cfg = DictDefault(
|
712 |
+
{
|
713 |
+
"adapter": "qlora",
|
714 |
+
"load_in_4bit": True,
|
715 |
+
"tokens": ["<|imstart|>"],
|
716 |
+
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
717 |
+
}
|
718 |
+
)
|
719 |
+
|
720 |
+
validate_config(cfg)
|
721 |
+
|
722 |
|
723 |
class ValidationWandbTest(ValidationTest):
|
724 |
"""
|