loopllama-1B / transfer_weights_llama.py
ericzhang0328's picture
Upload folder using huggingface_hub
2d5a424 verified
from transformers import AutoModelForCausalLM, AutoConfig
import torch
from safetensors.torch import save_file
base_dir = "/projects/llama-cpt/models/Llama-3.2-1B"
base = AutoModelForCausalLM.from_pretrained(base_dir, torch_dtype=torch.float16)
cfg = AutoConfig.from_pretrained("/projects/llama-cpt/models/loopllama", trust_remote_code=True)
from modeling_llama import LoopLlamaForCausalLM
dst = LoopLlamaForCausalLM(cfg)
missing, unexpected = dst.load_state_dict(base.state_dict(), strict=False)
print("missing:", missing)
print("unexpected:", unexpected)
# state = dst.state_dict()
# save_file(state, "/projects/llama-cpt/models/loopllama/model.safetensors")
dst.save_pretrained(
"/projects/llama-cpt/models/loopllama",
safe_serialization=True,
max_shard_size="2GB"
)