Sijuade commited on
Commit
6b4240d
1 Parent(s): ab3a051

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +150 -0
  2. requirements.txt +5 -0
  3. utils.py +138 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from utils import *
4
+ from torch import nn
5
+ import lightning.pytorch as pl
6
+ from torch.nn import functional as F
7
+
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
+ HTML_TEMPLATE = """
11
+ <style>
12
+
13
+ #app-header {
14
+ text-align: center;
15
+ background: rgba(255, 255, 255, 0.3); /* Semi-transparent white */
16
+ padding: 20px;
17
+ border-radius: 10px;
18
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
19
+ position: relative; /* To position the artifacts */
20
+ }
21
+ #app-header h1 {
22
+ color: #FF0000;
23
+ font-size: 2em;
24
+ margin-bottom: 10px;
25
+ }
26
+ .concept {
27
+ position: relative;
28
+ transition: transform 0.3s;
29
+ }
30
+ .concept:hover {
31
+ transform: scale(1.1);
32
+ }
33
+ .concept img {
34
+ width: 100px;
35
+ border-radius: 10px;
36
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
37
+ }
38
+ .concept-description {
39
+ position: absolute;
40
+ bottom: -30px;
41
+ left: 50%;
42
+ transform: translateX(-50%);
43
+ background-color: #4CAF50;
44
+ color: white;
45
+ padding: 5px 10px;
46
+ border-radius: 5px;
47
+ opacity: 0;
48
+ transition: opacity 0.3s;
49
+ }
50
+ .concept:hover .concept-description {
51
+ opacity: 1;
52
+ }
53
+ /* Artifacts */
54
+
55
+ </style>
56
+ <div id="app-header">
57
+ <!-- Artifacts -->
58
+ <div class="artifact large"></div>
59
+ <div class="artifact large"></div>
60
+ <div class="artifact large"></div>
61
+ <div class="artifact large"></div>
62
+ <!-- Content -->
63
+ <h1>GPT NEXT WORD GENERATOR</h1>
64
+ <p>Generate dialogue for given some initial prompt for context.</p>
65
+ <p>Model: GPT, Dataset: arxiv + book + cc, Parameter Count: 160M</p>
66
+ """
67
+
68
+ with gr.Blocks(theme=gr.themes.Glass(),css=".gradio-container {background: url('file=https://github.com/Delve-ERAV1/Conditional-Diffusion/assets/11761529/1ff9d2e1-798f-442a-a1e2-386fdd35010a')}") as interface:
69
+ gr.HTML(value=HTML_TEMPLATE, show_label=False)
70
+
71
+ gr.Markdown("")
72
+ gr.Markdown("")
73
+ gr.Markdown("")
74
+
75
+ gr.Markdown("")
76
+ gr.Markdown("")
77
+ gr.Markdown("")
78
+ gr.Markdown("")
79
+
80
+ gr.Markdown("")
81
+ gr.Markdown("")
82
+ gr.Markdown("")
83
+ gr.Markdown("")
84
+ gr.Markdown("")
85
+ gr.Markdown("")
86
+ gr.Markdown("")
87
+
88
+ gr.Markdown("")
89
+ gr.Markdown("")
90
+
91
+ gr.Markdown("")
92
+ gr.Markdown("")
93
+ gr.Markdown("")
94
+ gr.Markdown("")
95
+
96
+ gr.Markdown("")
97
+ gr.Markdown("")
98
+ gr.Markdown("")
99
+ gr.Markdown("")
100
+ gr.Markdown("")
101
+ gr.Markdown("")
102
+ gr.Markdown("")
103
+
104
+ gr.Markdown("")
105
+ gr.Markdown("")
106
+
107
+ gr.Markdown("")
108
+ gr.Markdown("")
109
+ gr.Markdown("")
110
+ gr.Markdown("")
111
+ gr.Markdown("")
112
+ gr.Markdown("")
113
+
114
+ gr.Markdown("")
115
+ gr.Markdown("")
116
+ gr.Markdown("")
117
+ gr.Markdown("")
118
+
119
+ gr.Markdown("")
120
+ gr.Markdown("")
121
+ gr.Markdown("")
122
+ gr.Markdown("")
123
+
124
+
125
+ with gr.Row():
126
+
127
+ input_text = gr.Textbox(
128
+ label="Input Text",
129
+ value="Enter your prompt here: This text will set the context for the AI's response."
130
+ )
131
+
132
+ temperature_dropdown = gr.Slider(0, 1, value=0.8, label="Temperature", info="Set the creativity level: Higher values produce more varied results, lower values generate more predictable text.")
133
+ top_k_dropdown = gr.Slider(50, 300, value=200, label="Top K", info="Control the randomness: Limits the AI to consider only the top K most likely next words.")
134
+ max_new_tokens = gr.Slider(1, 100, value=50, label="Max Tokens", info="Choose the length: This determines the maximum number of words the AI will generate.")
135
+
136
+
137
+ outputs = gr.Textbox(
138
+ label="Generated Dialogue"
139
+ )
140
+ inputs = [input_text, temperature_dropdown, top_k_dropdown, max_new_tokens]
141
+
142
+ with gr.Column():
143
+ button = gr.Button("Generate")
144
+ button.click(generate_dialogue, inputs=inputs, outputs=outputs)
145
+
146
+ with gr.Row():
147
+ gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=generate_dialogue, cache_examples=True,)
148
+
149
+
150
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ pandas
4
+ lightning
5
+ sentencepiece
utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import random
4
+ import torch.nn as nn
5
+ import lightning as L
6
+ from pathlib import Path
7
+ from torch.utils.data import DataLoader
8
+ from lightning.fabric.loggers import CSVLogger
9
+ from lightning.fabric.strategies import FSDPStrategy
10
+
11
+ from tsai_gpt.model import GPT, Block, Config
12
+ from tsai_gpt.tokenizer import Tokenizer
13
+ from tsai_gpt.utils import get_default_supported_precision, load_checkpoint, gptq_quantization
14
+
15
+
16
+ example_text = [
17
+ "In a galaxy far, far away, an intergalactic council convenes to discuss the rising cost of lightsaber batteries. Among them is an unlikely representative: a droid with a penchant for economics...",
18
+ "As Sherlock Holmes and Dr. Watson enter the world of social media influencers, they find their first case: the mysterious disappearance of a famous TikTok star's like button.",
19
+ "In the midst of a zombie apocalypse, a group of survivors discovers a library with every book intact except for cookbooks. Their leader, a former TV chef, decides to write the ultimate survival recipe book titled...",
20
+ "A time traveler accidentally attends Shakespeare's first play, but instead of a quill, she hands him a smartphone with autocorrect. The resulting play is...",
21
+ "Amidst the chaos of a Hogwarts School reunion, a magical mishap swaps the voices of Professors Dumbledore and Snape, leading to an unexpected duet in the Great Hall that goes viral in the wizarding world."
22
+ ]
23
+
24
+ examples = [
25
+ [
26
+ example_text[i],
27
+ round(random.uniform(0,1), 1),
28
+ int(random.uniform(50,200)),
29
+ int(random.uniform(100,300))] for i,x in enumerate(example_text)
30
+ ]
31
+
32
+
33
+ model_name = "pythia-160m"
34
+ name = "redpajama"
35
+
36
+ checkpoint_dir = Path("iter-010915-ckpt.pth")
37
+ quantize = None
38
+ strategy = "auto"
39
+ devices = 1
40
+ precision = get_default_supported_precision(training=False)
41
+ plugins = None
42
+ fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
43
+ fabric.launch()
44
+
45
+
46
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize=="gptq.int4"):
47
+ config = Config.from_name(model_name)
48
+ model = GPT(config)
49
+
50
+ model.eval()
51
+ model = fabric.setup_module(model)
52
+ load_checkpoint(fabric, model, checkpoint_dir)
53
+
54
+ tokenizer = Tokenizer(Path('tokenizer'))
55
+
56
+
57
+ def generate_dialogue(input_text, temperature, max_tokens, top_k):
58
+ encoded = tokenizer.encode(input_text, device=fabric.device)
59
+ max_returned_tokens = encoded.size(0) + max_tokens
60
+
61
+
62
+ with fabric.init_tensor():
63
+ # set the max_seq_length to limit the memory usage to what we need
64
+ model.max_seq_length = max_returned_tokens
65
+
66
+
67
+ with fabric.init_tensor():
68
+ model.set_kv_cache(batch_size=1)
69
+
70
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
71
+
72
+ return(tokenizer.decode(y))
73
+
74
+
75
+ @torch.inference_mode()
76
+ def generate(
77
+ model: GPT,
78
+ idx: torch.Tensor,
79
+ max_returned_tokens: int,
80
+ *,
81
+ temperature: float = 1.0,
82
+ top_k:int = None,
83
+ eos_id:int = None,
84
+ ) -> torch.Tensor:
85
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
86
+
87
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
88
+
89
+ Args:
90
+ model: The model to use.
91
+ idx: Tensor of shape (T) with indices of the prompt sequence.
92
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
93
+ temperature: Scales the predicted logits by 1 / temperature.
94
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
95
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
96
+ """
97
+ T = idx.size(0)
98
+ assert max_returned_tokens > T
99
+ if model.max_seq_length < max_returned_tokens - 1:
100
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
101
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
102
+ # not support it to avoid negatively impacting the overall speed
103
+ raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
104
+
105
+ device, dtype = idx.device, idx.dtype
106
+ # create an empty tensor of the expected final shape and fill in the current tokens
107
+ empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
108
+ empty[:T] = idx
109
+ idx = empty
110
+ input_pos = torch.arange(0, T, device=device)
111
+
112
+ # generate up to a fixed number of tokens
113
+ for _ in range(max_returned_tokens - T):
114
+ x = idx.index_select(0, input_pos).view(1, -1)
115
+
116
+ # forward
117
+ logits = model(x, input_pos)
118
+ logits = logits[0, -1] / temperature
119
+
120
+ # optionally crop the logits to only the top k options
121
+ if top_k is not None:
122
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
123
+ logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
124
+
125
+ probs = torch.nn.functional.softmax(logits, dim=-1)
126
+ idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
127
+
128
+ # advance
129
+ input_pos = input_pos[-1:] + 1
130
+
131
+ # concatenate the new generation
132
+ idx = idx.index_copy(0, input_pos, idx_next)
133
+
134
+ # if <eos> token is triggered, return the output (stop generation)
135
+ if idx_next == eos_id:
136
+ return idx[:input_pos] # include the EOS token
137
+
138
+ return idx