TAPA / lit_llama /utils.py
xuxw98's picture
Update lit_llama/utils.py
c8ac827
raw
history blame
18.1 kB
"""Utility functions for training and inference."""
import functools
import pickle
import warnings
from io import BytesIO
from pathlib import Path
from contextlib import contextmanager
import torch
import torch.utils._device
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.serialization import normalize_storage_type
llama_model_sizes = {
4096: "7B", # 7B n_embd=4096
5120: "13B", # 13B n_embd=5120
6656: "30B", # 30B n_embd=6656
8192: "65B", # 65B n_embd=8192
}
def llama_model_lookup(checkpoint: dict) -> str:
"""Returns the LLaMA model name from the checkpoint.
Checks the width of the lm_head.weight matrix, as these uniquely identify the model.
"""
embedding_size = checkpoint['transformer.wte.weight'].shape[1]
return llama_model_sizes[embedding_size]
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
def save_model_checkpoint(fabric, model, file_path):
"""Handles boilerplate logic for retrieving and saving the state_dict.
This will be upstreamed to Fabric soon.
"""
file_path = Path(file_path)
if isinstance(fabric.strategy, DeepSpeedStrategy):
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
fabric.save(file_path, {"model": model})
fabric.barrier()
if fabric.global_rank == 0:
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
return
if isinstance(fabric.strategy, FSDPStrategy):
save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model._forward_module.state_dict()
else:
state_dict = model.state_dict()
if fabric.global_rank == 0:
torch.save(state_dict, file_path)
fabric.barrier()
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None, dtype=None, quantization_mode=None):
"""
Create tensors with given device and dtype and don't run initialization
(but instead use "empty tensors", i.e. uninitialized memory).
device: `torch.device` to work with
dtype: `torch.dtype` to work with
quantization_mode: optional string, quantization mode to work with, default `None`.
Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
`gptq.int4`, `gptq.int8`: GPTQ pre-quantized models
Example::
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
model = LLaMA.from_name('7B')
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""
self.quantization_mode = quantization_mode
self.quantized_linear_cls = None
if self.quantization_mode == 'llm.int8':
if device.type != "cuda":
raise ValueError("Quantization is only supported on the GPU.")
from .quantization import Linear8bitLt
self.quantized_linear_cls = Linear8bitLt
elif self.quantization_mode == 'gptq.int4':
from .quantization import ColBlockQuantizedLinear
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
elif self.quantization_mode == 'gptq.int8':
from .quantization import ColBlockQuantizedLinear
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
elif self.quantization_mode is not None:
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
self.device = device
self.dtype = dtype
def __enter__(self):
if self.quantized_linear_cls != None:
self.torch_linear_cls = torch.nn.Linear
torch.nn.Linear = self.quantized_linear_cls
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.quantized_linear_cls != None:
torch.nn.Linear = self.torch_linear_cls
return super().__exit__(exc_type, exc_val, exc_tb)
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, "__module__", None) == "torch.nn.init":
if "tensor" in kwargs:
return kwargs["tensor"]
else:
return args[0]
if (
self.device is not None
and func in torch.utils._device._device_constructors()
and kwargs.get("device") is None
):
kwargs["device"] = self.device
if (
self.dtype is not None
and func in torch.utils._device._device_constructors()
and kwargs.get("dtype") is None
):
kwargs["dtype"] = self.dtype
return func(*args, **kwargs)
@contextmanager
def quantization(mode: str = None):
quantized_linear_cls = None
if mode == 'llm.int8':
from .quantization import Linear8bitLt
quantized_linear_cls = Linear8bitLt
elif mode == 'gptq.int4':
from .quantization import ColBlockQuantizedLinear
quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
elif mode == 'gptq.int8':
from .quantization import ColBlockQuantizedLinear
quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
elif mode is not None:
raise ValueError(f"Unknown quantization mode: {mode}")
enabled = mode is not None
torch_linear_cls = torch.nn.Linear
if enabled:
torch.nn.Linear = quantized_linear_cls
yield
if enabled:
torch.nn.Linear = torch_linear_cls
# this is taken from torchhacks https://github.com/lernapparat/torchhacks
class NotYetLoadedTensor:
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
self.metatensor = metatensor
self.archiveinfo = archiveinfo
self.storageinfo = storageinfo
self.rebuild_args = rebuild_args
@classmethod
def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None):
ret = func(*args)
if isinstance(ret, NotYetLoadedTensor):
old_lt = ret._load_tensor
def _load_tensor():
t = old_lt()
return torch._tensor._rebuild_from_type_v2(
lambda: t, new_type, (), state
)
ret._load_tensor = _load_tensor
return ret
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)
@classmethod
def rebuild_parameter(
cls, data, requires_grad, backward_hooks, *, archiveinfo=None
):
if isinstance(data, NotYetLoadedTensor):
old_lt = data._load_tensor
def _load_tensor():
t = old_lt()
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)
data._load_tensor = _load_tensor
return data
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)
@classmethod
def rebuild_tensor_v2(
cls,
storage,
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata=None,
*,
archiveinfo=None,
):
rebuild_args = (
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata,
)
metatensor = torch._utils._rebuild_tensor_v2(
storage,
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata,
)
storageinfo = storage.archiveinfo
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
def _load_tensor(self):
name, storage_cls, fn, device, size = self.storageinfo
dtype = self.metatensor.dtype
uts = (
self.archiveinfo.zipfile_context.zf.get_storage_from_record(
f"data/{fn}",
size * torch._utils._element_size(dtype),
torch.UntypedStorage,
)
._typed_storage()
._untyped_storage
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
storage = torch.storage.TypedStorage(
wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
)
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
return tensor
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
loaded_args = [
(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
]
res = func(*loaded_args, **kwargs)
# gc.collect would be costly here, maybe do it optionally
return res
def __getattr__(self, name):
# properties
## TODO: device, is_...??
## TODO: mH, mT, H, T, data, imag, real
## name ???
if name in {
"dtype",
"grad",
"grad_fn",
"layout",
"names",
"ndim",
"output_nr",
"requires_grad",
"retains_grad",
"shape",
"volatile",
}:
return getattr(self.metatensor, name)
if name in {"size"}:
return getattr(self.metatensor, name)
# materializing with contiguous is needed for quantization
if name in {"contiguous"}:
return getattr(self._load_tensor(), name)
raise AttributeError(f"{type(self)} does not have {name}")
def __repr__(self):
return f"NotYetLoadedTensor({repr(self.metatensor)})"
class LazyLoadingUnpickler(pickle.Unpickler):
def __init__(self, file, zipfile_context):
super().__init__(file)
self.zipfile_context = zipfile_context
def find_class(self, module, name):
res = super().find_class(module, name)
if module == "torch._utils" and name == "_rebuild_tensor_v2":
return functools.partial(
NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self
)
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
return functools.partial(
NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self
)
elif module == "torch._utils" and name == "_rebuild_parameter":
return functools.partial(
NotYetLoadedTensor.rebuild_parameter, archiveinfo=self
)
return res
def persistent_load(self, pid):
name, cls, fn, device, size = pid
with warnings.catch_warnings():
warnings.simplefilter("ignore")
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
s.archiveinfo = pid
return s
class lazy_load:
def __init__(self, fn):
self.zf = torch._C.PyTorchFileReader(str(fn))
with BytesIO(self.zf.get_record("data.pkl")) as pkl:
mup = LazyLoadingUnpickler(pkl, self)
self.sd = mup.load()
def __enter__(self):
return self.sd
def __exit__(self, exc_type, exc_val, exc_tb):
del self.zf # I don't think there is a way to force closing...
self.zf = None
class SavingProxyForStorage:
def __init__(self, obj, saver, protocol_version=5):
self.protocol_version = protocol_version
self.saver = saver
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
raise TypeError(f"expected storage, not {type(obj)}")
# this logic is taken from PyTorch 2.0+ torch/serialization.py
if isinstance(obj, torch.storage.TypedStorage):
# PT upstream wants to deprecate this eventually...
storage = obj._untyped_storage
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
storage_key = saver._write_storage_and_return_key(storage)
location = torch.serialization.location_tag(storage)
self.storage_info = (
"storage",
storage_type,
storage_key,
location,
storage_numel,
)
def __reduce_ex__(self, protocol_version):
assert False, "this should be handled with out of band"
class SavingProxyForTensor:
def __init__(self, tensor, saver, protocol_version=5):
self.protocol_version = protocol_version
self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(
protocol_version
)
assert isinstance(
storage, torch.storage.TypedStorage
), "Please check for updates"
storage_proxy = SavingProxyForStorage(
storage, saver, protocol_version=protocol_version
)
self.reduce_args = (storage_proxy, *other_reduce_args)
def __reduce_ex__(self, protocol_version):
if protocol_version != self.protocol_version:
raise RuntimeError(
f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
)
return self.reduce_ret_fn, self.reduce_args
class IncrementalPyTorchPickler(pickle.Pickler):
def __init__(self, saver, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage_dtypes = {}
self.saver = saver
self.id_map = {}
# this logic is taken from PyTorch 2.0+ torch/serialization.py
def persistent_id(self, obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, SavingProxyForStorage):
return obj.storage_info
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in self.storage_dtypes:
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
"Cannot save multiple tensors or storages that "
"view the same data as different types"
)
else:
self.storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = self.id_map.get(storage._cdata)
if storage_key is None:
storage_key = self.saver._write_storage_and_return_key(storage)
self.id_map[storage._cdata] = storage_key
location = torch.serialization.location_tag(storage)
return ("storage", storage_type, storage_key, location, storage_numel)
return None
class incremental_save:
def __init__(self, name):
self.name = name
self.zipfile = torch._C.PyTorchFileWriter(str(name))
self.has_saved = False
self.next_key = 0
def __enter__(self):
return self
def store_early(self, tensor):
if isinstance(tensor, torch.Tensor):
return SavingProxyForTensor(tensor, self)
raise TypeError(f"can only store tensors early, not {type(tensor)}")
def save(self, obj):
if self.has_saved:
raise RuntimeError("have already saved")
# Write the pickle data for `obj`
data_buf = BytesIO()
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
pickler.dump(obj)
data_value = data_buf.getvalue()
self.zipfile.write_record("data.pkl", data_value, len(data_value))
self.has_saved = True
def _write_storage_and_return_key(self, storage):
if self.has_saved:
raise RuntimeError("have already saved")
key = self.next_key
self.next_key += 1
name = f"data/{key}"
if storage.device.type != "cpu":
storage = storage.cpu()
num_bytes = storage.nbytes()
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
return key
def __exit__(self, type, value, traceback):
self.zipfile.write_end_of_file()