add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083)
Browse files
src/axolotl/utils/lora_embeddings.py
CHANGED
@@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type):
|
|
8 |
returns the linear embedding layers needed for loras, dependent on the model arch
|
9 |
"""
|
10 |
if model_type == "phi-msft":
|
11 |
-
return ["embd", "lm_head.linear"]
|
12 |
-
|
|
|
|
|
|
8 |
returns the linear embedding layers needed for loras, dependent on the model arch
|
9 |
"""
|
10 |
if model_type == "phi-msft":
|
11 |
+
return ["embd.wte", "lm_head.linear"]
|
12 |
+
if model_type == "gpt_neox":
|
13 |
+
return ["embed_in", "embed_out"]
|
14 |
+
return ["embed_tokens", "lm_head"]
|
src/axolotl/utils/models.py
CHANGED
@@ -588,13 +588,14 @@ def load_model(
|
|
588 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
589 |
|
590 |
# make sure these are fp32 per Ramesh et al. (2021)
|
|
|
591 |
for name, module in model.named_modules():
|
592 |
if "norm" in name:
|
593 |
module.to(torch.float32)
|
594 |
if model_config.model_type == "btlm":
|
595 |
# don't upcast lm_head for btlm
|
596 |
continue
|
597 |
-
if
|
598 |
if hasattr(module, "weight"):
|
599 |
module.to(torch.float32)
|
600 |
|
@@ -619,15 +620,12 @@ def load_model(
|
|
619 |
|
620 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
621 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
622 |
-
if needs_fa2_dtype or
|
623 |
-
cfg.flash_attention
|
624 |
-
and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
|
625 |
-
):
|
626 |
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
627 |
for name, module in model.named_modules():
|
628 |
if "norm" in name:
|
629 |
module.to(cfg.torch_dtype)
|
630 |
-
if
|
631 |
if hasattr(module, "weight"):
|
632 |
module.to(cfg.torch_dtype)
|
633 |
|
|
|
588 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
589 |
|
590 |
# make sure these are fp32 per Ramesh et al. (2021)
|
591 |
+
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
592 |
for name, module in model.named_modules():
|
593 |
if "norm" in name:
|
594 |
module.to(torch.float32)
|
595 |
if model_config.model_type == "btlm":
|
596 |
# don't upcast lm_head for btlm
|
597 |
continue
|
598 |
+
if any(m in name for m in embedding_modules):
|
599 |
if hasattr(module, "weight"):
|
600 |
module.to(torch.float32)
|
601 |
|
|
|
620 |
|
621 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
622 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
623 |
+
if needs_fa2_dtype or cfg.flash_attention:
|
|
|
|
|
|
|
624 |
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
625 |
for name, module in model.named_modules():
|
626 |
if "norm" in name:
|
627 |
module.to(cfg.torch_dtype)
|
628 |
+
if any(m in name for m in embedding_modules):
|
629 |
if hasattr(module, "weight"):
|
630 |
module.to(cfg.torch_dtype)
|
631 |
|
tests/core/test_trainer_builder.py
CHANGED
@@ -30,6 +30,7 @@ def fixture_cfg():
|
|
30 |
"adam_epsilon": 0.00001,
|
31 |
"dataloader_num_workers": 1,
|
32 |
"dataloader_pin_memory": True,
|
|
|
33 |
}
|
34 |
)
|
35 |
|
|
|
30 |
"adam_epsilon": 0.00001,
|
31 |
"dataloader_num_workers": 1,
|
32 |
"dataloader_pin_memory": True,
|
33 |
+
"model_config_type": "llama",
|
34 |
}
|
35 |
)
|
36 |
|
tests/test_validation.py
CHANGED
@@ -770,7 +770,7 @@ class ValidationCheckModelConfig(BaseValidation):
|
|
770 |
"adapter": "qlora",
|
771 |
"load_in_4bit": True,
|
772 |
"tokens": ["<|imstart|>"],
|
773 |
-
"lora_modules_to_save": ["embd", "lm_head.linear"],
|
774 |
}
|
775 |
)
|
776 |
|
|
|
770 |
"adapter": "qlora",
|
771 |
"load_in_4bit": True,
|
772 |
"tokens": ["<|imstart|>"],
|
773 |
+
"lora_modules_to_save": ["embd.wte", "lm_head.linear"],
|
774 |
}
|
775 |
)
|
776 |
|