"""Utility functions for training and inference."""
import math
import pickle
import sys
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union
import lightning as L
import torch
import torch.nn as nn
import torch.utils._device
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from torch.serialization import normalize_storage_type
from model import GPT
def find_multiple(n: int, k: int) -> int:
assert k > 0
if n % k == 0:
return n
return n + k - (n % k)
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
total = 0
for p in module.parameters():
if requires_grad is None or p.requires_grad == requires_grad:
if hasattr(p, "quant_state"):
# bitsandbytes 4bit layer support
total += math.prod(p.quant_state[1])
total += p.numel()
return total
def gptq_quantization(enabled: bool = False) -> ContextManager:
if not enabled:
return nullcontext()
from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager
from quantize.gptq import ColBlockQuantizedLinear
class QuantizedLinear(ColBlockQuantizedLinear):
def __init__(self, *args, **kwargs):
super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
files = {
"lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
"lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (
checkpoint_dir / "tokenizer.model"
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
if checkpoint_dir.is_dir():
if all(files.values()):
# we're good
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
problem = " is not a checkpoint directory"
# list locally available checkpoints
available = list(Path("checkpoints").glob("*/*"))
if available:
options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
extra = f"\nYou have downloaded locally:{options}\n"
extra = ""
error_message = (
f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
"\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
f"{extra}\nSee all download options by running:\n python scripts/download.py"
print(error_message, file=sys.stderr)
raise SystemExit(1)
class SavingProxyForStorage:
def __init__(self, obj, saver, protocol_version=5):
self.protocol_version = protocol_version
self.saver = saver
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
raise TypeError(f"expected storage, not {type(obj)}")
# this logic is taken from PyTorch 2.0+ torch/serialization.py
if isinstance(obj, torch.storage.TypedStorage):
# PT upstream wants to deprecate this eventually...
storage = obj._untyped_storage
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
storage = obj
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
storage_key = saver._write_storage_and_return_key(storage)
location = torch.serialization.location_tag(storage)
self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
def __reduce_ex__(self, protocol_version):
assert False, "this should be handled with out of band"
class SavingProxyForTensor:
def __init__(self, tensor, saver, protocol_version=5):
self.protocol_version = protocol_version
self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
# for Tensors with Python attributes
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
(storage, *other_reduce_args) = reduce_args
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
self.reduce_args = (storage_proxy, *other_reduce_args)
def __reduce_ex__(self, protocol_version):
if protocol_version != self.protocol_version:
raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}")
return self.reduce_ret_fn, self.reduce_args
class IncrementalPyTorchPickler(pickle.Pickler):
def __init__(self, saver, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage_dtypes = {}
self.saver = saver
self.id_map = {}
# this logic is taken from PyTorch 2.0+ torch/serialization.py
def persistent_id(self, obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, SavingProxyForStorage):
return obj.storage_info
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in self.storage_dtypes:
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
"Cannot save multiple tensors or storages that view the same data as different types"
self.storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = self.id_map.get(storage._cdata)
if storage_key is None:
storage_key = self.saver._write_storage_and_return_key(storage)
self.id_map[storage._cdata] = storage_key
location = torch.serialization.location_tag(storage)
return ("storage", storage_type, storage_key, location, storage_numel)
return None
class incremental_save:
def __init__(self, name):
self.name = name
self.zipfile = torch._C.PyTorchFileWriter(str(name))
self.has_saved = False
self.next_key = 0
def __enter__(self):
return self
def store_early(self, tensor):
if isinstance(tensor, torch.Tensor):
return SavingProxyForTensor(tensor, self)
raise TypeError(f"can only store tensors early, not {type(tensor)}")
def save(self, obj):
if self.has_saved:
raise RuntimeError("have already saved")
# Write the pickle data for `obj`
data_buf = BytesIO()
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
data_value = data_buf.getvalue()
self.zipfile.write_record("data.pkl", data_value, len(data_value))
self.has_saved = True
def _write_storage_and_return_key(self, storage):
if self.has_saved:
raise RuntimeError("have already saved")
key = self.next_key
self.next_key += 1
name = f"data/{key}"
if storage.device.type != "cpu":
storage = storage.cpu()
num_bytes = storage.nbytes()
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
return key
def __exit__(self, type, value, traceback):
T = TypeVar("T")
def chunked_cross_entropy(
logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
) -> torch.Tensor:
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
# the memory usage in fine-tuning settings with low number of parameters.
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
# the memory spike's magnitude
# lm_head was chunked (we are fine-tuning)
if isinstance(logits, list):
# don't want to chunk cross entropy
if chunk_size == 0:
logits = torch.cat(logits, dim=1)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
# chunk cross entropy
logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
return torch.cat(loss_chunks).mean()
# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
if chunk_size == 0:
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
# lm_head wasn't chunked, chunk cross entropy
logit_chunks = logits.split(chunk_size)
target_chunks = targets.split(chunk_size)
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
return torch.cat(loss_chunks).mean()
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
for checkpoint_name, attribute_name in mapping.items():
full_checkpoint_name = prefix + checkpoint_name
if full_checkpoint_name in state_dict:
full_attribute_name = prefix + attribute_name
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
return state_dict
def get_default_supported_precision(training: bool) -> str:
"""Return default precision that is supported by the hardware: either `bf16` or `16`.
training: `-mixed` or `-true` version of the precision to use
default precision that is suitable for the task and is supported by the hardware
from lightning.fabric.accelerators import MPSAccelerator
if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()):
return "16-mixed" if training else "16-true"
return "bf16-mixed" if training else "bf16-true"
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
state_dict = lazy_load(checkpoint_path)
state_dict = state_dict.get("model", state_dict)
model.load_state_dict(state_dict, strict=strict)
def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
# this assumes that all samples have a fixed length equal to the block size
# which is most likely false during finetuning
flops_per_seq = flops_per_token * max_seq_length
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
return flops_per_seq + attn_flops_per_seq
def estimate_flops(model: "GPT", training: bool) -> int:
"""Measures estimated FLOPs for MFU.
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
# (~10%) compared to the measured FLOPs, making those lower but more realistic.
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
n_trainable_params = num_parameters(model, requires_grad=True)
trainable_flops = flops_per_param(
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
# forward + backward + gradients (assumes no gradient accumulation)
ops_per_step = 3 if training else 1
n_frozen_params = num_parameters(model, requires_grad=False)
frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
# forward + backward
frozen_ops_per_step = 2 if training else 1
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops