winglian commited on
Commit
96deb6b
1 Parent(s): 50682a3

recast loralayer, norm, lmhead + embed token weights per original qlora (#393)

Browse files

* recast loralayer, norm, lmhead + embed token weights per original qlora

* try again for the fix

* refactor torch dtype picking

* linter fixes

* missing import for LoraLayer

* fix install for tests now that peft is involved

.github/workflows/tests.yml CHANGED
@@ -24,7 +24,7 @@ jobs:
24
 
25
  - name: Install dependencies
26
  run: |
27
- pip install -e .
28
  pip install -r requirements-tests.txt
29
 
30
  - name: Run tests
 
24
 
25
  - name: Install dependencies
26
  run: |
27
+ pip install -e .[peft]
28
  pip install -r requirements-tests.txt
29
 
30
  - name: Run tests
setup.py CHANGED
@@ -32,5 +32,8 @@ setup(
32
  "extras": [
33
  "deepspeed",
34
  ],
 
 
 
35
  },
36
  )
 
32
  "extras": [
33
  "deepspeed",
34
  ],
35
+ "peft": [
36
+ "peft @ git+https://github.com/huggingface/peft.git",
37
+ ],
38
  },
39
  )
src/axolotl/utils/config.py CHANGED
@@ -62,6 +62,13 @@ def normalize_config(cfg):
62
  else:
63
  torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
64
 
 
 
 
 
 
 
 
65
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
66
 
67
 
 
62
  else:
63
  torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
64
 
65
+ if cfg.bf16 or cfg.bfloat16:
66
+ cfg.torch_dtype = torch.bfloat16
67
+ elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
68
+ cfg.torch_dtype = torch.float16
69
+ else:
70
+ cfg.torch_dtype = torch.float32
71
+
72
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
73
 
74
 
src/axolotl/utils/models.py CHANGED
@@ -11,6 +11,7 @@ import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
  from optimum.bettertransformer import BetterTransformer
 
14
  from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
@@ -146,12 +147,6 @@ def load_model(
146
  LOG.info("patching _expand_mask")
147
  hijack_expand_mask()
148
 
149
- if cfg.bf16 or cfg.bfloat16:
150
- torch_dtype = torch.bfloat16
151
- elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
152
- torch_dtype = torch.float16
153
- else:
154
- torch_dtype = torch.float32
155
  try:
156
  if cfg.gptq:
157
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -183,7 +178,7 @@ def load_model(
183
  load_in_4bit=True,
184
  llm_int8_threshold=6.0,
185
  llm_int8_has_fp16_weight=False,
186
- bnb_4bit_compute_dtype=torch_dtype,
187
  bnb_4bit_use_double_quant=True,
188
  bnb_4bit_quant_type="nf4",
189
  )
@@ -242,7 +237,7 @@ def load_model(
242
  device_map=cfg.device_map,
243
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
244
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
245
- torch_dtype=torch_dtype,
246
  **model_kwargs,
247
  )
248
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -277,7 +272,7 @@ def load_model(
277
  device_map=cfg.device_map,
278
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
279
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
280
- torch_dtype=torch_dtype,
281
  trust_remote_code=cfg.trust_remote_code or False,
282
  **model_kwargs,
283
  )
@@ -308,7 +303,7 @@ def load_model(
308
  device_map=cfg.device_map,
309
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
310
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
311
- torch_dtype=torch_dtype,
312
  trust_remote_code=cfg.trust_remote_code or False,
313
  **model_kwargs,
314
  )
@@ -322,7 +317,7 @@ def load_model(
322
  device_map=cfg.device_map,
323
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
324
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
325
- torch_dtype=torch_dtype,
326
  trust_remote_code=cfg.trust_remote_code or False,
327
  **model_kwargs,
328
  )
@@ -356,16 +351,6 @@ def load_model(
356
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
357
  )
358
 
359
- # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
360
- # convert them back to fp16/bf16 for flash-attn compatibility.
361
- if cfg.flash_attention and cfg.is_llama_derived_model:
362
- for name, module in model.named_modules():
363
- if "norm" in name:
364
- module.to(torch_dtype)
365
- if "lm_head" in name or "embed_tokens" in name:
366
- if hasattr(module, "weight"):
367
- module.to(torch_dtype)
368
-
369
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
370
 
371
  if cfg.ddp and not load_in_8bit:
@@ -509,6 +494,22 @@ def load_lora(model, cfg):
509
  else:
510
  model = get_peft_model(model, lora_config)
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  model.print_trainable_parameters()
513
 
514
  return model, lora_config
 
11
  import torch
12
  import transformers
13
  from optimum.bettertransformer import BetterTransformer
14
+ from peft.tuners.lora import LoraLayer
15
  from transformers import ( # noqa: F401
16
  AutoConfig,
17
  AutoModelForCausalLM,
 
147
  LOG.info("patching _expand_mask")
148
  hijack_expand_mask()
149
 
 
 
 
 
 
 
150
  try:
151
  if cfg.gptq:
152
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
 
178
  load_in_4bit=True,
179
  llm_int8_threshold=6.0,
180
  llm_int8_has_fp16_weight=False,
181
+ bnb_4bit_compute_dtype=cfg.torch_dtype,
182
  bnb_4bit_use_double_quant=True,
183
  bnb_4bit_quant_type="nf4",
184
  )
 
237
  device_map=cfg.device_map,
238
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
239
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
240
+ torch_dtype=cfg.torch_dtype,
241
  **model_kwargs,
242
  )
243
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
 
272
  device_map=cfg.device_map,
273
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
274
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
275
+ torch_dtype=cfg.torch_dtype,
276
  trust_remote_code=cfg.trust_remote_code or False,
277
  **model_kwargs,
278
  )
 
303
  device_map=cfg.device_map,
304
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
305
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
306
+ torch_dtype=cfg.torch_dtype,
307
  trust_remote_code=cfg.trust_remote_code or False,
308
  **model_kwargs,
309
  )
 
317
  device_map=cfg.device_map,
318
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
319
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
320
+ torch_dtype=cfg.torch_dtype,
321
  trust_remote_code=cfg.trust_remote_code or False,
322
  **model_kwargs,
323
  )
 
351
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
352
  )
353
 
 
 
 
 
 
 
 
 
 
 
354
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
355
 
356
  if cfg.ddp and not load_in_8bit:
 
494
  else:
495
  model = get_peft_model(model, lora_config)
496
 
497
+ for name, module in model.named_modules():
498
+ if isinstance(module, LoraLayer):
499
+ module = module.to(cfg.torch_dtype)
500
+ if "norm" in name:
501
+ module = module.to(torch.float32)
502
+ if "lm_head" in name or "embed_tokens" in name:
503
+ if hasattr(module, "weight"):
504
+ module = module.to(cfg.torch_dtype)
505
+
506
+ # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
507
+ # convert them back to fp16/bf16 for flash-attn compatibility.
508
+ if cfg.flash_attention and cfg.is_llama_derived_model:
509
+ for name, module in model.named_modules():
510
+ if "norm" in name:
511
+ module = module.to(cfg.torch_dtype)
512
+
513
  model.print_trainable_parameters()
514
 
515
  return model, lora_config