PraneshJs commited on
Commit
4a4fecc
·
verified ·
1 Parent(s): 181be6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -313
app.py CHANGED
@@ -1,388 +1,229 @@
1
- # IMAGE DIFFUSION VISUALIZER — ADVANCED
2
- # Visualizes how a (tiny) Stable Diffusion model denoises step by step.
3
- # Model: hf-internal-testing/tiny-stable-diffusion-pipe (small, CPU-safe, for demos)
 
 
4
 
5
  import gradio as gr
6
  import torch
7
  import numpy as np
8
- from diffusers import DiffusionPipeline
9
  from sklearn.decomposition import PCA
10
  import plotly.graph_objects as go
11
- import plotly.express as px
12
  from PIL import Image
13
  import time
 
14
 
15
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
- MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-pipe"
 
 
 
 
 
 
 
 
17
 
18
  PIPE_CACHE = None
19
 
20
 
21
- # -------------------- MODEL LOADING -------------------- #
22
 
23
  def get_pipe():
24
- """Lazy-load and cache the tiny Stable Diffusion pipeline."""
25
  global PIPE_CACHE
26
- if PIPE_CACHE is not None:
27
  return PIPE_CACHE
28
- pipe = DiffusionPipeline.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
29
  pipe.to(DEVICE)
30
- pipe.safety_checker = None # tiny pipe usually doesn't have NSFW issues; keep simple
 
 
 
 
 
 
31
  PIPE_CACHE = pipe
32
  return PIPE_CACHE
33
 
34
 
35
- # -------------------- CORE UTILS -------------------- #
36
 
37
- def decode_latent_to_pil(pipe, latent_np):
38
- """
39
- Decode a latent (C,H,W) numpy array to a PIL image using the VAE.
40
- Works for intermediate steps too.
41
- """
42
- vae = pipe.vae
43
- latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
44
- # scaling_factor is used in SD-style VAEs; fallback to standard SD value
45
- scale = getattr(vae.config, "scaling_factor", 0.18215)
46
- with torch.no_grad():
47
- image = vae.decode(latent / scale).sample
48
- image = (image / 2 + 0.5).clamp(0, 1)
49
- image = image[0].permute(1, 2, 0).cpu().numpy()
50
- image = (image * 255).astype("uint8")
51
- return Image.fromarray(image)
52
-
53
-
54
- def compute_pca_over_steps(latents_list):
55
- """
56
- latents_list: list of (C,H,W) numpy arrays.
57
- Flatten each into a single vector; run PCA across steps.
58
- Returns (S,2) array of 2D coords.
59
- """
60
- if len(latents_list) == 0:
61
- return None
62
- flat = [x.reshape(-1) for x in latents_list]
63
- mat = np.stack(flat, axis=0) # (steps, dim)
64
- if mat.shape[0] < 2 or mat.shape[1] < 2:
65
- # Not enough data for PCA; return zeros
66
- return np.zeros((mat.shape[0], 2))
67
  try:
68
  pca = PCA(n_components=2)
69
- pts = pca.fit_transform(mat)
70
  return pts
71
- except Exception:
72
- return np.zeros((mat.shape[0], 2))
73
-
74
-
75
- def compute_norms_over_steps(latents_list):
76
- """Compute L2 norm of each latent across channels & spatial dims."""
77
- if len(latents_list) == 0:
78
- return []
79
- flat = [x.reshape(-1) for x in latents_list]
80
- norms = [float(np.linalg.norm(v)) for v in flat]
81
- return norms
82
-
83
-
84
- def explain(simple=True):
85
- if simple:
86
- return (
87
- "🧒 **Simple explanation of what you see:**\n\n"
88
- "1. The model starts with a totally noisy image.\n"
89
- "2. Step by step, it removes noise and shapes the picture.\n"
90
- "3. Your words (the prompt) tell it *what* to draw.\n"
91
- "4. The slider lets you move through these steps:\n"
92
- " - Early steps = mostly noise\n"
93
- " - Later steps = clearer image\n"
94
- )
95
- else:
96
- return (
97
- "🔬 **Technical explanation:**\n\n"
98
- "- We use a tiny Stable Diffusion-style pipeline.\n"
99
- "- At each timestep `t`, the UNet predicts noise εₜ for latent `zₜ`.\n"
100
- "- The scheduler updates `zₜ → zₜ₋₁` using εₜ.\n"
101
- "- We record the latent after each step and decode it with the VAE.\n"
102
- "- PCA over flattened latents shows the trajectory in latent space.\n"
103
- "- Latent norm vs step shows how the magnitude evolves during denoising.\n"
104
- )
105
-
106
-
107
- def make_pca_figure(points, current_idx):
108
- """Make a PCA trajectory plot over steps, highlighting the selected step."""
109
- steps = list(range(len(points)))
110
- fig = go.Figure()
111
- fig.add_trace(go.Scatter(
112
- x=points[:, 0],
113
- y=points[:, 1],
114
- mode="lines+markers",
115
- name="Steps",
116
- text=[f"step {i}" for i in steps]
117
- ))
118
- if 0 <= current_idx < len(points):
119
- fig.add_trace(go.Scatter(
120
- x=[points[current_idx, 0]],
121
- y=[points[current_idx, 1]],
122
- mode="markers+text",
123
- text=[f"step {current_idx}"],
124
- textposition="top center",
125
- marker=dict(size=14, color="red"),
126
- name="Current step"
127
- ))
128
- fig.update_layout(
129
- title="Latent PCA trajectory over steps",
130
- xaxis_title="PC1",
131
- yaxis_title="PC2",
132
- height=400
133
- )
134
- return fig
135
 
136
 
137
- def make_norm_figure(norms, current_idx):
138
- """Plot latent norm vs step, highlighting the current step."""
139
- steps = list(range(len(norms)))
140
- fig = go.Figure()
141
- fig.add_trace(go.Scatter(
142
- x=steps,
143
- y=norms,
144
- mode="lines+markers",
145
- name="Latent norm"
146
- ))
147
- if 0 <= current_idx < len(norms):
148
- fig.add_trace(go.Scatter(
149
- x=[steps[current_idx]],
150
- y=[norms[current_idx]],
151
- mode="markers",
152
- marker=dict(size=14, color="red"),
153
- name="Current step"
154
- ))
155
- fig.update_layout(
156
- title="Latent L2 norm vs diffusion step",
157
- xaxis_title="Step index (0 = most noisy)",
158
- yaxis_title="‖latent‖₂",
159
- height=400
160
- )
161
- return fig
162
 
 
163
 
164
- # -------------------- MAIN ANALYSIS FUNCTION -------------------- #
165
-
166
- def run_diffusion_analysis(prompt, num_steps, guidance, seed, simple_mode):
167
- """
168
- Run the tiny diffusion pipeline, recording latents at each step.
169
- Returns Gradio updates + a state dict.
170
- """
171
- if not prompt or not prompt.strip():
172
- return (
173
- None, # final image
174
- f"⚠️ Please enter a prompt.",
175
- gr.update(maximum=0, value=0),
176
- None, None, None,
177
- {
178
- "error": "no_prompt"
179
- }
180
- )
181
 
182
  pipe = get_pipe()
183
- num_steps = int(num_steps)
184
- guidance = float(guidance)
185
 
186
- # Seed handling
187
- if seed is None or seed < 0:
188
- generator = torch.Generator(device=DEVICE)
189
- else:
190
- generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
191
 
192
  latents_list = []
193
- timesteps_list = []
194
 
195
- def callback(step, timestep, latents):
196
- # latents: (batch, C, H, W)
197
  latents_list.append(latents.detach().cpu().numpy()[0])
198
- timesteps_list.append(int(timestep))
199
 
200
  t0 = time.time()
201
- try:
202
- result = pipe(
203
- prompt,
204
- num_inference_steps=num_steps,
205
- guidance_scale=guidance,
206
- generator=generator,
207
- callback=callback,
208
- callback_steps=1,
209
- )
210
- except Exception as e:
211
- return (
212
- None,
213
- f"❌ Model / diffusion error: {e}",
214
- gr.update(maximum=0, value=0),
215
- None, None, None,
216
- {
217
- "error": "diffusion_error",
218
- "details": str(e)
219
- }
220
- )
221
-
222
- elapsed = time.time() - t0
223
-
224
- if len(latents_list) == 0:
225
- return (
226
- None,
227
- "❌ No latents were collected. Something went wrong inside the pipeline.",
228
- gr.update(maximum=0, value=0),
229
- None, None, None,
230
- {
231
- "error": "no_latents"
232
- }
233
- )
234
-
235
- final_image = result.images[0] # PIL
236
-
237
- # Compute PCA and norms over steps
238
- pca_points = compute_pca_over_steps(latents_list)
239
- norms = compute_norms_over_steps(latents_list)
240
-
241
- # Default step: last (most denoised)
242
- current_idx = len(latents_list) - 1
243
-
244
- # Decode image for current step
245
- try:
246
- step_image = decode_latent_to_pil(pipe, latents_list[current_idx])
247
- except Exception:
248
- step_image = None
249
 
250
- # Build plots
251
- pca_fig = make_pca_figure(pca_points, current_idx) if pca_points is not None else None
252
- norm_fig = make_norm_figure(norms, current_idx) if norms else None
 
 
 
 
 
 
 
 
 
253
 
254
- # Explanation
255
- explanation = explain(simple_mode)
256
- explanation += f"\n\n⏱ **Runtime:** {elapsed:.2f}s • **Steps:** {len(latents_list)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- # State dict to keep everything for slider updates
259
  state = {
260
- "prompt": prompt,
261
- "num_steps": num_steps,
262
- "guidance": guidance,
263
- "seed": seed,
264
  "latents": latents_list,
265
- "timesteps": timesteps_list,
266
- "pca_points": pca_points,
267
  "norms": norms
268
  }
269
 
270
- step_slider_update = gr.update(maximum=len(latents_list)-1, value=current_idx)
271
-
272
  return (
273
- final_image,
274
  explanation,
275
- step_slider_update,
276
  step_image,
277
- pca_fig,
278
- norm_fig,
279
  state
280
  )
281
 
282
 
283
- def update_step_view(state, step_idx):
284
- """
285
- When the user moves the step slider, update:
286
- - the decoded image at that step
287
- - the PCA plot (highlight current)
288
- - the norm plot (highlight current)
289
- """
290
- if not state or "latents" not in state:
291
- return gr.update(value=None), gr.update(value=None), gr.update(value=None)
292
-
293
- latents_list = state["latents"]
294
- pca_points = state["pca_points"]
295
- norms = state["norms"]
296
-
297
- if len(latents_list) == 0:
298
- return gr.update(value=None), gr.update(value=None), gr.update(value=None)
299
 
300
- step_idx = int(step_idx)
301
- step_idx = max(0, min(step_idx, len(latents_list) - 1))
 
 
 
 
 
 
 
302
 
303
- pipe = get_pipe()
 
 
 
 
 
 
 
304
 
305
- # Decode image at this step
306
- try:
307
- step_image = decode_latent_to_pil(pipe, latents_list[step_idx])
308
- except Exception:
309
- step_image = None
310
 
311
- # Update PCA & norm plots
312
- pca_fig = make_pca_figure(pca_points, step_idx) if pca_points is not None else None
313
- norm_fig = make_norm_figure(norms, step_idx) if norms else None
314
 
315
- return gr.update(value=step_image), gr.update(value=pca_fig), gr.update(value=norm_fig)
 
 
 
 
316
 
 
 
 
 
 
 
317
 
318
- # -------------------- GRADIO UI -------------------- #
319
 
320
- with gr.Blocks(title="Diffusion Visualizer — Noise to Image") as demo:
321
 
322
- gr.Markdown("# 🧠 Image Diffusion Visualizer (Advanced)")
323
- gr.Markdown(
324
- "See how a tiny Stable Diffusion model turns **pure noise** into an image "
325
- "step by step. Use the slider to move through the diffusion process."
326
- )
327
 
328
- with gr.Row():
329
- with gr.Column(scale=2):
330
- prompt_box = gr.Textbox(
331
- label="Prompt",
332
- value="a small house in the forest, digital art",
333
- lines=3
334
- )
335
- num_steps_slider = gr.Slider(
336
- minimum=5, maximum=50, value=20, step=1,
337
- label="Number of diffusion steps"
338
- )
339
- guidance_slider = gr.Slider(
340
- minimum=1.0, maximum=10.0, value=7.5, step=0.5,
341
- label="Guidance scale (higher = follow prompt more)"
342
- )
343
- seed_box = gr.Number(
344
- label="Seed (leave -1 for random)",
345
- value=-1,
346
- precision=0
347
- )
348
- simple_mode_chk = gr.Checkbox(
349
- label="Explain in simple terms (for kids/elders)",
350
- value=True
351
- )
352
- run_btn = gr.Button("Generate & Analyze", variant="primary")
353
-
354
- with gr.Column(scale=2):
355
- final_image = gr.Image(label="Final generated image")
356
- explanation_md = gr.Markdown(label="Explanation")
357
-
358
- gr.Markdown("### 🔍 Explore the denoising process")
359
- step_slider = gr.Slider(
360
- minimum=0, maximum=0, value=0, step=1,
361
- label="View step (0 = early, noisy • max = late, clear)"
362
- )
363
 
364
  with gr.Row():
365
  with gr.Column():
366
- step_image = gr.Image(label="Image at this diffusion step")
367
- with gr.Column():
368
- pca_plot = gr.Plot(label="Latent PCA trajectory")
 
 
 
 
369
  with gr.Column():
