import torch import random import torch.nn as nn import lightning as L from pathlib import Path from torch.utils.data import DataLoader from lightning.fabric.loggers import CSVLogger from lightning.fabric.strategies import FSDPStrategy from tsai_gpt.model import GPT, Block, Config from tsai_gpt.tokenizer import Tokenizer from tsai_gpt.utils import get_default_supported_precision, load_checkpoint, gptq_quantization model_name = "pythia-160m" name = "redpajama" checkpoint_dir = Path("iter-015000-ckpt.pth") quantize = None strategy = "auto" devices = 1 precision = get_default_supported_precision(training=False) plugins = None fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins) fabric.launch() example_text = [ "In the middle of the enchanted forest, there was a magical pond where the water shimmered with a faint glow of", "The detective carefully examined the crime scene, searching for any clues that might lead to the identity of the", "In the middle of the enchanted forest, there was a magical pond where the water shimmered with a faint glow of", "The time machine malfunctioned, sending the protagonist to a dystopian future where robots had taken over and humans were forced to live underground to escape the threat of ", "In the parallel universe, gravity worked differently, causing objects to float in the air as if affected by an invisible" ] examples = [ [ text, round(random.uniform(0.6, 0.9), 1), round(int(random.uniform(120, 250)) / 10) * 10, round(int(random.uniform(50, 100)) / 10) * 10, ] for text in example_text ] with fabric.init_module(empty_init=True), gptq_quantization(quantize=="gptq.int4"): config = Config.from_name(model_name) model = GPT(config) model.eval() model = fabric.setup_module(model) load_checkpoint(fabric, model, checkpoint_dir) tokenizer = Tokenizer(Path('tokenizer_files')) @torch.inference_mode() def generate( model: GPT, idx: torch.Tensor, max_returned_tokens: int, *, temperature: float = 1.0, top_k:int = None, eos_id:int = None, ) -> torch.Tensor: """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. The implementation of this function is modified from A. Karpathy's nanoGPT. Args: model: The model to use. idx: Tensor of shape (T) with indices of the prompt sequence. max_returned_tokens: The maximum number of tokens to return (given plus generated). temperature: Scales the predicted logits by 1 / temperature. top_k: If specified, only sample among the tokens with the k highest probabilities. eos_id: If specified, stop generating any more token once the token is triggered. """ T = idx.size(0) assert max_returned_tokens > T if model.max_seq_length < max_returned_tokens - 1: # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do # not support it to avoid negatively impacting the overall speed raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") device, dtype = idx.device, idx.dtype # create an empty tensor of the expected final shape and fill in the current tokens empty = torch.empty(max_returned_tokens, dtype=dtype, device=device) empty[:T] = idx idx = empty input_pos = torch.arange(0, T, device=device) # generate up to a fixed number of tokens for _ in range(max_returned_tokens - T): x = idx.index_select(0, input_pos).view(1, -1) # forward logits = model(x, input_pos) logits = logits[0, -1] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits = torch.where(logits < v[[-1]], -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) # advance input_pos = input_pos[-1:] + 1 # concatenate the new generation idx = idx.index_copy(0, input_pos, idx_next) # if token is triggered, return the output (stop generation) if idx_next == eos_id: return idx[:input_pos] # include the EOS token return idx def generate_context(input_text, temperature, max_tokens, top_k): encoded = tokenizer.encode(input_text, device=fabric.device) max_returned_tokens = encoded.size(0) + max_tokens with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens with fabric.init_tensor(): model.set_kv_cache(batch_size=1) y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k) return(tokenizer.decode(y))