Merge-aml / lazy_tensor_loader.py
bruhzair's picture
Upload lazy_tensor_loader.py
1982d49 verified
# Copyright (C) 2025 Arcee AI
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
import json
import logging
import os
import os.path
from dataclasses import dataclass
from typing import Dict, List, Optional
import safetensors
import safetensors.torch
import torch
from torch import Tensor
from mergekit.io.loader import TensorLoader
@dataclass
class ShardInfo:
filename: str
contained_keys: List[str]
@dataclass
class ShardedTensorIndex:
base_path: str
is_safetensors: bool
tensor_paths: Dict[str, str]
shards: List[ShardInfo]
@classmethod
def from_disk(cls, base_path: str) -> "ShardedTensorIndex":
model_path = None
for model_file_name in [
"model.safetensors",
"pytorch_model.bin",
"pytorch_model*.safetensors",
]:
candidate_path = os.path.join(base_path, model_file_name)
if os.path.exists(candidate_path) or os.path.exists(
candidate_path + ".index.json"
):
model_path = candidate_path
break
if not model_path:
raise RuntimeError(f"Unable to find model files at {base_path}")
is_safetensors = model_path.endswith(".safetensors")
tensor_paths = None
shards = []
if os.path.exists(model_path + ".index.json"):
# shared model - parse index
with open(model_path + ".index.json", "r") as fd:
weight_map = json.load(fd)["weight_map"]
tensor_paths = weight_map
shard_names = list(sorted(set(tensor_paths[e] for e in tensor_paths)))
for shard_name in shard_names:
info = ShardInfo(
shard_name,
[key for key in tensor_paths if tensor_paths[key] == shard_name],
)
shards.append(info)
elif os.path.exists(model_path):
shard_name = os.path.basename(model_path)
# get list of tensors contained in single-file checkpoint
if model_path.lower().endswith(".safetensors"):
with safetensors.safe_open(model_path, framework="pt") as st:
tensor_paths = {key: shard_name for key in st.keys()}
else:
# this is ugly but not much else can be done
shard = torch.load(model_path, map_location="meta")
if "state_dict" in shard:
shard = shard["state_dict"]
tensor_paths = {key: shard_name for key in shard}
shards.append(
ShardInfo(os.path.basename(model_path), list(tensor_paths.keys()))
)
return ShardedTensorIndex(
base_path=base_path,
is_safetensors=is_safetensors,
tensor_paths=tensor_paths,
shards=shards,
)
class LazyTensorLoader:
index: ShardedTensorIndex
current_shard: Optional[TensorLoader]
lazy_unpickle: bool
def __init__(self, index: ShardedTensorIndex, lazy_unpickle: bool = True):
self.index = index
self.current_shard = None
self.lazy_unpickle = lazy_unpickle
def get_tensor(
self,
key: str,
device: str = "cpu",
aliases: Optional[List[str]] = None,
raise_on_missing: bool = True,
) -> Optional[Tensor]:
if aliases and key not in self.index.tensor_paths:
for alias in aliases:
if alias in self.index.tensor_paths:
key = alias
break
if self.current_shard is None or key not in self.current_shard.keys():
if key not in self.index.tensor_paths:
if raise_on_missing:
raise KeyError(key)
return None
self.current_shard = None
self.current_keys = None
shard_file = self.index.tensor_paths[key]
shard_full_path = os.path.join(self.index.base_path, shard_file)
logging.debug(f"Opening shard {shard_full_path}")
self.current_shard = TensorLoader.get(
shard_full_path, use_lazy_unpickle=self.lazy_unpickle, device=device
)
return self.current_shard.get_tensor(key).to(device)
def flush(self):
self.current_shard = None
self.current_keys = None
@classmethod
def from_disk(
cls, base_path: str, lazy_unpickle: bool = True
) -> "LazyTensorLoader":
return LazyTensorLoader(ShardedTensorIndex.from_disk(base_path), lazy_unpickle)