mynewmodel / merge_script.py
yopzey's picture
Committing all changes before LFS migration
9190d78
raw
history blame contribute delete
No virus
2.65 kB
import torch
from peft import PeftModel # Ensure you have 'peft' library or modify according to your setup
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import argparse
from utils import get_logger # Ensure this is implemented in your environment
import json
logger = get_logger("merge", "info")
def smart_tokenizer_and_embedding_resize(tokenizer, model, custom_tokens_path=None):
"""Resize tokenizer and embedding to accommodate new tokens."""
special_tokens_dict = {
"pad_token": "[PAD]",
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<unk>"
}
# Load custom tokens if specified
custom_tokens = []
if custom_tokens_path is not None:
with open(custom_tokens_path, 'r') as file:
custom_tokens = [line.strip() for line in file.readlines()]
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
if custom_tokens:
num_added_toks += tokenizer.add_tokens(custom_tokens, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Resized tokenizer and model embeddings. Added {num_added_toks} tokens.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-bm", "--base_model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Base model name or path")
parser.add_argument("-lm", "--lora_model", type=str, required=True, help="Path to the Lora model directory")
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory for the merged model")
parser.add_argument("--custom_tokens", type=str, default=None, help="Path to a file containing custom tokens")
args = parser.parse_args()
if not os.path.exists(args.lora_model):
raise FileNotFoundError(f"LoRA model directory {args.lora_model} not found.")
os.makedirs(args.output, exist_ok=True)
# Load the base model and tokenizer
model = AutoModelForCausalLM.from_pretrained(args.base_model)
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
# Adjust tokenizer and model for any additional tokens
smart_tokenizer_and_embedding_resize(tokenizer, model, args.custom_tokens)
# Load and merge the LoRA model
logger.info("Loading and merging the LoRA model...")
lora_model = PeftModel.from_pretrained(model, args.lora_model, merge_with_base=True)
# Save the merged model and tokenizer
lora_model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
logger.info(f"Merged model saved to {args.output}")
if __name__ == "__main__":
main()