|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
import os.path |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional |
|
|
|
import safetensors |
|
import safetensors.torch |
|
import torch |
|
from torch import Tensor |
|
|
|
from mergekit.io.loader import TensorLoader |
|
|
|
|
|
@dataclass |
|
class ShardInfo: |
|
filename: str |
|
contained_keys: List[str] |
|
|
|
|
|
@dataclass |
|
class ShardedTensorIndex: |
|
base_path: str |
|
is_safetensors: bool |
|
tensor_paths: Dict[str, str] |
|
shards: List[ShardInfo] |
|
|
|
@classmethod |
|
def from_disk(cls, base_path: str) -> "ShardedTensorIndex": |
|
model_path = None |
|
for model_file_name in [ |
|
"model.safetensors", |
|
"pytorch_model.bin", |
|
"pytorch_model*.safetensors", |
|
]: |
|
candidate_path = os.path.join(base_path, model_file_name) |
|
if os.path.exists(candidate_path) or os.path.exists( |
|
candidate_path + ".index.json" |
|
): |
|
model_path = candidate_path |
|
break |
|
|
|
if not model_path: |
|
raise RuntimeError(f"Unable to find model files at {base_path}") |
|
|
|
is_safetensors = model_path.endswith(".safetensors") |
|
tensor_paths = None |
|
shards = [] |
|
|
|
if os.path.exists(model_path + ".index.json"): |
|
|
|
with open(model_path + ".index.json", "r") as fd: |
|
weight_map = json.load(fd)["weight_map"] |
|
tensor_paths = weight_map |
|
|
|
shard_names = list(sorted(set(tensor_paths[e] for e in tensor_paths))) |
|
for shard_name in shard_names: |
|
info = ShardInfo( |
|
shard_name, |
|
[key for key in tensor_paths if tensor_paths[key] == shard_name], |
|
) |
|
shards.append(info) |
|
|
|
elif os.path.exists(model_path): |
|
shard_name = os.path.basename(model_path) |
|
|
|
|
|
if model_path.lower().endswith(".safetensors"): |
|
with safetensors.safe_open(model_path, framework="pt") as st: |
|
tensor_paths = {key: shard_name for key in st.keys()} |
|
else: |
|
|
|
shard = torch.load(model_path, map_location="meta") |
|
if "state_dict" in shard: |
|
shard = shard["state_dict"] |
|
|
|
tensor_paths = {key: shard_name for key in shard} |
|
|
|
shards.append( |
|
ShardInfo(os.path.basename(model_path), list(tensor_paths.keys())) |
|
) |
|
|
|
return ShardedTensorIndex( |
|
base_path=base_path, |
|
is_safetensors=is_safetensors, |
|
tensor_paths=tensor_paths, |
|
shards=shards, |
|
) |
|
|
|
|
|
class LazyTensorLoader: |
|
index: ShardedTensorIndex |
|
current_shard: Optional[TensorLoader] |
|
lazy_unpickle: bool |
|
|
|
def __init__(self, index: ShardedTensorIndex, lazy_unpickle: bool = True): |
|
self.index = index |
|
self.current_shard = None |
|
self.lazy_unpickle = lazy_unpickle |
|
|
|
def get_tensor( |
|
self, |
|
key: str, |
|
device: str = "cpu", |
|
aliases: Optional[List[str]] = None, |
|
raise_on_missing: bool = True, |
|
) -> Optional[Tensor]: |
|
if aliases and key not in self.index.tensor_paths: |
|
for alias in aliases: |
|
if alias in self.index.tensor_paths: |
|
key = alias |
|
break |
|
|
|
if self.current_shard is None or key not in self.current_shard.keys(): |
|
if key not in self.index.tensor_paths: |
|
if raise_on_missing: |
|
raise KeyError(key) |
|
return None |
|
|
|
self.current_shard = None |
|
self.current_keys = None |
|
|
|
shard_file = self.index.tensor_paths[key] |
|
shard_full_path = os.path.join(self.index.base_path, shard_file) |
|
logging.debug(f"Opening shard {shard_full_path}") |
|
self.current_shard = TensorLoader.get( |
|
shard_full_path, use_lazy_unpickle=self.lazy_unpickle, device=device |
|
) |
|
|
|
return self.current_shard.get_tensor(key).to(device) |
|
|
|
def flush(self): |
|
self.current_shard = None |
|
self.current_keys = None |
|
|
|
@classmethod |
|
def from_disk( |
|
cls, base_path: str, lazy_unpickle: bool = True |
|
) -> "LazyTensorLoader": |
|
return LazyTensorLoader(ShardedTensorIndex.from_disk(base_path), lazy_unpickle) |
|
|