#!/usr/bin/env python3 # models/loopllama-1B/setup_from_llama.py """ 从 Llama 3.2-1B 权重初始化 LoopLlama 模型 """ import os import torch from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer from configuration_llama import LlamaConfig from modeling_llama import LoopLlamaForCausalLM def setup_loopllama_from_pretrained( source_model_path="meta-llama/Llama-3.2-1B", target_path="./", loop_times=2 ): """ 从预训练的 Llama 模型创建 LoopLlama 模型 Args: source_model_path: 源 Llama 模型路径 target_path: 目标保存路径 loop_times: 循环次数 """ print(f"Loading original Llama model from {source_model_path}...") # 加载原始模型和配置 original_model = LlamaForCausalLM.from_pretrained( source_model_path, torch_dtype=torch.bfloat16, trust_remote_code=True ) original_config = original_model.config # 加载 tokenizer try: tokenizer = LlamaTokenizer.from_pretrained(source_model_path) except: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(source_model_path) print("Creating LoopLlama configuration...") # 创建 LoopLlama 配置 loop_config = LlamaConfig( **original_config.to_dict(), loop_times=loop_times ) print(f"Creating LoopLlama model with {loop_times} loop times...") # 创建 LoopLlama 模型 loop_model = LoopLlamaForCausalLM(loop_config) print("Copying weights from original model...") # 复制权重 (只复制存在的键) original_state_dict = original_model.state_dict() loop_state_dict = loop_model.state_dict() # 复制匹配的权重 for key in loop_state_dict.keys(): if key in original_state_dict: print(f"Copying {key}") loop_state_dict[key].copy_(original_state_dict[key]) else: print(f"Warning: {key} not found in original model") print(f"Saving LoopLlama model to {target_path}...") # 保存模型和配置 loop_model.save_pretrained(target_path) loop_config.save_pretrained(target_path) tokenizer.save_pretrained(target_path) print("Setup completed!") # 验证模型可以加载 print("Verifying model loading...") test_model = LoopLlamaForCausalLM.from_pretrained( target_path, trust_remote_code=True, torch_dtype=torch.bfloat16 ) print(f"Model loaded successfully. Loop times: {test_model.config.loop_times}") return loop_model, tokenizer if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--source", default="/9950backfile/zjy_2/loopllama_cpt/loopllama-cpt/models/llama3_2-1B", help="Source Llama model") parser.add_argument("--target", default="./", help="Target directory") parser.add_argument("--loop_times", type=int, default=3, help="Number of loop times") args = parser.parse_args() setup_loopllama_from_pretrained( source_model_path=args.source, target_path=args.target, loop_times=args.loop_times )