370
- norm_plot = gr.Plot(label="Latent norm vs step")
 
371
 
 
 
 
 
372
  state = gr.State()
373
 
374
- # Wire run button
375
- run_btn.click(
376
- run_diffusion_analysis,
377
- inputs=[prompt_box, num_steps_slider, guidance_slider, seed_box, simple_mode_chk],
378
- outputs=[final_image, explanation_md, step_slider, step_image, pca_plot, norm_plot, state]
379
  )
380
 
381
- # Wire slider change
382
- step_slider.change(
383
- update_step_view,
384
- inputs=[state, step_slider],
385
- outputs=[step_image, pca_plot, norm_plot]
386
- )
387
 
388
  demo.launch()
 
1
+ # ==========================================================
2
+ # Stable Diffusion v1-4 CPU Optimized Diffusion Visualizer
3
+ # REAL images (256×256) on free HuggingFace CPU
4
+ # With: step-by-step latents, PCA path, norm plot, latents decode
5
+ # ==========================================================
6
 
7
  import gradio as gr
8
  import torch
9
  import numpy as np
10
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
11
  from sklearn.decomposition import PCA
12
  import plotly.graph_objects as go
 
13
  from PIL import Image
14
  import time
15
+ import warnings
16
 
17
+ warnings.filterwarnings("ignore")
18
+
19
+ # ------------------- CPU SETTINGS -------------------
20
+
21
+ DEVICE = "cpu"
22
+
23
+ # Disable MKLDNN for safety (prevents matmul errors on SD)
24
+ torch.backends.mkldnn.enabled = False
25
+
26
+ MODEL_ID = "CompVis/stable-diffusion-v1-4"
27
 
28
  PIPE_CACHE = None
29
 
30
 
31
+ # ------------------- LOAD SD MODEL -------------------
32
 
33
  def get_pipe():
 
34
  global PIPE_CACHE
35
+ if PIPE_CACHE:
36
  return PIPE_CACHE
37
+
38
+ pipe = StableDiffusionPipeline.from_pretrained(
39
+ MODEL_ID,
40
+ torch_dtype=torch.float32,
41
+ low_cpu_mem_usage=True,
42
+ )
43
+
44
+ # Replace scheduler with DDIM (better for stepping)
45
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
46
+
47
  pipe.to(DEVICE)
48
+
49
+ # VERY IMPORTANT: disable safety checker to avoid weird errors on CPU
50
+ pipe.safety_checker = lambda images, clip_input: (images, False)
51
+
52
+ # Disable features not needed
53
+ pipe.enable_attention_slicing(None)
54
+
55
  PIPE_CACHE = pipe
56
  return PIPE_CACHE
57
 
58
 
59
+ # ------------------- PCA + NORM -------------------
60
 
