PraneshJs commited on
Commit
6d97cab
·
verified ·
1 Parent(s): 12dfcb4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IMAGE DIFFUSION VISUALIZER — ADVANCED
2
+ # Visualizes how a (tiny) Stable Diffusion model denoises step by step.
3
+ # Model: hf-internal-testing/tiny-stable-diffusion-pipe (small, CPU-safe, for demos)
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import numpy as np
8
+ from diffusers import DiffusionPipeline
9
+ from sklearn.decomposition import PCA
10
+ import plotly.graph_objects as go
11
+ import plotly.express as px
12
+ from PIL import Image
13
+ import time
14
+
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-pipe"
17
+
18
+ PIPE_CACHE = None
19
+
20
+
21
+ # -------------------- MODEL LOADING -------------------- #
22
+
23
+ def get_pipe():
24
+ """Lazy-load and cache the tiny Stable Diffusion pipeline."""
25
+ global PIPE_CACHE
26
+ if PIPE_CACHE is not None:
27
+ return PIPE_CACHE
28
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID)
29
+ pipe.to(DEVICE)
30
+ pipe.safety_checker = None # tiny pipe usually doesn't have NSFW issues; keep simple
31
+ PIPE_CACHE = pipe
32
+ return PIPE_CACHE
33
+
34
+
35
+ # -------------------- CORE UTILS -------------------- #
36
+
37
+ def decode_latent_to_pil(pipe, latent_np):
38
+ """
39
+ Decode a latent (C,H,W) numpy array to a PIL image using the VAE.
40
+ Works for intermediate steps too.
41
+ """
42
+ vae = pipe.vae
43
+ latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
44
+ # scaling_factor is used in SD-style VAEs; fallback to standard SD value
45
+ scale = getattr(vae.config, "scaling_factor", 0.18215)
46
+ with torch.no_grad():
47
+ image = vae.decode(latent / scale).sample
48
+ image = (image / 2 + 0.5).clamp(0, 1)
49
+ image = image[0].permute(1, 2, 0).cpu().numpy()
50
+ image = (image * 255).astype("uint8")
51
+ return Image.fromarray(image)
52
+
53
+
54
+ def compute_pca_over_steps(latents_list):
55
+ """
56
+ latents_list: list of (C,H,W) numpy arrays.
57
+ Flatten each into a single vector; run PCA across steps.
58
+ Returns (S,2) array of 2D coords.
59
+ """
60
+ if len(latents_list) == 0:
61
+ return None
62
+ flat = [x.reshape(-1) for x in latents_list]
63
+ mat = np.stack(flat, axis=0) # (steps, dim)
64
+ if mat.shape[0] < 2 or mat.shape[1] < 2:
65
+ # Not enough data for PCA; return zeros
66
+ return np.zeros((mat.shape[0], 2))
67
+ try:
68
+ pca = PCA(n_components=2)
69
+ pts = pca.fit_transform(mat)
70
+ return pts
71
+ except Exception:
72
+ return np.zeros((mat.shape[0], 2))
73
+
74
+
75
+ def compute_norms_over_steps(latents_list):
76
+ """Compute L2 norm of each latent across channels & spatial dims."""
77
+ if len(latents_list) == 0:
78
+ return []
79
+ flat = [x.reshape(-1) for x in latents_list]
80
+ norms = [float(np.linalg.norm(v)) for v in flat]
81
+ return norms
82
+
83
+
84
+ def explain(simple=True):
85
+ if simple:
86
+ return (
87
+ "🧒 **Simple explanation of what you see:**\n\n"
88
+ "1. The model starts with a totally noisy image.\n"
89
+ "2. Step by step, it removes noise and shapes the picture.\n"
90
+ "3. Your words (the prompt) tell it *what* to draw.\n"
91
+ "4. The slider lets you move through these steps:\n"
92
+ " - Early steps = mostly noise\n"
93
+ " - Later steps = clearer image\n"
94
+ )
95
+ else:
96
+ return (
97
+ "🔬 **Technical explanation:**\n\n"
98
+ "- We use a tiny Stable Diffusion-style pipeline.\n"
99
+ "- At each timestep `t`, the UNet predicts noise εₜ for latent `zₜ`.\n"
100
+ "- The scheduler updates `zₜ → zₜ₋₁` using εₜ.\n"
101
+ "- We record the latent after each step and decode it with the VAE.\n"
102
+ "- PCA over flattened latents shows the trajectory in latent space.\n"
103
+ "- Latent norm vs step shows how the magnitude evolves during denoising.\n"
104
+ )
105
+
106
+
107
+ def make_pca_figure(points, current_idx):
108
+ """Make a PCA trajectory plot over steps, highlighting the selected step."""
109
+ steps = list(range(len(points)))
110
+ fig = go.Figure()
111
+ fig.add_trace(go.Scatter(
112
+ x=points[:, 0],
113
+ y=points[:, 1],
114
+ mode="lines+markers",
115
+ name="Steps",
116
+ text=[f"step {i}" for i in steps]
117
+ ))
118
+ if 0 <= current_idx < len(points):
119
+ fig.add_trace(go.Scatter(
120
+ x=[points[current_idx, 0]],
121
+ y=[points[current_idx, 1]],
122
+ mode="markers+text",
123
+ text=[f"step {current_idx}"],
124
+ textposition="top center",
125
+ marker=dict(size=14, color="red"),
126
+ name="Current step"
127
+ ))
128
+ fig.update_layout(
129
+ title="Latent PCA trajectory over steps",
130
+ xaxis_title="PC1",
131
+ yaxis_title="PC2",
132
+ height=400
133
+ )
134
+ return fig
135
+
136
+
137
+ def make_norm_figure(norms, current_idx):
138
+ """Plot latent norm vs step, highlighting the current step."""
139
+ steps = list(range(len(norms)))
140
+ fig = go.Figure()
141
+ fig.add_trace(go.Scatter(
142
+ x=steps,
143
+ y=norms,
144
+ mode="lines+markers",
145
+ name="Latent norm"
146
+ ))
147
+ if 0 <= current_idx < len(norms):
148
+ fig.add_trace(go.Scatter(
149
+ x=[steps[current_idx]],
150
+ y=[norms[current_idx]],
151
+ mode="markers",
152
+ marker=dict(size=14, color="red"),
153
+ name="Current step"
154
+ ))
155
+ fig.update_layout(
156
+ title="Latent L2 norm vs diffusion step",
157
+ xaxis_title="Step index (0 = most noisy)",
158
+ yaxis_title="‖latent‖₂",
159
+ height=400
160
+ )
161
+ return fig
162
+
163
+
164
+ # -------------------- MAIN ANALYSIS FUNCTION -------------------- #
165
+
166
+ def run_diffusion_analysis(prompt, num_steps, guidance, seed, simple_mode):
167
+ """
168
+ Run the tiny diffusion pipeline, recording latents at each step.
169
+ Returns Gradio updates + a state dict.
170
+ """
171
+ if not prompt or not prompt.strip():
172
+ return (
173
+ None, # final image
174
+ f"⚠️ Please enter a prompt.",
175
+ gr.update(maximum=0, value=0),
176
+ None, None, None,
177
+ {
178
+ "error": "no_prompt"
179
+ }
180
+ )
181
+
182
+ pipe = get_pipe()
183
+ num_steps = int(num_steps)
184
+ guidance = float(guidance)
185
+
186
+ # Seed handling
187
+ if seed is None or seed < 0:
188
+ generator = torch.Generator(device=DEVICE)
189
+ else:
190
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
191
+
192
+ latents_list = []
193
+ timesteps_list = []
194
+
195
+ def callback(step, timestep, latents):
196
+ # latents: (batch, C, H, W)
197
+ latents_list.append(latents.detach().cpu().numpy()[0])
198
+ timesteps_list.append(int(timestep))
199
+
200
+ t0 = time.time()
201
+ try:
202
+ result = pipe(
203
+ prompt,
204
+ num_inference_steps=num_steps,
205
+ guidance_scale=guidance,
206
+ generator=generator,
207
+ callback=callback,
208
+ callback_steps=1,
209
+ )
210
+ except Exception as e:
211
+ return (
212
+ None,
213
+ f"❌ Model / diffusion error: {e}",
214
+ gr.update(maximum=0, value=0),
215
+ None, None, None,
216
+ {
217
+ "error": "diffusion_error",
218
+ "details": str(e)
219
+ }
220
+ )
221
+
222
+ elapsed = time.time() - t0
223
+
224
+ if len(latents_list) == 0:
225
+ return (
226
+ None,
227
+ "❌ No latents were collected. Something went wrong inside the pipeline.",
228
+ gr.update(maximum=0, value=0),
229
+ None, None, None,
230
+ {
231
+ "error": "no_latents"
232
+ }
233
+ )
234
+
235
+ final_image = result.images[0] # PIL
236
+
237
+ # Compute PCA and norms over steps
238
+ pca_points = compute_pca_over_steps(latents_list)
239
+ norms = compute_norms_over_steps(latents_list)
240
+
241
+ # Default step: last (most denoised)
242
+ current_idx = len(latents_list) - 1
243
+
244
+ # Decode image for current step
245
+ try:
246
+ step_image = decode_latent_to_pil(pipe, latents_list[current_idx])
247
+ except Exception:
248
+ step_image = None
249
+
250
+ # Build plots
251
+ pca_fig = make_pca_figure(pca_points, current_idx) if pca_points is not None else None
252
+ norm_fig = make_norm_figure(norms, current_idx) if norms else None
253
+
254
+ # Explanation
255
+ explanation = explain(simple_mode)
256
+ explanation += f"\n\n⏱ **Runtime:** {elapsed:.2f}s • **Steps:** {len(latents_list)}"
257
+
258
+ # State dict to keep everything for slider updates
259
+ state = {
260
+ "prompt": prompt,
261
+ "num_steps": num_steps,
262
+ "guidance": guidance,
263
+ "seed": seed,
264
+ "latents": latents_list,
265
+ "timesteps": timesteps_list,
266
+ "pca_points": pca_points,
267
+ "norms": norms
268
+ }
269
+
270
+ step_slider_update = gr.update(maximum=len(latents_list)-1, value=current_idx)
271
+
272
+ return (
273
+ final_image,
274
+ explanation,
275
+ step_slider_update,
276
+ step_image,
277
+ pca_fig,
278
+ norm_fig,
279
+ state
280
+ )
281
+
282
+
283
+ def update_step_view(state, step_idx):
284
+ """
285
+ When the user moves the step slider, update:
286
+ - the decoded image at that step
287
+ - the PCA plot (highlight current)
288
+ - the norm plot (highlight current)
289
+ """
290
+ if not state or "latents" not in state:
291
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
292
+
293
+ latents_list = state["latents"]
294
+ pca_points = state["pca_points"]
295
+ norms = state["norms"]
296
+
297
+ if len(latents_list) == 0:
298
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
299
+
300
+ step_idx = int(step_idx)
301
+ step_idx = max(0, min(step_idx, len(latents_list) - 1))
302
+
303
+ pipe = get_pipe()
304
+
305
+ # Decode image at this step
306
+ try:
307
+ step_image = decode_latent_to_pil(pipe, latents_list[step_idx])
308
+ except Exception:
309
+ step_image = None
310
+
311
+ # Update PCA & norm plots
312
+ pca_fig = make_pca_figure(pca_points, step_idx) if pca_points is not None else None
313
+ norm_fig = make_norm_figure(norms, step_idx) if norms else None
314
+
315
+ return gr.update(value=step_image), gr.update(value=pca_fig), gr.update(value=norm_fig)
316
+
317
+
318
+ # -------------------- GRADIO UI -------------------- #
319
+
320
+ with gr.Blocks(title="Diffusion Visualizer — Noise to Image", theme=gr.themes.Soft()) as demo:
321
+
322
+ gr.Markdown("# 🧠 Image Diffusion Visualizer (Advanced)")
323
+ gr.Markdown(
324
+ "See how a tiny Stable Diffusion model turns **pure noise** into an image "
325
+ "step by step. Use the slider to move through the diffusion process."
326
+ )
327
+
328
+ with gr.Row():
329
+ with gr.Column(scale=2):
330
+ prompt_box = gr.Textbox(
331
+ label="Prompt",
332
+ value="a small house in the forest, digital art",
333
+ lines=3
334
+ )
335
+ num_steps_slider = gr.Slider(
336
+ minimum=5, maximum=50, value=20, step=1,
337
+ label="Number of diffusion steps"
338
+ )
339
+ guidance_slider = gr.Slider(
340
+ minimum=1.0, maximum=10.0, value=7.5, step=0.5,
341
+ label="Guidance scale (higher = follow prompt more)"
342
+ )
343
+ seed_box = gr.Number(
344
+ label="Seed (leave -1 for random)",
345
+ value=-1,
346
+ precision=0
347
+ )
348
+ simple_mode_chk = gr.Checkbox(
349
+ label="Explain in simple terms (for kids/elders)",
350
+ value=True
351
+ )
352
+ run_btn = gr.Button("Generate & Analyze", variant="primary")
353
+
354
+ with gr.Column(scale=2):
355
+ final_image = gr.Image(label="Final generated image")
356
+ explanation_md = gr.Markdown(label="Explanation")
357
+
358
+ gr.Markdown("### 🔍 Explore the denoising process")
359
+ step_slider = gr.Slider(
360
+ minimum=0, maximum=0, value=0, step=1,
361
+ label="View step (0 = early, noisy • max = late, clear)"
362
+ )
363
+
364
+ with gr.Row():
365
+ with gr.Column():
366
+ step_image = gr.Image(label="Image at this diffusion step")
367
+ with gr.Column():
368
+ pca_plot = gr.Plot(label="Latent PCA trajectory")
369
+ with gr.Column():
370
+ norm_plot = gr.Plot(label="Latent norm vs step")
371
+
372
+ state = gr.State()
373
+
374
+ # Wire run button
375
+ run_btn.click(
376
+ run_diffusion_analysis,
377
+ inputs=[prompt_box, num_steps_slider, guidance_slider, seed_box, simple_mode_chk],
378
+ outputs=[final_image, explanation_md, step_slider, step_image, pca_plot, norm_plot, state]
379
+ )
380
+
381
+ # Wire slider change
382
+ step_slider.change(
383
+ update_step_view,
384
+ inputs=[state, step_slider],
385
+ outputs=[step_image, pca_plot, norm_plot]
386
+ )
387
+
388
+ demo.launch()