Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import glob
import json
import os
from typing import Optional
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
class TargetEmbeddingsAndHead(nn.Module):
"""
Efficiently loads only the embedding layer and lm_head from a pretrained model.
Avoids loading the full model into memory.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@classmethod
def from_pretrained(
cls,
model_path: str,
embed_key: str = "model.embed_tokens.weight",
lm_head_key: str = "lm_head.weight",
cache_dir: Optional[str] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
trust_remote_code: bool = False,
) -> "TargetEmbeddingsAndHead":
# 1. Load Config
config = AutoConfig.from_pretrained(
model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code
)
instance = cls(config)
# 2. Resolve Model Path (Handle Hub)
local_model_path = model_path
if not os.path.exists(local_model_path):
try:
local_model_path = snapshot_download(
repo_id=model_path, cache_dir=cache_dir
)
except:
pass # Maybe it's a local path that looks like a repo ID but doesn't exist?
# 3. Load Weights Efficiently
instance._load_weights(local_model_path, embed_key, lm_head_key)
# 4. Move to Device & Freeze
instance.to(device=device, dtype=dtype)
instance.eval()
instance.requires_grad_(False)
return instance
def _load_weights(self, model_path: str, embed_key: str, lm_head_key: str):
# Locate index.json
index_files = glob.glob(os.path.join(model_path, "*.index.json"))
weight_map = {}
if index_files:
# Sharded Checkpoint
with open(index_files[0], "r") as f:
index = json.load(f)
# Find which file contains our keys
weight_map = index.get("weight_map", {})
files_to_load = {}
if embed_key in weight_map:
files_to_load[embed_key] = weight_map[embed_key]
else:
# Fallback: sometimes keys are prefixed differently?
print(
f"Warning: {embed_key} not found in weight_map. Keys available: {list(weight_map.keys())[:5]}..."
)
if lm_head_key in weight_map:
files_to_load[lm_head_key] = weight_map[lm_head_key]
# Load specific files
for key, filename in files_to_load.items():
file_path = os.path.join(model_path, filename)
self._load_key_from_file(file_path, key)
else:
# Non-sharded Checkpoint (single file)
# Try finding .safetensors or .bin
safetensors = glob.glob(os.path.join(model_path, "*.safetensors"))
bins = glob.glob(os.path.join(model_path, "*.bin"))
target_file = None
if safetensors:
target_file = safetensors[0]
elif bins:
target_file = bins[0]
if target_file:
self._load_key_from_file(target_file, embed_key)
self._load_key_from_file(target_file, lm_head_key)
else:
raise FileNotFoundError(f"No checkpoint file found in {model_path}")
def _load_key_from_file(self, file_path: str, key: str):
tensor = None
if file_path.endswith(".safetensors"):
with safe_open(file_path, framework="pt") as f:
if key in f.keys():
tensor = f.get_tensor(key)
else:
# torch.load loads full dict, less efficient but works
state_dict = torch.load(file_path, map_location="cpu")
if key in state_dict:
tensor = state_dict[key]
del state_dict # Free immediately
if tensor is not None:
if key.endswith("embed_tokens.weight"):
self.embed_tokens.weight.data.copy_(tensor)
print(f"Loaded embedding weights from {file_path}")
elif key.endswith("lm_head.weight"):
self.lm_head.weight.data.copy_(tensor)
print(f"Loaded lm_head weights from {file_path}")
else:
print(f"Warning: Key {key} not found in {file_path}")