File size: 5,152 Bytes
b3c4079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import random
import torch.nn as nn
import lightning as L
from pathlib import Path
from torch.utils.data import DataLoader
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.strategies import FSDPStrategy

from tsai_gpt.model import GPT, Block, Config
from tsai_gpt.tokenizer import Tokenizer
from tsai_gpt.utils import get_default_supported_precision, load_checkpoint, gptq_quantization

model_name = "pythia-160m"
name = "redpajama"

checkpoint_dir = Path("iter-015000-ckpt.pth")
quantize = None
strategy = "auto"
devices = 1
precision = get_default_supported_precision(training=False)
plugins = None
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
fabric.launch()

example_text = [
    "In the middle of the enchanted forest, there was a magical pond where the water shimmered with a faint glow of",
    "The detective carefully examined the crime scene, searching for any clues that might lead to the identity of the",
    "In the middle of the enchanted forest, there was a magical pond where the water shimmered with a faint glow of",
    "The time machine malfunctioned, sending the protagonist to a dystopian future where robots had taken over and humans were forced to live underground to escape the threat of ",
    "In the parallel universe, gravity worked differently, causing objects to float in the air as if affected by an invisible"
]

examples = [
    [
        text,
        round(random.uniform(0.6, 0.9), 1),
        round(int(random.uniform(120, 250)) / 10) * 10,
        round(int(random.uniform(50, 100)) / 10) * 10,
    ] for text in example_text
]

with fabric.init_module(empty_init=True), gptq_quantization(quantize=="gptq.int4"):
    config = Config.from_name(model_name)
    model = GPT(config)

model.eval()
model = fabric.setup_module(model)
load_checkpoint(fabric, model, checkpoint_dir)

tokenizer = Tokenizer(Path('tokenizer_files'))

@torch.inference_mode()
def generate(
    model: GPT,
    idx: torch.Tensor,
    max_returned_tokens: int,
    *,
    temperature: float = 1.0,
    top_k:int = None,
    eos_id:int = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
    The implementation of this function is modified from A. Karpathy's nanoGPT.
    Args:
        model: The model to use.
        idx: Tensor of shape (T) with indices of the prompt sequence.
        max_returned_tokens: The maximum number of tokens to return (given plus generated).
        temperature: Scales the predicted logits by 1 / temperature.
        top_k: If specified, only sample among the tokens with the k highest probabilities.
        eos_id: If specified, stop generating any more token once the <eos> token is triggered.
    """
    T = idx.size(0)
    assert max_returned_tokens > T
    if model.max_seq_length < max_returned_tokens - 1:
        # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
        # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
        # not support it to avoid negatively impacting the overall speed
        raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

    device, dtype = idx.device, idx.dtype
    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
    empty[:T] = idx
    idx = empty
    input_pos = torch.arange(0, T, device=device)

    # generate up to a fixed number of tokens
    for _ in range(max_returned_tokens - T):
        x = idx.index_select(0, input_pos).view(1, -1)

        # forward
        logits = model(x, input_pos)
        logits = logits[0, -1] / temperature

        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

        # advance
        input_pos = input_pos[-1:] + 1

        # concatenate the new generation
        idx = idx.index_copy(0, input_pos, idx_next)

        # if <eos> token is triggered, return the output (stop generation)
        if idx_next == eos_id:
            return idx[:input_pos]  # include the EOS token

    return idx

def generate_context(input_text, temperature, max_tokens, top_k):

    encoded = tokenizer.encode(input_text, device=fabric.device)
    
    max_returned_tokens = encoded.size(0) + max_tokens

    with fabric.init_tensor():
        # set the max_seq_length to limit the memory usage to what we need
        model.max_seq_length = max_returned_tokens
        
    with fabric.init_tensor():
        model.set_kv_cache(batch_size=1)

    y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)

    return(tokenizer.decode(y))