Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| 模型合并脚本 - 将LoRA权重合并到基础模型中 | |
| 用于推理和部署 | |
| """ | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import argparse | |
| def merge_lora_model(base_model_path, lora_model_path, output_path): | |
| """ | |
| 合并LoRA权重到基础模型 | |
| Args: | |
| base_model_path: 基础模型路径 | |
| lora_model_path: LoRA模型路径(训练输出) | |
| output_path: 合并后模型保存路径 | |
| """ | |
| print("📥 Loading base model...") | |
| # 加载基础模型(不使用量化) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| print("📥 Loading LoRA model...") | |
| # 加载LoRA模型 | |
| model = PeftModel.from_pretrained(base_model, lora_model_path) | |
| print("🔄 Merging LoRA weights...") | |
| # 合并权重 | |
| model = model.merge_and_unload() | |
| print("💾 Saving merged model...") | |
| # 保存合并后的模型 | |
| model.save_pretrained(output_path, safe_serialization=True) | |
| # 复制tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_path) | |
| tokenizer.save_pretrained(output_path) | |
| print(f"✅ Model merged and saved to {output_path}") | |
| def test_merged_model(model_path): | |
| """测试合并后的模型""" | |
| print("🧪 Testing merged model...") | |
| # 加载模型和tokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # 测试提示 | |
| test_prompt = "### Human: Create an advertisement for a revolutionary AI-powered smartwatch\n### Assistant:" | |
| inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| generated_text = response[len(test_prompt):].strip() | |
| print(f"\n📝 Test Prompt: Create an advertisement for a revolutionary AI-powered smartwatch") | |
| print(f"📄 Generated Response:\n{generated_text}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Merge LoRA weights with base model") | |
| parser.add_argument("--base_model", required=True, help="Path to base model") | |
| parser.add_argument("--lora_model", required=True, help="Path to LoRA model (training output)") | |
| parser.add_argument("--output", required=True, help="Output path for merged model") | |
| parser.add_argument("--test", action="store_true", help="Test the merged model") | |
| args = parser.parse_args() | |
| # 合并模型 | |
| merge_lora_model(args.base_model, args.lora_model, args.output) | |
| # 测试模型(可选) | |
| if args.test: | |
| test_merged_model(args.output) | |
| if __name__ == "__main__": | |
| # 示例用法 | |
| print("📋 Merge LoRA Model Script") | |
| print("\n使用方法:") | |
| print("python merge_model.py --base_model microsoft/DialoGPT-medium --lora_model ./results --output ./merged_model --test") | |
| print("\n或者直接运行默认配置:") | |
| # 默认配置 | |
| merge_lora_model( | |
| base_model_path="microsoft/DialoGPT-medium", # 替换为实际的OpenAI OSS 120B模型 | |
| lora_model_path="./results", | |
| output_path="./merged_model" | |
| ) | |
| # 测试合并后的模型 | |
| test_merged_model("./merged_model") |