recast loralayer, norm, lmhead + embed token weights per original qlora (#393)
Browse files* recast loralayer, norm, lmhead + embed token weights per original qlora
* try again for the fix
* refactor torch dtype picking
* linter fixes
* missing import for LoraLayer
* fix install for tests now that peft is involved
- .github/workflows/tests.yml +1 -1
- setup.py +3 -0
- src/axolotl/utils/config.py +7 -0
- src/axolotl/utils/models.py +22 -21
.github/workflows/tests.yml
CHANGED
@@ -24,7 +24,7 @@ jobs:
|
|
24 |
|
25 |
- name: Install dependencies
|
26 |
run: |
|
27 |
-
pip install -e .
|
28 |
pip install -r requirements-tests.txt
|
29 |
|
30 |
- name: Run tests
|
|
|
24 |
|
25 |
- name: Install dependencies
|
26 |
run: |
|
27 |
+
pip install -e .[peft]
|
28 |
pip install -r requirements-tests.txt
|
29 |
|
30 |
- name: Run tests
|
setup.py
CHANGED
@@ -32,5 +32,8 @@ setup(
|
|
32 |
"extras": [
|
33 |
"deepspeed",
|
34 |
],
|
|
|
|
|
|
|
35 |
},
|
36 |
)
|
|
|
32 |
"extras": [
|
33 |
"deepspeed",
|
34 |
],
|
35 |
+
"peft": [
|
36 |
+
"peft @ git+https://github.com/huggingface/peft.git",
|
37 |
+
],
|
38 |
},
|
39 |
)
|
src/axolotl/utils/config.py
CHANGED
@@ -62,6 +62,13 @@ def normalize_config(cfg):
|
|
62 |
else:
|
63 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
66 |
|
67 |
|
|
|
62 |
else:
|
63 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
64 |
|
65 |
+
if cfg.bf16 or cfg.bfloat16:
|
66 |
+
cfg.torch_dtype = torch.bfloat16
|
67 |
+
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
68 |
+
cfg.torch_dtype = torch.float16
|
69 |
+
else:
|
70 |
+
cfg.torch_dtype = torch.float32
|
71 |
+
|
72 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
73 |
|
74 |
|
src/axolotl/utils/models.py
CHANGED
@@ -11,6 +11,7 @@ import bitsandbytes as bnb
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
from optimum.bettertransformer import BetterTransformer
|
|
|
14 |
from transformers import ( # noqa: F401
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
@@ -146,12 +147,6 @@ def load_model(
|
|
146 |
LOG.info("patching _expand_mask")
|
147 |
hijack_expand_mask()
|
148 |
|
149 |
-
if cfg.bf16 or cfg.bfloat16:
|
150 |
-
torch_dtype = torch.bfloat16
|
151 |
-
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
152 |
-
torch_dtype = torch.float16
|
153 |
-
else:
|
154 |
-
torch_dtype = torch.float32
|
155 |
try:
|
156 |
if cfg.gptq:
|
157 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
@@ -183,7 +178,7 @@ def load_model(
|
|
183 |
load_in_4bit=True,
|
184 |
llm_int8_threshold=6.0,
|
185 |
llm_int8_has_fp16_weight=False,
|
186 |
-
bnb_4bit_compute_dtype=torch_dtype,
|
187 |
bnb_4bit_use_double_quant=True,
|
188 |
bnb_4bit_quant_type="nf4",
|
189 |
)
|
@@ -242,7 +237,7 @@ def load_model(
|
|
242 |
device_map=cfg.device_map,
|
243 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
244 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
245 |
-
torch_dtype=torch_dtype,
|
246 |
**model_kwargs,
|
247 |
)
|
248 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
@@ -277,7 +272,7 @@ def load_model(
|
|
277 |
device_map=cfg.device_map,
|
278 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
279 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
280 |
-
torch_dtype=torch_dtype,
|
281 |
trust_remote_code=cfg.trust_remote_code or False,
|
282 |
**model_kwargs,
|
283 |
)
|
@@ -308,7 +303,7 @@ def load_model(
|
|
308 |
device_map=cfg.device_map,
|
309 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
310 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
311 |
-
torch_dtype=torch_dtype,
|
312 |
trust_remote_code=cfg.trust_remote_code or False,
|
313 |
**model_kwargs,
|
314 |
)
|
@@ -322,7 +317,7 @@ def load_model(
|
|
322 |
device_map=cfg.device_map,
|
323 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
324 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
325 |
-
torch_dtype=torch_dtype,
|
326 |
trust_remote_code=cfg.trust_remote_code or False,
|
327 |
**model_kwargs,
|
328 |
)
|
@@ -356,16 +351,6 @@ def load_model(
|
|
356 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
357 |
)
|
358 |
|
359 |
-
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
360 |
-
# convert them back to fp16/bf16 for flash-attn compatibility.
|
361 |
-
if cfg.flash_attention and cfg.is_llama_derived_model:
|
362 |
-
for name, module in model.named_modules():
|
363 |
-
if "norm" in name:
|
364 |
-
module.to(torch_dtype)
|
365 |
-
if "lm_head" in name or "embed_tokens" in name:
|
366 |
-
if hasattr(module, "weight"):
|
367 |
-
module.to(torch_dtype)
|
368 |
-
|
369 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
370 |
|
371 |
if cfg.ddp and not load_in_8bit:
|
@@ -509,6 +494,22 @@ def load_lora(model, cfg):
|
|
509 |
else:
|
510 |
model = get_peft_model(model, lora_config)
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
model.print_trainable_parameters()
|
513 |
|
514 |
return model, lora_config
|
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
from optimum.bettertransformer import BetterTransformer
|
14 |
+
from peft.tuners.lora import LoraLayer
|
15 |
from transformers import ( # noqa: F401
|
16 |
AutoConfig,
|
17 |
AutoModelForCausalLM,
|
|
|
147 |
LOG.info("patching _expand_mask")
|
148 |
hijack_expand_mask()
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
try:
|
151 |
if cfg.gptq:
|
152 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
|
178 |
load_in_4bit=True,
|
179 |
llm_int8_threshold=6.0,
|
180 |
llm_int8_has_fp16_weight=False,
|
181 |
+
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
182 |
bnb_4bit_use_double_quant=True,
|
183 |
bnb_4bit_quant_type="nf4",
|
184 |
)
|
|
|
237 |
device_map=cfg.device_map,
|
238 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
239 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
240 |
+
torch_dtype=cfg.torch_dtype,
|
241 |
**model_kwargs,
|
242 |
)
|
243 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
|
|
272 |
device_map=cfg.device_map,
|
273 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
274 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
275 |
+
torch_dtype=cfg.torch_dtype,
|
276 |
trust_remote_code=cfg.trust_remote_code or False,
|
277 |
**model_kwargs,
|
278 |
)
|
|
|
303 |
device_map=cfg.device_map,
|
304 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
305 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
306 |
+
torch_dtype=cfg.torch_dtype,
|
307 |
trust_remote_code=cfg.trust_remote_code or False,
|
308 |
**model_kwargs,
|
309 |
)
|
|
|
317 |
device_map=cfg.device_map,
|
318 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
319 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
320 |
+
torch_dtype=cfg.torch_dtype,
|
321 |
trust_remote_code=cfg.trust_remote_code or False,
|
322 |
**model_kwargs,
|
323 |
)
|
|
|
351 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
352 |
)
|
353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
355 |
|
356 |
if cfg.ddp and not load_in_8bit:
|
|
|
494 |
else:
|
495 |
model = get_peft_model(model, lora_config)
|
496 |
|
497 |
+
for name, module in model.named_modules():
|
498 |
+
if isinstance(module, LoraLayer):
|
499 |
+
module = module.to(cfg.torch_dtype)
|
500 |
+
if "norm" in name:
|
501 |
+
module = module.to(torch.float32)
|
502 |
+
if "lm_head" in name or "embed_tokens" in name:
|
503 |
+
if hasattr(module, "weight"):
|
504 |
+
module = module.to(cfg.torch_dtype)
|
505 |
+
|
506 |
+
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
507 |
+
# convert them back to fp16/bf16 for flash-attn compatibility.
|
508 |
+
if cfg.flash_attention and cfg.is_llama_derived_model:
|
509 |
+
for name, module in model.named_modules():
|
510 |
+
if "norm" in name:
|
511 |
+
module = module.to(cfg.torch_dtype)
|
512 |
+
|
513 |
model.print_trainable_parameters()
|
514 |
|
515 |
return model, lora_config
|