File size: 3,243 Bytes
7ec3f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
# models/loopllama-1B/setup_from_llama.py

"""
从 Llama 3.2-1B 权重初始化 LoopLlama 模型
"""

import os
import torch
from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer
from configuration_llama import LlamaConfig
from modeling_llama import LoopLlamaForCausalLM

def setup_loopllama_from_pretrained(
    source_model_path="meta-llama/Llama-3.2-1B", 
    target_path="./",
    loop_times=2
):
    """
    从预训练的 Llama 模型创建 LoopLlama 模型
    
    Args:
        source_model_path: 源 Llama 模型路径
        target_path: 目标保存路径
        loop_times: 循环次数
    """
    print(f"Loading original Llama model from {source_model_path}...")
    
    # 加载原始模型和配置
    original_model = LlamaForCausalLM.from_pretrained(
        source_model_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    original_config = original_model.config
    
    # 加载 tokenizer
    try:
        tokenizer = LlamaTokenizer.from_pretrained(source_model_path)
    except:
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(source_model_path)
    
    print("Creating LoopLlama configuration...")
    
    # 创建 LoopLlama 配置
    loop_config = LlamaConfig(
        **original_config.to_dict(),
        loop_times=loop_times
    )
    
    print(f"Creating LoopLlama model with {loop_times} loop times...")
    
    # 创建 LoopLlama 模型
    loop_model = LoopLlamaForCausalLM(loop_config)
    
    print("Copying weights from original model...")
    
    # 复制权重 (只复制存在的键)
    original_state_dict = original_model.state_dict()
    loop_state_dict = loop_model.state_dict()
    
    # 复制匹配的权重
    for key in loop_state_dict.keys():
        if key in original_state_dict:
            print(f"Copying {key}")
            loop_state_dict[key].copy_(original_state_dict[key])
        else:
            print(f"Warning: {key} not found in original model")
    
    print(f"Saving LoopLlama model to {target_path}...")
    
    # 保存模型和配置
    loop_model.save_pretrained(target_path)
    loop_config.save_pretrained(target_path)
    tokenizer.save_pretrained(target_path)
    
    print("Setup completed!")
    
    # 验证模型可以加载
    print("Verifying model loading...")
    test_model = LoopLlamaForCausalLM.from_pretrained(
        target_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    )
    print(f"Model loaded successfully. Loop times: {test_model.config.loop_times}")
    
    return loop_model, tokenizer

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--source", default="/9950backfile/zjy_2/loopllama_cpt/loopllama-cpt/models/llama3_2-1B", help="Source Llama model")
    parser.add_argument("--target", default="./", help="Target directory")
    parser.add_argument("--loop_times", type=int, default=3, help="Number of loop times")
    
    args = parser.parse_args()
    
    setup_loopllama_from_pretrained(
        source_model_path=args.source,
        target_path=args.target,
        loop_times=args.loop_times
    )