|
|
|
|
|
|
|
""" |
|
从 Llama 3.2-1B 权重初始化 LoopLlama 模型 |
|
""" |
|
|
|
import os |
|
import torch |
|
from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer |
|
from configuration_llama import LlamaConfig |
|
from modeling_llama import LoopLlamaForCausalLM |
|
|
|
def setup_loopllama_from_pretrained( |
|
source_model_path="meta-llama/Llama-3.2-1B", |
|
target_path="./", |
|
loop_times=2 |
|
): |
|
""" |
|
从预训练的 Llama 模型创建 LoopLlama 模型 |
|
|
|
Args: |
|
source_model_path: 源 Llama 模型路径 |
|
target_path: 目标保存路径 |
|
loop_times: 循环次数 |
|
""" |
|
print(f"Loading original Llama model from {source_model_path}...") |
|
|
|
|
|
original_model = LlamaForCausalLM.from_pretrained( |
|
source_model_path, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True |
|
) |
|
original_config = original_model.config |
|
|
|
|
|
try: |
|
tokenizer = LlamaTokenizer.from_pretrained(source_model_path) |
|
except: |
|
from transformers import AutoTokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(source_model_path) |
|
|
|
print("Creating LoopLlama configuration...") |
|
|
|
|
|
loop_config = LlamaConfig( |
|
**original_config.to_dict(), |
|
loop_times=loop_times |
|
) |
|
|
|
print(f"Creating LoopLlama model with {loop_times} loop times...") |
|
|
|
|
|
loop_model = LoopLlamaForCausalLM(loop_config) |
|
|
|
print("Copying weights from original model...") |
|
|
|
|
|
original_state_dict = original_model.state_dict() |
|
loop_state_dict = loop_model.state_dict() |
|
|
|
|
|
for key in loop_state_dict.keys(): |
|
if key in original_state_dict: |
|
print(f"Copying {key}") |
|
loop_state_dict[key].copy_(original_state_dict[key]) |
|
else: |
|
print(f"Warning: {key} not found in original model") |
|
|
|
print(f"Saving LoopLlama model to {target_path}...") |
|
|
|
|
|
loop_model.save_pretrained(target_path) |
|
loop_config.save_pretrained(target_path) |
|
tokenizer.save_pretrained(target_path) |
|
|
|
print("Setup completed!") |
|
|
|
|
|
print("Verifying model loading...") |
|
test_model = LoopLlamaForCausalLM.from_pretrained( |
|
target_path, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
print(f"Model loaded successfully. Loop times: {test_model.config.loop_times}") |
|
|
|
return loop_model, tokenizer |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--source", default="/9950backfile/zjy_2/loopllama_cpt/loopllama-cpt/models/llama3_2-1B", help="Source Llama model") |
|
parser.add_argument("--target", default="./", help="Target directory") |
|
parser.add_argument("--loop_times", type=int, default=3, help="Number of loop times") |
|
|
|
args = parser.parse_args() |
|
|
|
setup_loopllama_from_pretrained( |
|
source_model_path=args.source, |
|
target_path=args.target, |
|
loop_times=args.loop_times |
|
) |