from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import logging import argparse logger = logging.getLogger(__name__) def merge_lora(base_model_name_or_path, peft_model_path, output_dir, device='auto', push_to_hub=False): if device == 'auto': device_arg = { 'device_map': 'auto' } else: device_arg = { 'device_map': { "": device} } logger.info(f"Loading base model: {base_model_name_or_path}") base_model = AutoModelForCausalLM.from_pretrained( base_model_name_or_path, return_dict=True, torch_dtype=torch.float16, **device_arg ) logger.info(f"Loading PEFT: {peft_model_path}") model = PeftModel.from_pretrained(base_model, peft_model_path, torch_dtype=torch.float16, **device_arg) logger.info(f"Running merge_and_unload") model = model.merge_and_unload() tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) if push_to_hub: logger.info(f"Saving to hub ...") model.push_to_hub(f"{output_dir}", use_temp_dir=False) tokenizer.push_to_hub(f"{output_dir}", use_temp_dir=False) else: model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir, torch_dtype=torch.float16) logger.info(f"Model saved to {output_dir}") if __name__ == "__main__" : logger = logging.getLogger() logging.basicConfig( format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" ) parser = argparse.ArgumentParser() parser.add_argument("--base_model_name_or_path", type=str) parser.add_argument("--peft_model_path", type=str) parser.add_argument("--output_dir", type=str) parser.add_argument("--device", type=str, default="auto") parser.add_argument("--push_to_hub", action="store_true") args = parser.parse_args() merge_lora(base_model_name_or_path = args.base_model_name_or_path, peft_model_path = args.peft_model_path, output_dir = args.output_dir, device = args.device, push_to_hub = args.push_to_hub)