TharunSivamani commited on
Commit
4e54cb1
1 Parent(s): bd05363
Files changed (1) hide show
  1. utils.py +131 -0
utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import torch.nn as nn
4
+ import lightning as L
5
+ from pathlib import Path
6
+ from torch.utils.data import DataLoader
7
+ from lightning.fabric.loggers import CSVLogger
8
+ from lightning.fabric.strategies import FSDPStrategy
9
+
10
+ from tsai_gpt.model import GPT, Block, Config
11
+ from tsai_gpt.tokenizer import Tokenizer
12
+ from tsai_gpt.utils import get_default_supported_precision, load_checkpoint, gptq_quantization
13
+
14
+ model_name = "pythia-160m"
15
+ name = "redpajama"
16
+
17
+ checkpoint_dir = Path("iter-015000-ckpt.pth")
18
+ quantize = None
19
+ strategy = "auto"
20
+ devices = 1
21
+ precision = get_default_supported_precision(training=False)
22
+ plugins = None
23
+ fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
24
+ fabric.launch()
25
+
26
+ example_text = [
27
+ "In the middle of the enchanted forest, there was a magical pond where the water shimmered with a faint glow of",
28
+ "The detective carefully examined the crime scene, searching for any clues that might lead to the identity of the",
29
+ "In the middle of the enchanted forest, there was a magical pond where the water shimmered with a faint glow of",
30
+ "The time machine malfunctioned, sending the protagonist to a dystopian future where robots had taken over and humans were forced to live underground to escape the threat of ",
31
+ "In the parallel universe, gravity worked differently, causing objects to float in the air as if affected by an invisible"
32
+ ]
33
+
34
+ examples = [
35
+ [
36
+ text,
37
+ round(random.uniform(0.6, 0.9), 1),
38
+ round(int(random.uniform(120, 250)) / 10) * 10,
39
+ round(int(random.uniform(50, 100)) / 10) * 10,
40
+ ] for text in example_text
41
+ ]
42
+
43
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize=="gptq.int4"):
44
+ config = Config.from_name(model_name)
45
+ model = GPT(config)
46
+
47
+ model.eval()
48
+ model = fabric.setup_module(model)
49
+ load_checkpoint(fabric, model, checkpoint_dir)
50
+
51
+ tokenizer = Tokenizer(Path('tokenizer_files'))
52
+
53
+ @torch.inference_mode()
54
+ def generate(
55
+ model: GPT,
56
+ idx: torch.Tensor,
57
+ max_returned_tokens: int,
58
+ *,
59
+ temperature: float = 1.0,
60
+ top_k:int = None,
61
+ eos_id:int = None,
62
+ ) -> torch.Tensor:
63
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
64
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
65
+ Args:
66
+ model: The model to use.
67
+ idx: Tensor of shape (T) with indices of the prompt sequence.
68
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
69
+ temperature: Scales the predicted logits by 1 / temperature.
70
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
71
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
72
+ """
73
+ T = idx.size(0)
74
+ assert max_returned_tokens > T
75
+ if model.max_seq_length < max_returned_tokens - 1:
76
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
77
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
78
+ # not support it to avoid negatively impacting the overall speed
79
+ raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
80
+
81
+ device, dtype = idx.device, idx.dtype
82
+ # create an empty tensor of the expected final shape and fill in the current tokens
83
+ empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
84
+ empty[:T] = idx
85
+ idx = empty
86
+ input_pos = torch.arange(0, T, device=device)
87
+
88
+ # generate up to a fixed number of tokens
89
+ for _ in range(max_returned_tokens - T):
90
+ x = idx.index_select(0, input_pos).view(1, -1)
91
+
92
+ # forward
93
+ logits = model(x, input_pos)
94
+ logits = logits[0, -1] / temperature
95
+
96
+ # optionally crop the logits to only the top k options
97
+ if top_k is not None:
98
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
99
+ logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
100
+
101
+ probs = torch.nn.functional.softmax(logits, dim=-1)
102
+ idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
103
+
104
+ # advance
105
+ input_pos = input_pos[-1:] + 1
106
+
107
+ # concatenate the new generation
108
+ idx = idx.index_copy(0, input_pos, idx_next)
109
+
110
+ # if <eos> token is triggered, return the output (stop generation)
111
+ if idx_next == eos_id:
112
+ return idx[:input_pos] # include the EOS token
113
+
114
+ return idx
115
+
116
+ def generate_context(input_text, temperature, max_tokens, top_k):
117
+
118
+ encoded = tokenizer.encode(input_text, device=fabric.device)
119
+
120
+ max_returned_tokens = encoded.size(0) + max_tokens
121
+
122
+ with fabric.init_tensor():
123
+ # set the max_seq_length to limit the memory usage to what we need
124
+ model.max_seq_length = max_returned_tokens
125
+
126
+ with fabric.init_tensor():
127
+ model.set_kv_cache(batch_size=1)
128
+
129
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
130
+
131
+ return(tokenizer.decode(y))