"""Utility functions for training and inference.""" import math import pickle import sys from contextlib import nullcontext from io import BytesIO from pathlib import Path from typing import ( TYPE_CHECKING, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union, ) import lightning as L import torch import torch.nn as nn import torch.utils._device from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities.load import _lazy_load as lazy_load from torch.serialization import normalize_storage_type if TYPE_CHECKING: from model import GPT def find_multiple(n: int, k: int) -> int: assert k > 0 if n % k == 0: return n return n + k - (n % k) def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: total = 0 for p in module.parameters(): if requires_grad is None or p.requires_grad == requires_grad: if hasattr(p, "quant_state"): # bitsandbytes 4bit layer support total += math.prod(p.quant_state[1]) else: total += p.numel() return total def gptq_quantization(enabled: bool = False) -> ContextManager: if not enabled: return nullcontext() from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager from quantize.gptq import ColBlockQuantizedLinear class QuantizedLinear(ColBlockQuantizedLinear): def __init__(self, *args, **kwargs): super().__init__(*args, bits=4, tile_cols=-1, **kwargs) return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear}) def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: files = { "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), "tokenizer.json OR tokenizer.model": ( checkpoint_dir / "tokenizer.json" ).is_file() or (checkpoint_dir / "tokenizer.model").is_file(), "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), } if checkpoint_dir.is_dir(): if all(files.values()): # we're good return problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" else: problem = " is not a checkpoint directory" # list locally available checkpoints available = list(Path("checkpoints").glob("*/*")) if available: options = "\n --checkpoint_dir ".join( [""] + [repr(str(p.resolve())) for p in available] ) extra = f"\nYou have downloaded locally:{options}\n" else: extra = "" error_message = ( f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" f"{extra}\nSee all download options by running:\n python scripts/download.py" ) print(error_message, file=sys.stderr) raise SystemExit(1) class SavingProxyForStorage: def __init__(self, obj, saver, protocol_version=5): self.protocol_version = protocol_version self.saver = saver if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): raise TypeError(f"expected storage, not {type(obj)}") # this logic is taken from PyTorch 2.0+ torch/serialization.py if isinstance(obj, torch.storage.TypedStorage): # PT upstream wants to deprecate this eventually... storage = obj._untyped_storage storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj._size() else: storage = obj storage_type = normalize_storage_type(type(obj)) storage_numel = storage.nbytes() storage_key = saver._write_storage_and_return_key(storage) location = torch.serialization.location_tag(storage) self.storage_info = ( "storage", storage_type, storage_key, location, storage_numel, ) def __reduce_ex__(self, protocol_version): assert False, "this should be handled with out of band" class SavingProxyForTensor: def __init__(self, tensor, saver, protocol_version=5): self.protocol_version = protocol_version self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version) if reduce_args[0] == torch._utils._rebuild_tensor_v2: # for Tensors with Python attributes (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args assert isinstance( storage, torch.storage.TypedStorage ), "Please check for updates" storage_proxy = SavingProxyForStorage( storage, saver, protocol_version=protocol_version ) self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args) else: (storage, *other_reduce_args) = reduce_args assert isinstance( storage, torch.storage.TypedStorage ), "Please check for updates" storage_proxy = SavingProxyForStorage( storage, saver, protocol_version=protocol_version ) self.reduce_args = (storage_proxy, *other_reduce_args) def __reduce_ex__(self, protocol_version): if protocol_version != self.protocol_version: raise RuntimeError( f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}" ) return self.reduce_ret_fn, self.reduce_args class IncrementalPyTorchPickler(pickle.Pickler): def __init__(self, saver, *args, **kwargs): super().__init__(*args, **kwargs) self.storage_dtypes = {} self.saver = saver self.id_map = {} # this logic is taken from PyTorch 2.0+ torch/serialization.py def persistent_id(self, obj): # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 if isinstance(obj, SavingProxyForStorage): return obj.storage_info if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted storage = obj._untyped_storage storage_dtype = obj.dtype storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj._size() else: storage = obj storage_dtype = torch.uint8 storage_type = normalize_storage_type(type(obj)) storage_numel = storage.nbytes() # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check if storage.data_ptr() != 0: if storage.data_ptr() in self.storage_dtypes: if storage_dtype != self.storage_dtypes[storage.data_ptr()]: raise RuntimeError( "Cannot save multiple tensors or storages that view the same data as different types" ) else: self.storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = self.id_map.get(storage._cdata) if storage_key is None: storage_key = self.saver._write_storage_and_return_key(storage) self.id_map[storage._cdata] = storage_key location = torch.serialization.location_tag(storage) return ("storage", storage_type, storage_key, location, storage_numel) return None class incremental_save: def __init__(self, name): self.name = name self.zipfile = torch._C.PyTorchFileWriter(str(name)) self.has_saved = False self.next_key = 0 def __enter__(self): return self def store_early(self, tensor): if isinstance(tensor, torch.Tensor): return SavingProxyForTensor(tensor, self) raise TypeError(f"can only store tensors early, not {type(tensor)}") def save(self, obj): if self.has_saved: raise RuntimeError("have already saved") # Write the pickle data for `obj` data_buf = BytesIO() pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) pickler.dump(obj) data_value = data_buf.getvalue() self.zipfile.write_record("data.pkl", data_value, len(data_value)) self.has_saved = True def _write_storage_and_return_key(self, storage): if self.has_saved: raise RuntimeError("have already saved") key = self.next_key self.next_key += 1 name = f"data/{key}" if storage.device.type != "cpu": storage = storage.cpu() num_bytes = storage.nbytes() self.zipfile.write_record(name, storage.data_ptr(), num_bytes) return key def __exit__(self, type, value, traceback): self.zipfile.write_end_of_file() T = TypeVar("T") def chunked_cross_entropy( logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128, ) -> torch.Tensor: # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate # the memory usage in fine-tuning settings with low number of parameters. # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing # the memory spike's magnitude # lm_head was chunked (we are fine-tuning) if isinstance(logits, list): # don't want to chunk cross entropy if chunk_size == 0: logits = torch.cat(logits, dim=1) logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) # chunk cross entropy logit_chunks = [ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits ] target_chunks = [ target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1) ] loss_chunks = [ torch.nn.functional.cross_entropy( logit_chunk, target_chunk, ignore_index=-1, reduction="none" ) for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] return torch.cat(loss_chunks).mean() # no chunking at all logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) if chunk_size == 0: return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) # lm_head wasn't chunked, chunk cross entropy logit_chunks = logits.split(chunk_size) target_chunks = targets.split(chunk_size) loss_chunks = [ torch.nn.functional.cross_entropy( logit_chunk, target_chunk, ignore_index=-1, reduction="none" ) for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] return torch.cat(loss_chunks).mean() def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: for checkpoint_name, attribute_name in mapping.items(): full_checkpoint_name = prefix + checkpoint_name if full_checkpoint_name in state_dict: full_attribute_name = prefix + attribute_name state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) return state_dict def get_default_supported_precision(training: bool) -> str: """Return default precision that is supported by the hardware: either `bf16` or `16`. Args: training: `-mixed` or `-true` version of the precision to use Returns: default precision that is suitable for the task and is supported by the hardware """ from lightning.fabric.accelerators import MPSAccelerator if MPSAccelerator.is_available() or ( torch.cuda.is_available() and not torch.cuda.is_bf16_supported() ): return "16-mixed" if training else "16-true" return "bf16-mixed" if training else "bf16-true" def load_checkpoint( fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True ) -> None: if isinstance(fabric.strategy, FSDPStrategy): fabric.load_raw(checkpoint_path, model, strict=strict) else: state_dict = lazy_load(checkpoint_path) state_dict = state_dict.get("model", state_dict) model.load_state_dict(state_dict, strict=strict) def flops_per_param( max_seq_length: int, n_layer: int, n_embd: int, n_params: int ) -> int: flops_per_token = ( 2 * n_params ) # each parameter is used for a MAC (2 FLOPS) per network operation # this assumes that all samples have a fixed length equal to the block size # which is most likely false during finetuning flops_per_seq = flops_per_token * max_seq_length attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) return flops_per_seq + attn_flops_per_seq def estimate_flops(model: "GPT", training: bool) -> int: """Measures estimated FLOPs for MFU. Refs: * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 """ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage # (~10%) compared to the measured FLOPs, making those lower but more realistic. # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. n_trainable_params = num_parameters(model, requires_grad=True) trainable_flops = flops_per_param( model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params, ) # forward + backward + gradients (assumes no gradient accumulation) ops_per_step = 3 if training else 1 n_frozen_params = num_parameters(model, requires_grad=False) frozen_flops = flops_per_param( model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params ) # forward + backward frozen_ops_per_step = 2 if training else 1 return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops