diffusion-gpt / app.py
multimodalart's picture
Update app.py
a42b836 verified
raw
history blame
16.8 kB
import gradio as gr
import spaces
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import math
import os
import pickle
import requests
import textwrap
import subprocess
import shutil
import time
from dataclasses import dataclass
from typing import Optional
# --- 1. Automated Environment and Data Setup ---
def setup_environment():
"""
Checks for and sets up the necessary data and code.
- Clones nanoGPT if not present.
- Copies the shakespeare_char dataset directory.
- Runs the data preparation script to create meta.pkl and binary files.
This function makes the script self-contained.
"""
nano_gpt_repo_path = 'nanoGPT'
data_dir_path = 'shakespeare_char'
meta_path = os.path.join(data_dir_path, 'meta.pkl')
if os.path.exists(meta_path):
print("Dataset and metadata found. Skipping setup.")
return
print("Required data not found. Starting one-time setup...")
if not os.path.exists(nano_gpt_repo_path):
print(f"Cloning nanoGPT repository...")
try:
subprocess.run(
['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'],
check=True, capture_output=True, text=True
)
print("Cloned successfully.")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr}")
raise
else:
print("nanoGPT repository already exists.")
source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char')
if not os.path.exists(data_dir_path):
print(f"Copying '{source_data_dir}' to '{data_dir_path}'...")
shutil.copytree(source_data_dir, data_dir_path)
print("Copied successfully.")
else:
print(f"'{data_dir_path}' directory already exists.")
prepare_script_path = os.path.join(data_dir_path, 'prepare.py')
if not os.path.exists(meta_path):
print(f"Running data preparation script: '{prepare_script_path}'...")
try:
subprocess.run(
['python', 'prepare.py'],
check=True, cwd=data_dir_path, capture_output=True, text=True
)
print("Data preparation script finished successfully.")
except subprocess.CalledProcessError as e:
print(f"Error running prepare.py: {e.stderr}")
raise
print("Setup complete.")
setup_environment()
# --- 2. Global Setup & Helper Functions ---
data_dir = './shakespeare_char/'
meta_path = os.path.join(data_dir, 'meta.pkl')
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
vocab_size = meta['vocab_size']
itos = meta['itos']
stoi = meta['stoi']
context_length = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def decode(indices_tensor: torch.Tensor):
if indices_tensor.dim() > 1:
indices_tensor = indices_tensor.squeeze(0)
indices = indices_tensor.cpu().numpy()
return ''.join([itos.get(i, '?') for i in indices])
def wrap_text(long_text, width=80):
paragraphs = long_text.splitlines()
wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
return "\n".join(wrapped)
# --- 3. Model Architecture (Identical to Notebook) ---
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
cond_dim: int = 64
dropout: float = 0.0
bias: bool = False
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class SelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
if self.flash:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x * (1 + scale) + shift
def bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
if bias is not None:
out = scale * (x + bias)
else:
out = scale * x
if residual is not None:
out = residual + out
return out
class DDiTBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.attn = SelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
x_skip = x
x = modulate(self.ln_1(x), shift_msa, scale_msa)
x = self.attn(x)
x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
return x
class DDitFinalLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
self.linear = nn.Linear(config.n_embd, config.vocab_size)
self.linear.weight.data.zero_()
self.linear.bias.data.zero_()
self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.sigma_map = TimestepEmbedder(config.cond_dim)
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = DDitFinalLayer(config)
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, sigma):
sigma = sigma.reshape(-1)
b, t = idx.size()
c = F.silu(self.sigma_map(sigma))
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device)
tok_emb = self.transformer.wte(idx)
pos_emb = self.transformer.wpe(pos)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x, c)
x = self.transformer.ln_f(x)
x = self.lm_head(x, c)
x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
return x
class GeometricNoise:
def __init__(self, sigma_min=1e-4, sigma_max=20):
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device)
def rate_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
def total_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
def __call__(self, t):
return self.total_noise(t), self.rate_noise(t)
# --- 4. Inference & Sampling Logic (Identical to Notebook) ---
def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob
trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
trans = trans.scatter(-1, x_t[..., None], diag_fill)
return trans
def staggered_score(score, delta_sigma):
exp_factor = torch.exp(-delta_sigma)[..., None]
correction = ((exp_factor - 1) / (vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True)
return correction + score / exp_factor
def sample_categorical(probs: torch.Tensor) -> torch.Tensor:
eps = 1e-10
gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
# --- 5. Model Initialization and Loading ---
print("Initializing and loading the pretrained model...")
model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
bias=False, vocab_size=vocab_size, block_size=context_length, dropout=0.2)
config = GPTConfig(**model_args)
model = GPT(config)
model.load_state_dict(
torch.hub.load_state_dict_from_url(
'https://raw.githubusercontent.com/ash80/diffusion-gpt/master/pretrained_model/model_epoch_25.pth',
map_location=device
)
)
model.to(device)
model.eval()
noise = GeometricNoise(sigma_min=1e-4, sigma_max=20)
print("Model loaded successfully.")
# --- 6. Gradio Interface Logic ---
@spaces.GPU
def generate_text(steps):
"""
Fast generation phase. Runs the diffusion process and stores all
intermediate frames in a list, then returns the final text and the list.
"""
steps = int(steps)
eps = 1e-5
# List to store each frame of the diffusion process
diffusion_frames = []
# Start with a random sample
x = torch.randint(0, vocab_size, (1, context_length), device=device)
initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(decode(x[0]))}"
diffusion_frames.append(initial_text)
timesteps = torch.linspace(1, eps, steps + 1, device=device)
step_size = (1 - eps) / steps
with torch.no_grad():
for i in range(steps):
t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
curr_sigma_bar = noise(t)[0]
next_sigma_bar = noise(t - step_size)[0]
delta_sigma = curr_sigma_bar - next_sigma_bar
log_score = model(x, curr_sigma_bar)
score = torch.exp(log_score)
stag_score = staggered_score(score, delta_sigma)
probs = stag_score * transition(x, delta_sigma)
x = sample_categorical(probs)
# Store the frame
progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(decode(x[0]))}"
diffusion_frames.append(progress_text)
# Final denoising step
t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device)
curr_sigma_bar = noise(t)[0]
delta_sigma = curr_sigma_bar
log_score = model(x, curr_sigma_bar)
score = torch.exp(log_score)
stag_score = staggered_score(score, delta_sigma)
probs = stag_score * transition(x, delta_sigma)
x = sample_categorical(probs)
final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(decode(x[0]))}"
diffusion_frames.append(final_text)
# Return the final text and the complete list of frames
return final_text, diffusion_frames
def replay_diffusion(frames, replay_speed):
"""
Slow replay phase. Iterates through the stored frames and yields them
with a delay to create an animation effect.
"""
delay = 0.5 / replay_speed # Calculate delay based on speed multiplier
for frame in frames:
yield frame
time.sleep(delay)
# Define the Gradio UI
css = '''.gradio-container > .fillable {max-width: 720px !important}
h3{margin-top: 1em}
p{margin-top: 0}
textarea{font-family: monospace;background-color: black}
'''
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
gr.Markdown(
"""
# The Annotated Discrete Diffusion Models
### Tiny 7.23M parameters Shakespeare character diffusion model by [Ashwani Kumar](https://x.com/ash_at_tt/status/1977376958859092250)
[GitHub](https://github.com/ash80/diffusion-gpt), [Colab](https://colab.research.google.com/github/ash80/diffusion-gpt/blob/master/The_Annotated_Discrete_Diffusion_Models.ipynb)
"""
)
generate_button = gr.Button("Generate", variant="primary")
output_textbox = gr.Textbox(
label="Generated Text",
lines=15,
interactive=False,
show_copy_button=True,
placeholder="Generation will appear here..."
)
with gr.Row():
steps_slider = gr.Slider(
minimum=64,
maximum=512,
value=128,
step=1,
label="Denoising Steps",
info="Number of steps in the generation process."
)
speed_slider = gr.Slider(
minimum=1,
maximum=20,
value=10,
step=1,
label="Replay Speed",
info="Controls the speed of the animation after generation.",
visible=False
)
diffusion_frames_state = gr.State([])
generate_event = generate_button.click(
fn=generate_text,
inputs=[steps_slider],
outputs=[output_textbox, diffusion_frames_state]
).then(
fn=replay_diffusion,
inputs=[diffusion_frames_state, speed_slider],
outputs=[output_textbox]
)
if __name__ == "__main__":
demo.launch()