tmm1 commited on
Commit
0d2e34f
·
unverified ·
2 Parent(s): b56a6c0 2eda9e0

Merge pull request #336 from tmm1/flash-attn

Browse files

Fix flash-attn + qlora not working with llama models

src/axolotl/{flash_attn.py → monkeypatch/llama_attn_hijack_flash.py} RENAMED
File without changes
src/axolotl/utils/models.py CHANGED
@@ -92,7 +92,9 @@ def load_model(
92
 
93
  if cfg.is_llama_derived_model and cfg.flash_attention:
94
  if cfg.device not in ["mps", "cpu"] and not cfg.inference:
95
- from axolotl.flash_attn import replace_llama_attn_with_flash_attn
 
 
96
 
97
  LOG.info("patching with flash attention")
98
  replace_llama_attn_with_flash_attn()
@@ -331,6 +333,16 @@ def load_model(
331
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
332
  )
333
 
 
 
 
 
 
 
 
 
 
 
334
  model, lora_config = load_adapter(model, cfg, adapter)
335
 
336
  if cfg.ddp and not load_in_8bit:
@@ -407,14 +419,6 @@ def load_llama_adapter(model, cfg):
407
  else:
408
  model = get_peft_model(model, peft_config)
409
 
410
- if cfg.flash_attention:
411
- for name, module in model.named_modules():
412
- if "norm" in name:
413
- module.to(torch.float16)
414
- if "lm_head" in name or "embed_tokens" in name:
415
- if hasattr(module, "weight"):
416
- module.to(torch.float16)
417
-
418
  model.print_trainable_parameters()
419
 
420
  return model, peft_config
 
92
 
93
  if cfg.is_llama_derived_model and cfg.flash_attention:
94
  if cfg.device not in ["mps", "cpu"] and not cfg.inference:
95
+ from axolotl.monkeypatch.llama_attn_hijack_flash import (
96
+ replace_llama_attn_with_flash_attn,
97
+ )
98
 
99
  LOG.info("patching with flash attention")
100
  replace_llama_attn_with_flash_attn()
 
333
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
334
  )
335
 
336
+ # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
337
+ # convert them back to fp16/bf16 for flash-attn compatibility.
338
+ if cfg.flash_attention and cfg.is_llama_derived_model:
339
+ for name, module in model.named_modules():
340
+ if "norm" in name:
341
+ module.to(torch_dtype)
342
+ if "lm_head" in name or "embed_tokens" in name:
343
+ if hasattr(module, "weight"):
344
+ module.to(torch_dtype)
345
+
346
  model, lora_config = load_adapter(model, cfg, adapter)
347
 
348
  if cfg.ddp and not load_in_8bit:
 
419
  else:
420
  model = get_peft_model(model, peft_config)
421
 
 
 
 
 
 
 
 
 
422
  model.print_trainable_parameters()
423
 
424
  return model, peft_config