keep gate in fp32 for 16 bit loras (#1105)
Browse files* keep gate in fp32 for loras
* add e2e check for lora w/o flash attention for mixtral to check gate
* add checks for gate in fp32 for mixtral, add typehints to train outputs
* mixtral doesn't support basic lora :facepalm:
add lora tests @ 16bit and fix gate layer check
fix the parameter name, was using the old disco name
don't lora over the gate so we can check that is in fp32
fix dtype check
* ensure we're using fp16/bf16 for 16bit and qlora is always going to be in uint8
- src/axolotl/train.py +4 -2
- src/axolotl/utils/models.py +1 -1
- tests/e2e/test_mixtral.py +186 -5
src/axolotl/train.py
CHANGED
@@ -5,14 +5,16 @@ import signal
|
|
5 |
import sys
|
6 |
from dataclasses import dataclass
|
7 |
from pathlib import Path
|
8 |
-
from typing import Optional
|
9 |
|
10 |
import torch
|
11 |
import transformers.modelcard
|
12 |
from accelerate.logging import get_logger
|
13 |
from datasets import Dataset
|
14 |
from optimum.bettertransformer import BetterTransformer
|
|
|
15 |
from pkg_resources import get_distribution # type: ignore
|
|
|
16 |
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
17 |
|
18 |
from axolotl.common.cli import TrainerCliArgs
|
@@ -43,7 +45,7 @@ class TrainDatasetMeta:
|
|
43 |
|
44 |
def train(
|
45 |
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
46 |
-
):
|
47 |
# load the tokenizer first
|
48 |
LOG.debug(
|
49 |
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
|
|
5 |
import sys
|
6 |
from dataclasses import dataclass
|
7 |
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
|
10 |
import torch
|
11 |
import transformers.modelcard
|
12 |
from accelerate.logging import get_logger
|
13 |
from datasets import Dataset
|
14 |
from optimum.bettertransformer import BetterTransformer
|
15 |
+
from peft import PeftModel
|
16 |
from pkg_resources import get_distribution # type: ignore
|
17 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
18 |
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
19 |
|
20 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
45 |
|
46 |
def train(
|
47 |
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
48 |
+
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
49 |
# load the tokenizer first
|
50 |
LOG.debug(
|
51 |
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
src/axolotl/utils/models.py
CHANGED
@@ -590,7 +590,7 @@ def load_model(
|
|
590 |
# make sure these are fp32 per Ramesh et al. (2021)
|
591 |
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
592 |
for name, module in model.named_modules():
|
593 |
-
if "norm"
|
594 |
module.to(torch.float32)
|
595 |
if model_config.model_type == "btlm":
|
596 |
# don't upcast lm_head for btlm
|
|
|
590 |
# make sure these are fp32 per Ramesh et al. (2021)
|
591 |
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
592 |
for name, module in model.named_modules():
|
593 |
+
if any(m in name for m in ["norm", "gate"]):
|
594 |
module.to(torch.float32)
|
595 |
if model_config.model_type == "btlm":
|
596 |
# don't upcast lm_head for btlm
|
tests/e2e/test_mixtral.py
CHANGED
@@ -7,6 +7,7 @@ import os
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
10 |
from transformers.utils import is_torch_bf16_gpu_available
|
11 |
|
12 |
from axolotl.cli import load_datasets
|
@@ -27,7 +28,7 @@ class TestMixtral(unittest.TestCase):
|
|
27 |
"""
|
28 |
|
29 |
@with_temp_dir
|
30 |
-
def
|
31 |
# pylint: disable=duplicate-code
|
32 |
cfg = DictDefault(
|
33 |
{
|
@@ -37,10 +38,18 @@ class TestMixtral(unittest.TestCase):
|
|
37 |
"sequence_len": 1024,
|
38 |
"load_in_4bit": True,
|
39 |
"adapter": "qlora",
|
40 |
-
"lora_r":
|
41 |
-
"lora_alpha":
|
42 |
"lora_dropout": 0.1,
|
43 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
"val_set_size": 0.1,
|
45 |
"special_tokens": {},
|
46 |
"datasets": [
|
@@ -65,7 +74,179 @@ class TestMixtral(unittest.TestCase):
|
|
65 |
cli_args = TrainerCliArgs()
|
66 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
67 |
|
68 |
-
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
70 |
|
71 |
@with_temp_dir
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
10 |
+
import torch
|
11 |
from transformers.utils import is_torch_bf16_gpu_available
|
12 |
|
13 |
from axolotl.cli import load_datasets
|
|
|
28 |
"""
|
29 |
|
30 |
@with_temp_dir
|
31 |
+
def test_qlora_w_fa2(self, temp_dir):
|
32 |
# pylint: disable=duplicate-code
|
33 |
cfg = DictDefault(
|
34 |
{
|
|
|
38 |
"sequence_len": 1024,
|
39 |
"load_in_4bit": True,
|
40 |
"adapter": "qlora",
|
41 |
+
"lora_r": 4,
|
42 |
+
"lora_alpha": 8,
|
43 |
"lora_dropout": 0.1,
|
44 |
+
"lora_target_modules": [
|
45 |
+
"o_proj",
|
46 |
+
"w3",
|
47 |
+
"k_proj",
|
48 |
+
"v_proj",
|
49 |
+
"w1",
|
50 |
+
"q_proj",
|
51 |
+
"w2",
|
52 |
+
],
|
53 |
"val_set_size": 0.1,
|
54 |
"special_tokens": {},
|
55 |
"datasets": [
|
|
|
74 |
cli_args = TrainerCliArgs()
|
75 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
76 |
|
77 |
+
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
78 |
+
assert (
|
79 |
+
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
80 |
+
== torch.uint8
|
81 |
+
)
|
82 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
83 |
+
|
84 |
+
@with_temp_dir
|
85 |
+
def test_qlora_wo_fa2(self, temp_dir):
|
86 |
+
# pylint: disable=duplicate-code
|
87 |
+
cfg = DictDefault(
|
88 |
+
{
|
89 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
90 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
91 |
+
"flash_attention": False,
|
92 |
+
"sequence_len": 1024,
|
93 |
+
"load_in_4bit": True,
|
94 |
+
"adapter": "qlora",
|
95 |
+
"lora_r": 4,
|
96 |
+
"lora_alpha": 8,
|
97 |
+
"lora_dropout": 0.1,
|
98 |
+
"lora_target_modules": [
|
99 |
+
"o_proj",
|
100 |
+
"w3",
|
101 |
+
"k_proj",
|
102 |
+
"v_proj",
|
103 |
+
"w1",
|
104 |
+
"q_proj",
|
105 |
+
"w2",
|
106 |
+
],
|
107 |
+
"val_set_size": 0.1,
|
108 |
+
"special_tokens": {},
|
109 |
+
"datasets": [
|
110 |
+
{
|
111 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
112 |
+
"type": "alpaca",
|
113 |
+
},
|
114 |
+
],
|
115 |
+
"num_epochs": 2,
|
116 |
+
"micro_batch_size": 2,
|
117 |
+
"gradient_accumulation_steps": 1,
|
118 |
+
"output_dir": temp_dir,
|
119 |
+
"learning_rate": 0.00001,
|
120 |
+
"optimizer": "adamw_bnb_8bit",
|
121 |
+
"lr_scheduler": "cosine",
|
122 |
+
"max_steps": 20,
|
123 |
+
"save_steps": 10,
|
124 |
+
"eval_steps": 10,
|
125 |
+
}
|
126 |
+
)
|
127 |
+
normalize_config(cfg)
|
128 |
+
cli_args = TrainerCliArgs()
|
129 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
130 |
+
|
131 |
+
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
132 |
+
assert (
|
133 |
+
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
134 |
+
== torch.uint8
|
135 |
+
)
|
136 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
137 |
+
|
138 |
+
@with_temp_dir
|
139 |
+
def test_16bit_lora_w_fa2(self, temp_dir):
|
140 |
+
# pylint: disable=duplicate-code
|
141 |
+
cfg = DictDefault(
|
142 |
+
{
|
143 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
144 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
145 |
+
"flash_attention": True,
|
146 |
+
"sequence_len": 1024,
|
147 |
+
"adapter": "lora",
|
148 |
+
"lora_r": 4,
|
149 |
+
"lora_alpha": 8,
|
150 |
+
"lora_dropout": 0.1,
|
151 |
+
"lora_target_modules": [
|
152 |
+
"o_proj",
|
153 |
+
"w3",
|
154 |
+
"k_proj",
|
155 |
+
"v_proj",
|
156 |
+
"w1",
|
157 |
+
"q_proj",
|
158 |
+
"w2",
|
159 |
+
],
|
160 |
+
"val_set_size": 0.1,
|
161 |
+
"special_tokens": {},
|
162 |
+
"datasets": [
|
163 |
+
{
|
164 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
165 |
+
"type": "alpaca",
|
166 |
+
},
|
167 |
+
],
|
168 |
+
"num_epochs": 2,
|
169 |
+
"micro_batch_size": 2,
|
170 |
+
"gradient_accumulation_steps": 1,
|
171 |
+
"output_dir": temp_dir,
|
172 |
+
"learning_rate": 0.00001,
|
173 |
+
"optimizer": "adamw_bnb_8bit",
|
174 |
+
"lr_scheduler": "cosine",
|
175 |
+
"max_steps": 20,
|
176 |
+
"save_steps": 10,
|
177 |
+
"eval_steps": 10,
|
178 |
+
}
|
179 |
+
)
|
180 |
+
if is_torch_bf16_gpu_available():
|
181 |
+
cfg.bf16 = True
|
182 |
+
else:
|
183 |
+
cfg.fp16 = True
|
184 |
+
normalize_config(cfg)
|
185 |
+
cli_args = TrainerCliArgs()
|
186 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
187 |
+
|
188 |
+
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
189 |
+
assert (
|
190 |
+
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
191 |
+
== torch.float32
|
192 |
+
)
|
193 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
194 |
+
|
195 |
+
@with_temp_dir
|
196 |
+
def test_16bit_lora_wo_fa2(self, temp_dir):
|
197 |
+
# pylint: disable=duplicate-code
|
198 |
+
cfg = DictDefault(
|
199 |
+
{
|
200 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
201 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
202 |
+
"flash_attention": False,
|
203 |
+
"sequence_len": 1024,
|
204 |
+
"adapter": "lora",
|
205 |
+
"lora_r": 4,
|
206 |
+
"lora_alpha": 8,
|
207 |
+
"lora_dropout": 0.1,
|
208 |
+
"lora_target_modules": [
|
209 |
+
"o_proj",
|
210 |
+
"w3",
|
211 |
+
"k_proj",
|
212 |
+
"v_proj",
|
213 |
+
"w1",
|
214 |
+
"q_proj",
|
215 |
+
"w2",
|
216 |
+
],
|
217 |
+
"val_set_size": 0.1,
|
218 |
+
"special_tokens": {},
|
219 |
+
"datasets": [
|
220 |
+
{
|
221 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
222 |
+
"type": "alpaca",
|
223 |
+
},
|
224 |
+
],
|
225 |
+
"num_epochs": 2,
|
226 |
+
"micro_batch_size": 2,
|
227 |
+
"gradient_accumulation_steps": 1,
|
228 |
+
"output_dir": temp_dir,
|
229 |
+
"learning_rate": 0.00001,
|
230 |
+
"optimizer": "adamw_bnb_8bit",
|
231 |
+
"lr_scheduler": "cosine",
|
232 |
+
"max_steps": 20,
|
233 |
+
"save_steps": 10,
|
234 |
+
"eval_steps": 10,
|
235 |
+
}
|
236 |
+
)
|
237 |
+
normalize_config(cfg)
|
238 |
+
if is_torch_bf16_gpu_available():
|
239 |
+
cfg.bf16 = True
|
240 |
+
else:
|
241 |
+
cfg.fp16 = True
|
242 |
+
cli_args = TrainerCliArgs()
|
243 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
244 |
+
|
245 |
+
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
246 |
+
assert (
|
247 |
+
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
248 |
+
== torch.float32
|
249 |
+
)
|
250 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
251 |
|
252 |
@with_temp_dir
|