qwerrwe / src /axolotl /utils /lora_embeddings.py
winglian's picture
add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083)
78c5b19 unverified
raw
history blame
377 Bytes
"""
helpers for lora embeddings
"""
def get_linear_embedding_layers(model_type):
"""
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "phi-msft":
return ["embd.wte", "lm_head.linear"]
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
return ["embed_tokens", "lm_head"]