fix for qwen w lora (#906)
Browse files- src/axolotl/utils/models.py +10 -3
src/axolotl/utils/models.py
CHANGED
@@ -412,15 +412,22 @@ def load_model(
|
|
412 |
module.to(torch.float32)
|
413 |
|
414 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
if (cfg.adapter == "lora" and load_in_8bit) or (
|
416 |
cfg.adapter == "qlora" and cfg.load_in_4bit
|
417 |
):
|
418 |
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
419 |
if cfg.gradient_checkpointing:
|
420 |
model.gradient_checkpointing_enable()
|
421 |
-
|
422 |
-
model
|
423 |
-
|
|
|
424 |
needs_fa2_dtype = True
|
425 |
|
426 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
|
|
412 |
module.to(torch.float32)
|
413 |
|
414 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
415 |
+
skip_prepare_model_for_kbit_training = False
|
416 |
+
|
417 |
+
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
418 |
+
# Qwen doesn't play nicely with LoRA if this is enabled
|
419 |
+
skip_prepare_model_for_kbit_training = True
|
420 |
+
|
421 |
if (cfg.adapter == "lora" and load_in_8bit) or (
|
422 |
cfg.adapter == "qlora" and cfg.load_in_4bit
|
423 |
):
|
424 |
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
425 |
if cfg.gradient_checkpointing:
|
426 |
model.gradient_checkpointing_enable()
|
427 |
+
if not skip_prepare_model_for_kbit_training:
|
428 |
+
model = prepare_model_for_kbit_training(
|
429 |
+
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
430 |
+
)
|
431 |
needs_fa2_dtype = True
|
432 |
|
433 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|