Llama2-TwAddr-LoRA / scripts /step2_merge.py
penut85420's picture
Add README, Data, Scripts
6512525
from typing import Union
import torch
from peft import PeftModel
from peft.tuners.lora import LoraModel
from transformers import LlamaForCausalLM as ModelCls
from transformers import LlamaTokenizerFast as TkCls
PeftCls = Union[PeftModel, LoraModel]
orig_model = "TheBloke/Llama-2-7B-Chat-fp16"
lora_model = "models/Llama-7B-TwAddr-LoRA"
output_dir = "models/Llama-7B-TwAddr-Merged"
model = ModelCls.from_pretrained(
orig_model,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
# Due to generation config validation.
model.generation_config.temperature = 1.0
model.generation_config.top_p = 1.0
model: PeftCls = PeftModel.from_pretrained(
model,
lora_model,
torch_dtype=torch.float16,
)
model = model.merge_and_unload()
model.save_pretrained(
output_dir,
safe_serialization=True,
)
# Tokenizer 也要跟著另外存一份
tk: TkCls = TkCls.from_pretrained(orig_model)
tk.save_pretrained(output_dir)