File size: 5,817 Bytes
a164e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
from typing import Dict, Optional, Tuple
import torch
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference, dtype_from_name
from mergekit.graph import Task
from mergekit.io.lazy_tensor_loader import LazyTensorLoader
from mergekit.io.tensor_writer import TensorWriter
from mergekit.options import MergeOptions
class LoaderCache:
loaders: Dict[ModelReference, LazyTensorLoader] = {}
lora_cache_dir: Optional[str] = None
hf_cache_dir: Optional[str] = None
lazy_unpickle: bool = False
trust_remote_code: bool = False
# singleton instance
_instance: Optional["LoaderCache"] = None
def __new__(cls) -> "LoaderCache":
if cls._instance is None:
cls._instance = super(LoaderCache, cls).__new__(cls)
return cls._instance
def get(self, model: ModelReference) -> LazyTensorLoader:
if model not in self.loaders:
merged = model.merged(
cache_dir=self.lora_cache_dir, trust_remote_code=self.trust_remote_code
)
self.loaders[model] = merged.lazy_loader(
cache_dir=self.hf_cache_dir, lazy_unpickle=self.lazy_unpickle
)
return self.loaders[model]
def flush_all(self):
for loader in self.loaders.values():
loader.flush()
def setup(self, options: MergeOptions):
self.lora_cache_dir = options.lora_merge_cache
self.hf_cache_dir = options.transformers_cache
self.lazy_unpickle = options.lazy_unpickle
self.trust_remote_code = options.trust_remote_code
def _normalized_shard_name(path: str) -> int:
name, _ext = os.path.splitext(os.path.basename(path))
return name.lower().replace("pytorch_model", "model")
class LoadTensor(Task[Optional[torch.Tensor]]):
model: ModelReference
tensor: str
dtype: Optional[str] = None
device: Optional[str] = None
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
def arguments(self) -> Dict[str, Task]:
return {}
def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]:
all_names = [self.tensor] + list(self.aliases or [])
for name in all_names:
if name in loader.index.tensor_paths:
return name
return None
def execute(self) -> Optional[torch.Tensor]:
loader = LoaderCache().get(self.model)
name = self._resolve_name(loader)
if not name:
if not self.optional:
raise RuntimeError(
f"Tensor {self.tensor} required but not present in model {self.model}"
)
return None
x = loader.get_tensor(name, device=self.device or "cpu")
if self.dtype:
x = x.to(dtype=dtype_from_name(self.dtype))
return x
def priority(self) -> int:
return -1000
def group_label(self) -> Optional[str]:
loader = LoaderCache().get(self.model)
name = self._resolve_name(loader)
if name:
shard_path = loader.index.tensor_paths[name]
return _normalized_shard_name(shard_path)
return None
class GatherTensors(Task[Dict[ModelReference, torch.Tensor]]):
weight_info: ImmutableMap[ModelReference, WeightInfo]
dtype: Optional[str] = None
device: Optional[str] = None
def arguments(self) -> Dict[str, Task]:
return {
f"{str(model)}:{wi.name}": LoadTensor(
model=model,
tensor=wi.name,
dtype=self.dtype,
device=self.device,
optional=wi.optional,
aliases=wi.aliases,
)
for (model, wi) in self.weight_info.items()
}
def group_label(self) -> Optional[str]:
return max(t.group_label() or "" for t in self.arguments().values())
def priority(self) -> int:
return -10
def execute(self, **kwargs) -> Dict[ModelReference, torch.Tensor]:
key2model = {
f"{str(model)}:{wi.name}": model for (model, wi) in self.weight_info.items()
}
return {
key2model[key]: kwargs[key] for key in key2model if kwargs[key] is not None
}
class TensorWriterTask(Task[TensorWriter]):
out_path: str
max_shard_size: int
safe_serialization: bool = True
def arguments(self) -> Dict[str, Task]:
return {}
def execute(self, **_kwargs) -> TensorWriter:
return TensorWriter(
self.out_path,
max_shard_size=self.max_shard_size,
safe_serialization=self.safe_serialization,
)
class SaveTensor(Task[None]):
tensor_name: str
tensor_task: Task
writer_task: TensorWriterTask
clone: bool
optional: bool = False
def arguments(self) -> Dict[str, Task]:
return {"writer": self.writer_task, "tensor": self.tensor_task}
def priority(self) -> int:
return 1000
def group_label(self) -> Optional[str]:
return self.tensor_task.group_label()
def execute(self, writer: TensorWriter, tensor: Optional[torch.Tensor]) -> None:
if tensor is None:
if not self.optional:
raise RuntimeError(f"No value for required tensor {self.tensor_name}")
return
writer.save_tensor(name=self.tensor_name, tensor=tensor, clone=self.clone)
class FinalizeModel(Task[None]):
tensor_save_tasks: Tuple[Task, ...]
writer_task: TensorWriterTask
def arguments(self) -> Dict[str, Task]:
return {
"writer": self.writer_task,
**{f"_unused_{idx}": t for idx, t in enumerate(self.tensor_save_tasks)},
}
def execute(self, writer: TensorWriter, **kwargs) -> None:
writer.finalize()
|