61
+ def compute_pca(latents):
62
+ flat = [x.flatten() for x in latents]
63
+ X = np.stack(flat)
64
+ if X.shape[0] < 2:
65
+ return np.zeros((X.shape[0], 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
  pca = PCA(n_components=2)
68
+ pts = pca.fit_transform(X)
69
  return pts
70
+ except:
71
+ return np.zeros((X.shape[0], 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
+ def compute_norm(latents):
75
+ return [float(np.linalg.norm(x.flatten())) for x in latents]
76
+
77
+
78
+ # ------------------- LATENT DECODER -------------------
79
+
80
+ def decode_latent(pipe, latent_np):
81
+ latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
82
+ scale = pipe.vae.config.scaling_factor
83
+ with torch.no_grad():
84
+ image = pipe.vae.decode(latent / scale).sample
85
+ image = (image / 2 + 0.5).clamp(0, 1)
86
+ np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
87
+ return Image.fromarray(np_img)
88
+
89
+
90
+ # ------------------- RUN DIFFUSION -------------------
 
 
 
 
 
 
 
 
91
 
92
+ def run_diffusion(prompt, steps, guidance, seed, simple):
93
 
94
+ if not prompt.strip():
95
+ return None, "Enter prompt", gr.update(), None, None, None, {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  pipe = get_pipe()
 
 
98
 
99
+ generator = torch.Generator("cpu").manual_seed(seed if seed >= 0 else int(time.time()))
 
 
 
 
100
 
101
  latents_list = []
102
+ timesteps = []
103
 
104
+ def cb(step, t, latents):
 
105
  latents_list.append(latents.detach().cpu().numpy()[0])
106
+ timesteps.append(int(t))
107
 
108
  t0 = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ result = pipe(
111
+ prompt,
112
+ height=256,
113
+ width=256,
114
+ num_inference_steps=steps,
115
+ guidance_scale=guidance,
116
+ generator=generator,
117
+ callback=cb,
118
+ callback_steps=1,
119
+ )
120
+
121
+ total = time.time() - t0
122
 
123
+ final = result.images[0]
124
+
125
+ pca = compute_pca(latents_list)
126
+ norms = compute_norm(latents_list)
127
+
128
+ cur = len(latents_list) - 1
129
+ step_image = decode_latent(pipe, latents_list[cur])
130
+
131
+ explanation = (
132
+ "🧒 **Simple Explanation**\n"
133
+ "The model starts with noise, slowly removes it, and reveals an image.\n"
134
+ if simple else
135
+ "🔬 **Technical Explanation**\n"
136
+ "We collect latents at each DDIM step, decode them via VAE, and visualize their PCA path."
137
+ )
138
+ explanation += f"\n⏱ Runtime: {total:.2f}s"
139
 
 
140
  state = {
 
 
 
 
141
  "latents": latents_list,
142
+ "pca": pca,
 
143
  "norms": norms
144
  }
145
 
 
 
146
  return (
147
+ final,
148
  explanation,
149
+ gr.update(maximum=len(latents_list)-1, value=cur),
150
  step_image,
151
+ plot_pca(pca, cur),
152
+ plot_norm(norms, cur),
153
  state
154
  )
155
 
156
 
157
+ # ------------------- PLOT FUNCTIONS -------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ def plot_pca(points, idx):
160
+ fig = go.Figure()
161
+ fig.add_trace(go.Scatter(x=points[:,0], y=points[:,1], mode="lines+markers"))
162
+ fig.add_trace(go.Scatter(
163
+ x=[points[idx,0]], y=[points[idx,1]],
164
+ mode="markers", marker=dict(size=12, color="red")
165
+ ))
166
+ fig.update_layout(height=350, title="PCA Trajectory")
167
+ return fig
168
 
169
+ def plot_norm(norms, idx):
170
+ fig = go.Figure()
171
+ fig.add_trace(go.Scatter(y=norms, mode="lines+markers"))
172
+ fig.add_trace(go.Scatter(
173
+ x=[idx], y=[norms[idx]], mode="markers", marker=dict(size=12, color="red")
174
+ ))
175
+ fig.update_layout(height=350, title="Latent Norm Over Steps")
176
+ return fig
177
 
 
 
 
 
 
178
 
179
+ # ------------------- SLIDER UPDATE -------------------
 
 
180
 
181
+ def update_step(state, idx):
182
+ latents = state["latents"]
183
+ pca = state["pca"]
184
+ norms = state["norms"]
185
+ pipe = get_pipe()
186
 
187
+ img = decode_latent(pipe, latents[idx])
188
+ return (
189
+ img,
190
+ plot_pca(pca, idx),
191
+ plot_norm(norms, idx)
192
+ )
193
 
 
194
 
195
+ # ------------------- UI -------------------
196
 
197
+ with gr.Blocks(title="SD v1-4 CPU Diffusion Visualizer") as demo:
 
 
 
 
198
 
199
+ gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)")
200
+ gr.Markdown("This version produces **real images**, optimized for free HF CPU.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  with gr.Row():
203
  with gr.Column():
204
+ prompt = gr.Textbox(label="Prompt", value="a cute cat in watercolor")
205
+ steps = gr.Slider(10, 30, value=20, step=1, label="Steps")
206
+ guidance = gr.Slider(3, 12, value=7.5, step=0.5, label="Guidance")
207
+ seed = gr.Number(label="Seed (-1 for random)", value=-1)
208
+ simple = gr.Checkbox(label="Simple Explanation", value=True)
209
+ run = gr.Button("Run Diffusion", variant="primary")
210
+
211
  with gr.Column():
212
+ final = gr.Image(label="Final Image")
213
+ expl = gr.Markdown()
214
 
215
+ step_slider = gr.Slider(0, 0, value=0, step=1, label="View Step")
216
+ step_img = gr.Image(label="Latent Image at Step")
217
+ pca_plot = gr.Plot(label="PCA")
218
+ norm_plot = gr.Plot(label="Norm Plot")
219
  state = gr.State()
220
 
221
+ run.click(
222
+ run_diffusion,
223
+ inputs=[prompt, steps, guidance, seed, simple],
224
+ outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state]
 
225
  )
226
 
227
+ step_slider.change(update_step, [state, step_slider], [step_img, pca_plot, norm_plot])
 
 
 
 
 
228
 
229
  demo.launch()