Jyo-K commited on
Commit
3eeedde
·
verified ·
1 Parent(s): 960b06a

Upload 4 files

Browse files
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import tiktoken
5
+ import gradio as gr
6
+ import math
7
+ import os
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
10
+
11
+ # Tokenizer setup
12
+ enc = tiktoken.get_encoding("gpt2")
13
+ vocab_size = enc.n_vocab + 1 # +1 for mask token
14
+ mask_token_id = enc.n_vocab
15
+
16
+ def encode(s):
17
+ return enc.encode(s)
18
+
19
+ def decode(l):
20
+ return enc.decode([t for t in l if t != mask_token_id])
21
+
22
+ def format_masked_text(l):
23
+ chunks = []
24
+ current_chunk = []
25
+ for t in l:
26
+ if t == mask_token_id:
27
+ if current_chunk:
28
+ chunks.append(enc.decode(current_chunk))
29
+ current_chunk = []
30
+ chunks.append(" [MASK] ")
31
+ else:
32
+ current_chunk.append(t)
33
+
34
+ if current_chunk:
35
+ chunks.append(enc.decode(current_chunk))
36
+
37
+ return "".join(chunks)
38
+
39
+ def norm(x):
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5)
41
+
42
+ def apply_rotary_emb(x, cos, sin):
43
+ assert x.ndim == 4
44
+ d = x.shape[3] // 2
45
+ x1, x2 = x[..., :d], x[..., d:]
46
+ y1 = x1 * cos + x2 * sin
47
+ y2 = x1 * (-sin) + x2 * cos
48
+ out = torch.cat([y1, y2], 3)
49
+ return out.to(x.dtype)
50
+
51
+ class MultiHeadAttention(nn.Module):
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ self.config = config
55
+ self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=False)
56
+ self.c_k = nn.Linear(config.n_embd, config.n_embd, bias=False)
57
+ self.c_v = nn.Linear(config.n_embd, config.n_embd, bias=False)
58
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
59
+
60
+ def forward(self, x, cos_sin):
61
+ B, T, C = x.size()
62
+ q = self.c_q(x).view(B, T, self.config.n_head, self.config.head_dim)
63
+ k = self.c_k(x).view(B, T, self.config.n_head, self.config.head_dim)
64
+ v = self.c_v(x).view(B, T, self.config.n_head, self.config.head_dim)
65
+ cos, sin = cos_sin
66
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
67
+ q, k = norm(q), norm(k)
68
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
69
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
70
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
71
+ y = self.c_proj(y)
72
+ return y
73
+
74
+ class MLP(nn.Module):
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.config = config
78
+ hidden_dim = int(8 * config.n_embd / 3)
79
+ self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False)
80
+ self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
81
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
82
+
83
+ def forward(self, x):
84
+ return self.c_proj(F.silu(self.w1(x)) * self.w2(x))
85
+
86
+ class Block(nn.Module):
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ self.config = config
90
+ self.attn = MultiHeadAttention(config)
91
+ self.mlp = MLP(config)
92
+
93
+ def forward(self, x, cos_sin):
94
+ x = x + self.attn(norm(x), cos_sin)
95
+ x = x + self.mlp(norm(x))
96
+ return x
97
+
98
+ class Model(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.config = config
102
+ self.token_emb = nn.Embedding(vocab_size, config.n_embd)
103
+ self.time_emb = nn.Sequential(
104
+ nn.Linear(1, config.n_embd),
105
+ nn.SiLU(),
106
+ nn.Linear(config.n_embd, config.n_embd),
107
+ )
108
+ self.rotary_seq_len = config.block_size * 2
109
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len)
110
+ self.register_buffer("cos", cos, persistent=False)
111
+ self.register_buffer("sin", sin, persistent=False)
112
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
113
+ self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False)
114
+ self.lm_head.weight = self.token_emb.weight # tie weights
115
+ self.apply(self._init_weights)
116
+
117
+ def _init_weights(self, module):
118
+ if isinstance(module, nn.Linear):
119
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
120
+ if module.bias is not None:
121
+ torch.nn.init.zeros_(module.bias)
122
+ elif isinstance(module, nn.Embedding):
123
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
124
+
125
+ def _precompute_rotary_embeddings(self, seq_len, base=10000, device=None):
126
+ if device is None:
127
+ device = self.token_emb.weight.device
128
+ channel_range = torch.arange(0, self.config.head_dim, 2, dtype=torch.float32, device=device)
129
+ inv_freq = 1.0 / (base ** (channel_range / self.config.head_dim))
130
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
131
+ freqs = torch.outer(t, inv_freq)
132
+ cos, sin = freqs.cos(), freqs.sin()
133
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
134
+ return cos, sin
135
+
136
+ def forward(self, idx, targets=None, mask=None, mask_rate=None):
137
+ B, T = idx.size()
138
+ x = self.token_emb(idx)
139
+ if mask_rate is not None:
140
+ t = mask_rate.float().unsqueeze(-1) # (B, 1, 1)
141
+ x = x + self.time_emb(t)
142
+ x = norm(x)
143
+ cos_sin = (self.cos[:, :T], self.sin[:, :T])
144
+ for block in self.blocks:
145
+ x = block(x, cos_sin)
146
+ x = norm(x)
147
+ logits = self.lm_head(x)
148
+
149
+ if targets is None:
150
+ loss = None
151
+ else:
152
+ B, T, C = logits.shape
153
+ logits_flat = logits.view(B * T, C)
154
+ targets_flat = targets.view(B * T)
155
+ if mask is not None:
156
+ mask_flat = mask.view(B * T)
157
+ loss = F.cross_entropy(logits_flat, targets_flat, reduction="none")
158
+ loss = (loss * mask_flat).sum() / mask_flat.sum()
159
+ else:
160
+ loss = F.cross_entropy(logits_flat, targets_flat)
161
+ return logits, loss
162
+
163
+ class Config:
164
+ def __init__(self, model_type):
165
+ self.block_size = 512
166
+ if model_type == 'medium':
167
+ self.n_embd = 512
168
+ self.n_head = 8
169
+ self.n_layer = 8
170
+ self.weights_path = "tinystories_diffusion_med_dual.pt"
171
+ elif model_type == 'gpt2':
172
+ self.n_embd = 768
173
+ self.n_head = 12
174
+ self.n_layer = 12
175
+ self.weights_path = "tinystories_diffusion_GPT2_dual.pt"
176
+ else:
177
+ raise ValueError("model_type must be 'medium' or 'gpt2'")
178
+ self.head_dim = self.n_embd // self.n_head
179
+
180
+ # Dynamic loading
181
+ loaded_model_type = None
182
+ loaded_model = None
183
+
184
+ def get_model(model_type):
185
+ global loaded_model_type, loaded_model
186
+ if loaded_model_type == model_type and loaded_model is not None:
187
+ return loaded_model, Config(model_type)
188
+
189
+ print(f"Loading {model_type} model...")
190
+ config = Config(model_type)
191
+ model = Model(config)
192
+ weights_path = config.weights_path
193
+
194
+ if os.path.exists(weights_path):
195
+ state_dict = torch.load(weights_path, map_location=device, weights_only=True)
196
+ unwrapped_state_dict = {}
197
+ for k, v in state_dict.items():
198
+ # Handle 'module.' prefix from DataParallel if present
199
+ if k.startswith("module."):
200
+ unwrapped_state_dict[k[7:]] = v
201
+ else:
202
+ unwrapped_state_dict[k] = v
203
+ model.load_state_dict(unwrapped_state_dict)
204
+ print("Model loaded successfully!")
205
+ else:
206
+ print(f"Warning: {weights_path} not found. Running with uninitialized random parameters.")
207
+
208
+ model.to(device)
209
+ loaded_model = model
210
+ loaded_model_type = model_type
211
+ return model, config
212
+
213
+ @torch.no_grad()
214
+ def generate_diffusion(prompt, max_new_tokens=100, mode="Direct Output", model_type="medium"):
215
+ model, config = get_model(model_type)
216
+ prompt_tokens = encode(prompt)
217
+ model.eval()
218
+ prompt_len = len(prompt_tokens)
219
+ all_tokens = prompt_tokens.copy()
220
+ temp = 1.0
221
+ confidence_threshold = 0.95
222
+ top_k = 3
223
+
224
+ while len(all_tokens) - len(prompt_tokens) < max_new_tokens:
225
+ curr_prompt_len = len(all_tokens)
226
+ block_len = min(config.block_size - curr_prompt_len, len(prompt_tokens) + max_new_tokens - len(all_tokens))
227
+ if block_len <= 0: break
228
+
229
+ x = torch.full((1, config.block_size), mask_token_id, dtype=torch.long, device=device)
230
+ x[0, :curr_prompt_len] = torch.tensor(all_tokens[-curr_prompt_len:], device=device)
231
+
232
+ masked = torch.zeros(1, config.block_size, dtype=torch.bool, device=device)
233
+ masked[0, curr_prompt_len : curr_prompt_len + block_len] = True
234
+
235
+ while masked.any():
236
+ logits, _ = model(x)
237
+ probs = F.softmax(logits / temp, dim=-1)
238
+ top_k_probs, top_k_indices = torch.topk(probs, k=top_k, dim=-1)
239
+ confidences = top_k_probs.sum(dim=-1)
240
+
241
+ decode_mask = (confidences >= confidence_threshold) & masked
242
+ if not decode_mask.any():
243
+ masked_confidences = torch.where(masked, confidences, torch.tensor(-float('inf')).to(device))
244
+ decode_mask.view(-1)[masked_confidences.argmax()] = True
245
+
246
+ top_k_probs_norm = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
247
+ sampled_k = torch.multinomial(top_k_probs_norm.view(-1, top_k), 1).view(1, config.block_size)
248
+ sampled_tokens = torch.gather(top_k_indices, -1, sampled_k.unsqueeze(-1)).squeeze(-1)
249
+
250
+ x = torch.where(decode_mask, sampled_tokens, x)
251
+ masked = masked & ~decode_mask
252
+
253
+ if mode == "Show Generation Process":
254
+ current_block = x[0, curr_prompt_len : curr_prompt_len + block_len].tolist()
255
+ yield format_masked_text(all_tokens + current_block)
256
+
257
+ all_tokens.extend(x[0, curr_prompt_len : curr_prompt_len + block_len].tolist())
258
+
259
+ full_output = decode(all_tokens)
260
+ yield full_output
261
+
262
+ def gradio_fn(prompt, display_mode, max_tokens, model_type):
263
+ for text in generate_diffusion(prompt, max_new_tokens=max_tokens, mode=display_mode, model_type=model_type):
264
+ yield text
265
+
266
+ # Gradio
267
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
268
+ gr.Markdown("# TinyStories Diffusion LM")
269
+ gr.Markdown("A non-autoregressive language model leveraging parallel block-decoding and SwiGLU networks.")
270
+
271
+ with gr.Row():
272
+ with gr.Column():
273
+ prompt_in = gr.Textbox(lines=2, placeholder="Once upon a time, there was a little girl who", label="Prompt (approx 10 words)")
274
+
275
+ model_type_in = gr.Radio(["medium", "gpt2"], value="medium", label="Model Architecture")
276
+ mode = gr.Radio(["Direct Output", "Show Generation Process"], value="Direct Output", label="Display Mode")
277
+ max_tokens = gr.Slider(minimum=20, maximum=1000, value=100, step=1, label="Max Tokens")
278
+
279
+ generate_btn = gr.Button("Generate Story", variant='primary')
280
+
281
+ with gr.Column():
282
+ output = gr.Textbox(lines=10, label="Output")
283
+
284
+ generate_btn.click(fn=gradio_fn, inputs=[prompt_in, mode, max_tokens, model_type_in], outputs=output)
285
+
286
+ if __name__ == "__main__":
287
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=4.0.0
3
+ tiktoken>=0.6.0
tinystories_diffusion_GPT2_dual.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fcd025ac0e4b19e5d9d40c2d2d4f7ff7410b9f222090e86b292f9a04d72eb77
3
+ size 496536064
tinystories_diffusion_med_dual.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee541e39aa08270b69c86eed5847cda6d5bb447059ca4c0c60f913e9dd9a6101
3
+ size 204655089