|
import sys |
|
import time |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import lightning as L |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
wd = Path(__file__).parent.parent.resolve() |
|
sys.path.append(str(wd)) |
|
|
|
from lit_llama import LLaMA |
|
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup |
|
from lit_llama.lora import lora |
|
|
|
def del_lora_state_dict(model: nn.Module): |
|
base_model_dict = model.state_dict() |
|
key_to_delete = [k for k in base_model_dict if "lora_" in k] |
|
for del_key in key_to_delete: |
|
del base_model_dict[del_key] |
|
return base_model_dict |
|
|
|
|
|
def lora_model_lookup(checkpoint: dict) -> int: |
|
"""Returns the LoRA rank from the adapter checkpoint. |
|
|
|
""" |
|
return checkpoint["transformer.h.0.attn.c_attn.lora_B"].shape[1] |
|
|
|
|
|
def main( |
|
accelerator: str = "auto", |
|
lora_path: Optional[Path] = None, |
|
checkpoint_path: Optional[Path] = None, |
|
dtype: str = "bfloat16", |
|
) -> None: |
|
"""Merges lora weights to base model. |
|
|
|
Args: |
|
accelerator: The hardware to run on. Possible choices are: |
|
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. |
|
lora_path: Path to the checkpoint with trained LoRA weights, which are the output of |
|
`finetune_lora.py`. |
|
checkpoint_path: The checkpoint path to load. |
|
dtype: `torch.dtype` to work with |
|
""" |
|
if not lora_path: |
|
lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth") |
|
if not checkpoint_path: |
|
checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth") |
|
|
|
assert lora_path.is_file() |
|
assert checkpoint_path.is_file() |
|
|
|
fabric = L.Fabric(accelerator=accelerator, devices=1) |
|
|
|
dt = getattr(torch, dtype, None) |
|
if not isinstance(dt, torch.dtype): |
|
raise ValueError(f"{dtype} is not a valid dtype.") |
|
dtype = dt |
|
|
|
print("Loading model ...", file=sys.stderr) |
|
t0 = time.time() |
|
|
|
with (lazy_load(checkpoint_path) as pretrained_checkpoint, |
|
lazy_load(lora_path) as lora_checkpoint): |
|
name = llama_model_lookup(pretrained_checkpoint) |
|
rank = lora_model_lookup(lora_checkpoint) |
|
|
|
with EmptyInitOnDevice( |
|
device=fabric.device, dtype=dtype |
|
), lora(r=rank, alpha=16, dropout=0.05, enabled=True): |
|
model = LLaMA.from_name(name) |
|
|
|
|
|
model.load_state_dict(pretrained_checkpoint, strict=False) |
|
|
|
model.load_state_dict(lora_checkpoint, strict=False) |
|
|
|
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) |
|
|
|
model.eval() |
|
base_model_dict = del_lora_state_dict(model) |
|
save_path = lora_path.with_stem(f"{lora_path.stem}-lora-merged-weights") |
|
print("Saving LoRA to base model weights ...") |
|
torch.save(base_model_dict, save_path) |
|
print(f"Model saved at {save_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
from jsonargparse import CLI |
|
|
|
CLI(main) |