loopllama_origin / setup_script.py
ericzhang0328's picture
Upload folder using huggingface_hub
7ec3f06 verified
#!/usr/bin/env python3
# models/loopllama-1B/setup_from_llama.py
"""
从 Llama 3.2-1B 权重初始化 LoopLlama 模型
"""
import os
import torch
from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer
from configuration_llama import LlamaConfig
from modeling_llama import LoopLlamaForCausalLM
def setup_loopllama_from_pretrained(
source_model_path="meta-llama/Llama-3.2-1B",
target_path="./",
loop_times=2
):
"""
从预训练的 Llama 模型创建 LoopLlama 模型
Args:
source_model_path: 源 Llama 模型路径
target_path: 目标保存路径
loop_times: 循环次数
"""
print(f"Loading original Llama model from {source_model_path}...")
# 加载原始模型和配置
original_model = LlamaForCausalLM.from_pretrained(
source_model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
original_config = original_model.config
# 加载 tokenizer
try:
tokenizer = LlamaTokenizer.from_pretrained(source_model_path)
except:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(source_model_path)
print("Creating LoopLlama configuration...")
# 创建 LoopLlama 配置
loop_config = LlamaConfig(
**original_config.to_dict(),
loop_times=loop_times
)
print(f"Creating LoopLlama model with {loop_times} loop times...")
# 创建 LoopLlama 模型
loop_model = LoopLlamaForCausalLM(loop_config)
print("Copying weights from original model...")
# 复制权重 (只复制存在的键)
original_state_dict = original_model.state_dict()
loop_state_dict = loop_model.state_dict()
# 复制匹配的权重
for key in loop_state_dict.keys():
if key in original_state_dict:
print(f"Copying {key}")
loop_state_dict[key].copy_(original_state_dict[key])
else:
print(f"Warning: {key} not found in original model")
print(f"Saving LoopLlama model to {target_path}...")
# 保存模型和配置
loop_model.save_pretrained(target_path)
loop_config.save_pretrained(target_path)
tokenizer.save_pretrained(target_path)
print("Setup completed!")
# 验证模型可以加载
print("Verifying model loading...")
test_model = LoopLlamaForCausalLM.from_pretrained(
target_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16
)
print(f"Model loaded successfully. Loop times: {test_model.config.loop_times}")
return loop_model, tokenizer
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--source", default="/9950backfile/zjy_2/loopllama_cpt/loopllama-cpt/models/llama3_2-1B", help="Source Llama model")
parser.add_argument("--target", default="./", help="Target directory")
parser.add_argument("--loop_times", type=int, default=3, help="Number of loop times")
args = parser.parse_args()
setup_loopllama_from_pretrained(
source_model_path=args.source,
target_path=args.target,
loop_times=args.loop_times
)