from transformers import AutoModelForCausalLM, AutoConfig | |
import torch | |
from safetensors.torch import save_file | |
base_dir = "/projects/llama-cpt/models/Llama-3.2-1B" | |
base = AutoModelForCausalLM.from_pretrained(base_dir, torch_dtype=torch.float16) | |
cfg = AutoConfig.from_pretrained("/projects/llama-cpt/models/loopllama", trust_remote_code=True) | |
from modeling_llama import LoopLlamaForCausalLM | |
dst = LoopLlamaForCausalLM(cfg) | |
missing, unexpected = dst.load_state_dict(base.state_dict(), strict=False) | |
print("missing:", missing) | |
print("unexpected:", unexpected) | |
# state = dst.state_dict() | |
# save_file(state, "/projects/llama-cpt/models/loopllama/model.safetensors") | |
dst.save_pretrained( | |
"/projects/llama-cpt/models/loopllama", | |
safe_serialization=True, | |
max_shard_size="2GB" | |
) | |