|
"""Utility functions for training and inference.""" |
|
|
|
import functools |
|
import pickle |
|
import warnings |
|
from io import BytesIO |
|
from pathlib import Path |
|
from contextlib import contextmanager |
|
|
|
import torch |
|
import torch.utils._device |
|
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy |
|
from torch.distributed.fsdp import FullStateDictConfig |
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.distributed.fsdp import StateDictType |
|
from torch.serialization import normalize_storage_type |
|
|
|
llama_model_sizes = { |
|
4096: "7B", |
|
5120: "13B", |
|
6656: "30B", |
|
8192: "65B", |
|
} |
|
|
|
|
|
def llama_model_lookup(checkpoint: dict) -> str: |
|
"""Returns the LLaMA model name from the checkpoint. |
|
|
|
Checks the width of the lm_head.weight matrix, as these uniquely identify the model. |
|
""" |
|
embedding_size = checkpoint['transformer.wte.weight'].shape[1] |
|
return llama_model_sizes[embedding_size] |
|
|
|
|
|
def find_multiple(n: int, k: int) -> int: |
|
if n % k == 0: |
|
return n |
|
return n + k - (n % k) |
|
|
|
|
|
def save_model_checkpoint(fabric, model, file_path): |
|
"""Handles boilerplate logic for retrieving and saving the state_dict. |
|
|
|
This will be upstreamed to Fabric soon. |
|
""" |
|
file_path = Path(file_path) |
|
|
|
if isinstance(fabric.strategy, DeepSpeedStrategy): |
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict |
|
|
|
fabric.save(file_path, {"model": model}) |
|
fabric.barrier() |
|
if fabric.global_rank == 0: |
|
|
|
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth")) |
|
return |
|
|
|
if isinstance(fabric.strategy, FSDPStrategy): |
|
save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True) |
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): |
|
state_dict = model._forward_module.state_dict() |
|
else: |
|
state_dict = model.state_dict() |
|
|
|
if fabric.global_rank == 0: |
|
torch.save(state_dict, file_path) |
|
fabric.barrier() |
|
|
|
|
|
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): |
|
def __init__(self, device=None, dtype=None, quantization_mode=None): |
|
""" |
|
Create tensors with given device and dtype and don't run initialization |
|
(but instead use "empty tensors", i.e. uninitialized memory). |
|
|
|
device: `torch.device` to work with |
|
dtype: `torch.dtype` to work with |
|
quantization_mode: optional string, quantization mode to work with, default `None`. |
|
Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU) |
|
`gptq.int4`, `gptq.int8`: GPTQ pre-quantized models |
|
|
|
Example:: |
|
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): |
|
model = LLaMA.from_name('7B') |
|
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))""" |
|
|
|
self.quantization_mode = quantization_mode |
|
self.quantized_linear_cls = None |
|
if self.quantization_mode == 'llm.int8': |
|
if device.type != "cuda": |
|
raise ValueError("Quantization is only supported on the GPU.") |
|
from lit_llama.quantization import Linear8bitLt |
|
self.quantized_linear_cls = Linear8bitLt |
|
elif self.quantization_mode == 'gptq.int4': |
|
from lit_llama.quantization import ColBlockQuantizedLinear |
|
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1) |
|
elif self.quantization_mode == 'gptq.int8': |
|
from lit_llama.quantization import ColBlockQuantizedLinear |
|
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1) |
|
elif self.quantization_mode is not None: |
|
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}") |
|
self.device = device |
|
self.dtype = dtype |
|
|
|
def __enter__(self): |
|
if self.quantized_linear_cls != None: |
|
self.torch_linear_cls = torch.nn.Linear |
|
torch.nn.Linear = self.quantized_linear_cls |
|
return super().__enter__() |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
if self.quantized_linear_cls != None: |
|
torch.nn.Linear = self.torch_linear_cls |
|
return super().__exit__(exc_type, exc_val, exc_tb) |
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None): |
|
kwargs = kwargs or {} |
|
if getattr(func, "__module__", None) == "torch.nn.init": |
|
if "tensor" in kwargs: |
|
return kwargs["tensor"] |
|
else: |
|
return args[0] |
|
if ( |
|
self.device is not None |
|
and func in torch.utils._device._device_constructors() |
|
and kwargs.get("device") is None |
|
): |
|
kwargs["device"] = self.device |
|
if ( |
|
self.dtype is not None |
|
and func in torch.utils._device._device_constructors() |
|
and kwargs.get("dtype") is None |
|
): |
|
kwargs["dtype"] = self.dtype |
|
return func(*args, **kwargs) |
|
|
|
|
|
@contextmanager |
|
def quantization(mode: str = None): |
|
quantized_linear_cls = None |
|
if mode == 'llm.int8': |
|
from .quantization import Linear8bitLt |
|
quantized_linear_cls = Linear8bitLt |
|
elif mode == 'gptq.int4': |
|
from .quantization import ColBlockQuantizedLinear |
|
quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1) |
|
elif mode == 'gptq.int8': |
|
from .quantization import ColBlockQuantizedLinear |
|
quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1) |
|
elif mode is not None: |
|
raise ValueError(f"Unknown quantization mode: {mode}") |
|
|
|
enabled = mode is not None |
|
torch_linear_cls = torch.nn.Linear |
|
if enabled: |
|
torch.nn.Linear = quantized_linear_cls |
|
yield |
|
if enabled: |
|
torch.nn.Linear = torch_linear_cls |
|
|
|
|
|
|
|
|
|
|
|
class NotYetLoadedTensor: |
|
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): |
|
self.metatensor = metatensor |
|
self.archiveinfo = archiveinfo |
|
self.storageinfo = storageinfo |
|
self.rebuild_args = rebuild_args |
|
|
|
@classmethod |
|
def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): |
|
ret = func(*args) |
|
if isinstance(ret, NotYetLoadedTensor): |
|
old_lt = ret._load_tensor |
|
|
|
def _load_tensor(): |
|
t = old_lt() |
|
return torch._tensor._rebuild_from_type_v2( |
|
lambda: t, new_type, (), state |
|
) |
|
|
|
ret._load_tensor = _load_tensor |
|
return ret |
|
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) |
|
|
|
@classmethod |
|
def rebuild_parameter( |
|
cls, data, requires_grad, backward_hooks, *, archiveinfo=None |
|
): |
|
if isinstance(data, NotYetLoadedTensor): |
|
old_lt = data._load_tensor |
|
|
|
def _load_tensor(): |
|
t = old_lt() |
|
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) |
|
|
|
data._load_tensor = _load_tensor |
|
return data |
|
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) |
|
|
|
@classmethod |
|
def rebuild_tensor_v2( |
|
cls, |
|
storage, |
|
storage_offset, |
|
size, |
|
stride, |
|
requires_grad, |
|
backward_hooks, |
|
metadata=None, |
|
*, |
|
archiveinfo=None, |
|
): |
|
rebuild_args = ( |
|
storage_offset, |
|
size, |
|
stride, |
|
requires_grad, |
|
backward_hooks, |
|
metadata, |
|
) |
|
metatensor = torch._utils._rebuild_tensor_v2( |
|
storage, |
|
storage_offset, |
|
size, |
|
stride, |
|
requires_grad, |
|
backward_hooks, |
|
metadata, |
|
) |
|
storageinfo = storage.archiveinfo |
|
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) |
|
|
|
def _load_tensor(self): |
|
name, storage_cls, fn, device, size = self.storageinfo |
|
dtype = self.metatensor.dtype |
|
|
|
uts = ( |
|
self.archiveinfo.zipfile_context.zf.get_storage_from_record( |
|
f"data/{fn}", |
|
size * torch._utils._element_size(dtype), |
|
torch.UntypedStorage, |
|
) |
|
._typed_storage() |
|
._untyped_storage |
|
) |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
storage = torch.storage.TypedStorage( |
|
wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True |
|
) |
|
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) |
|
return tensor |
|
|
|
@classmethod |
|
def __torch_function__(cls, func, types, args=(), kwargs=None): |
|
if kwargs is None: |
|
kwargs = {} |
|
loaded_args = [ |
|
(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args |
|
] |
|
res = func(*loaded_args, **kwargs) |
|
|
|
return res |
|
|
|
def __getattr__(self, name): |
|
|
|
|
|
|
|
|
|
if name in { |
|
"dtype", |
|
"grad", |
|
"grad_fn", |
|
"layout", |
|
"names", |
|
"ndim", |
|
"output_nr", |
|
"requires_grad", |
|
"retains_grad", |
|
"shape", |
|
"volatile", |
|
}: |
|
return getattr(self.metatensor, name) |
|
if name in {"size"}: |
|
return getattr(self.metatensor, name) |
|
|
|
if name in {"contiguous"}: |
|
return getattr(self._load_tensor(), name) |
|
|
|
raise AttributeError(f"{type(self)} does not have {name}") |
|
|
|
def __repr__(self): |
|
return f"NotYetLoadedTensor({repr(self.metatensor)})" |
|
|
|
|
|
class LazyLoadingUnpickler(pickle.Unpickler): |
|
def __init__(self, file, zipfile_context): |
|
super().__init__(file) |
|
self.zipfile_context = zipfile_context |
|
|
|
def find_class(self, module, name): |
|
res = super().find_class(module, name) |
|
if module == "torch._utils" and name == "_rebuild_tensor_v2": |
|
return functools.partial( |
|
NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self |
|
) |
|
elif module == "torch._tensor" and name == "_rebuild_from_type_v2": |
|
return functools.partial( |
|
NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self |
|
) |
|
elif module == "torch._utils" and name == "_rebuild_parameter": |
|
return functools.partial( |
|
NotYetLoadedTensor.rebuild_parameter, archiveinfo=self |
|
) |
|
return res |
|
|
|
def persistent_load(self, pid): |
|
name, cls, fn, device, size = pid |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") |
|
s.archiveinfo = pid |
|
return s |
|
|
|
|
|
class lazy_load: |
|
def __init__(self, fn): |
|
self.zf = torch._C.PyTorchFileReader(str(fn)) |
|
with BytesIO(self.zf.get_record("data.pkl")) as pkl: |
|
mup = LazyLoadingUnpickler(pkl, self) |
|
self.sd = mup.load() |
|
|
|
def __enter__(self): |
|
return self.sd |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
del self.zf |
|
self.zf = None |
|
|
|
|
|
class SavingProxyForStorage: |
|
def __init__(self, obj, saver, protocol_version=5): |
|
self.protocol_version = protocol_version |
|
self.saver = saver |
|
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): |
|
raise TypeError(f"expected storage, not {type(obj)}") |
|
|
|
|
|
if isinstance(obj, torch.storage.TypedStorage): |
|
|
|
storage = obj._untyped_storage |
|
storage_type_str = obj._pickle_storage_type() |
|
storage_type = getattr(torch, storage_type_str) |
|
storage_numel = obj._size() |
|
else: |
|
storage = obj |
|
storage_type = normalize_storage_type(type(obj)) |
|
storage_numel = storage.nbytes() |
|
|
|
storage_key = saver._write_storage_and_return_key(storage) |
|
location = torch.serialization.location_tag(storage) |
|
|
|
self.storage_info = ( |
|
"storage", |
|
storage_type, |
|
storage_key, |
|
location, |
|
storage_numel, |
|
) |
|
|
|
def __reduce_ex__(self, protocol_version): |
|
assert False, "this should be handled with out of band" |
|
|
|
|
|
class SavingProxyForTensor: |
|
def __init__(self, tensor, saver, protocol_version=5): |
|
self.protocol_version = protocol_version |
|
self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__( |
|
protocol_version |
|
) |
|
assert isinstance( |
|
storage, torch.storage.TypedStorage |
|
), "Please check for updates" |
|
storage_proxy = SavingProxyForStorage( |
|
storage, saver, protocol_version=protocol_version |
|
) |
|
self.reduce_args = (storage_proxy, *other_reduce_args) |
|
|
|
def __reduce_ex__(self, protocol_version): |
|
if protocol_version != self.protocol_version: |
|
raise RuntimeError( |
|
f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}" |
|
) |
|
return self.reduce_ret_fn, self.reduce_args |
|
|
|
|
|
class IncrementalPyTorchPickler(pickle.Pickler): |
|
def __init__(self, saver, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.storage_dtypes = {} |
|
self.saver = saver |
|
self.id_map = {} |
|
|
|
|
|
def persistent_id(self, obj): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(obj, SavingProxyForStorage): |
|
return obj.storage_info |
|
|
|
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): |
|
if isinstance(obj, torch.storage.TypedStorage): |
|
|
|
|
|
storage = obj._untyped_storage |
|
storage_dtype = obj.dtype |
|
storage_type_str = obj._pickle_storage_type() |
|
storage_type = getattr(torch, storage_type_str) |
|
storage_numel = obj._size() |
|
|
|
else: |
|
storage = obj |
|
storage_dtype = torch.uint8 |
|
storage_type = normalize_storage_type(type(obj)) |
|
storage_numel = storage.nbytes() |
|
|
|
|
|
|
|
|
|
if storage.data_ptr() != 0: |
|
if storage.data_ptr() in self.storage_dtypes: |
|
if storage_dtype != self.storage_dtypes[storage.data_ptr()]: |
|
raise RuntimeError( |
|
"Cannot save multiple tensors or storages that " |
|
"view the same data as different types" |
|
) |
|
else: |
|
self.storage_dtypes[storage.data_ptr()] = storage_dtype |
|
|
|
storage_key = self.id_map.get(storage._cdata) |
|
if storage_key is None: |
|
storage_key = self.saver._write_storage_and_return_key(storage) |
|
self.id_map[storage._cdata] = storage_key |
|
location = torch.serialization.location_tag(storage) |
|
|
|
return ("storage", storage_type, storage_key, location, storage_numel) |
|
|
|
return None |
|
|
|
|
|
class incremental_save: |
|
def __init__(self, name): |
|
self.name = name |
|
self.zipfile = torch._C.PyTorchFileWriter(str(name)) |
|
self.has_saved = False |
|
self.next_key = 0 |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def store_early(self, tensor): |
|
if isinstance(tensor, torch.Tensor): |
|
return SavingProxyForTensor(tensor, self) |
|
raise TypeError(f"can only store tensors early, not {type(tensor)}") |
|
|
|
def save(self, obj): |
|
if self.has_saved: |
|
raise RuntimeError("have already saved") |
|
|
|
data_buf = BytesIO() |
|
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) |
|
pickler.dump(obj) |
|
data_value = data_buf.getvalue() |
|
self.zipfile.write_record("data.pkl", data_value, len(data_value)) |
|
self.has_saved = True |
|
|
|
def _write_storage_and_return_key(self, storage): |
|
if self.has_saved: |
|
raise RuntimeError("have already saved") |
|
key = self.next_key |
|
self.next_key += 1 |
|
name = f"data/{key}" |
|
if storage.device.type != "cpu": |
|
storage = storage.cpu() |
|
num_bytes = storage.nbytes() |
|
self.zipfile.write_record(name, storage.data_ptr(), num_bytes) |
|
return key |
|
|
|
def __exit__(self, type, value, traceback): |
|
self.zipfile.write_end_of_file() |
|
|