PraneshJs commited on
Commit
3cb2178
·
verified ·
1 Parent(s): 4a4fecc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -83
app.py CHANGED
@@ -1,7 +1,9 @@
1
  # ==========================================================
2
- # Stable Diffusion v1-4 — CPU Optimized Diffusion Visualizer
3
- # REAL images (256×256) on free HuggingFace CPU
4
- # With: step-by-step latents, PCA path, norm plot, latents decode
 
 
5
  # ==========================================================
6
 
7
  import gradio as gr
@@ -20,7 +22,7 @@ warnings.filterwarnings("ignore")
20
 
21
  DEVICE = "cpu"
22
 
23
- # Disable MKLDNN for safety (prevents matmul errors on SD)
24
  torch.backends.mkldnn.enabled = False
25
 
26
  MODEL_ID = "CompVis/stable-diffusion-v1-4"
@@ -31,27 +33,26 @@ PIPE_CACHE = None
31
  # ------------------- LOAD SD MODEL -------------------
32
 
33
  def get_pipe():
 
 
 
 
34
  global PIPE_CACHE
35
- if PIPE_CACHE:
36
  return PIPE_CACHE
37
 
38
  pipe = StableDiffusionPipeline.from_pretrained(
39
  MODEL_ID,
40
  torch_dtype=torch.float32,
41
- low_cpu_mem_usage=True,
 
42
  )
43
 
44
- # Replace scheduler with DDIM (better for stepping)
45
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
46
 
47
  pipe.to(DEVICE)
48
 
49
- # VERY IMPORTANT: disable safety checker to avoid weird errors on CPU
50
- pipe.safety_checker = lambda images, clip_input: (images, False)
51
-
52
- # Disable features not needed
53
- pipe.enable_attention_slicing(None)
54
-
55
  PIPE_CACHE = pipe
56
  return PIPE_CACHE
57
 
@@ -59,6 +60,12 @@ def get_pipe():
59
  # ------------------- PCA + NORM -------------------
60
 
61
  def compute_pca(latents):
 
 
 
 
 
 
62
  flat = [x.flatten() for x in latents]
63
  X = np.stack(flat)
64
  if X.shape[0] < 2:
@@ -67,17 +74,25 @@ def compute_pca(latents):
67
  pca = PCA(n_components=2)
68
  pts = pca.fit_transform(X)
69
  return pts
70
- except:
71
  return np.zeros((X.shape[0], 2))
72
 
73
 
74
  def compute_norm(latents):
 
 
 
 
 
75
  return [float(np.linalg.norm(x.flatten())) for x in latents]
76
 
77
 
78
  # ------------------- LATENT DECODER -------------------
79
 
80
  def decode_latent(pipe, latent_np):
 
 
 
81
  latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
82
  scale = pipe.vae.config.scaling_factor
83
  with torch.no_grad():
@@ -87,135 +102,286 @@ def decode_latent(pipe, latent_np):
87
  return Image.fromarray(np_img)
88
 
89
 
90
- # ------------------- RUN DIFFUSION -------------------
91
 
92
  def run_diffusion(prompt, steps, guidance, seed, simple):
93
-
94
- if not prompt.strip():
95
- return None, "Enter prompt", gr.update(), None, None, None, {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  pipe = get_pipe()
98
 
99
- generator = torch.Generator("cpu").manual_seed(seed if seed >= 0 else int(time.time()))
 
 
 
 
 
 
 
 
100
 
101
  latents_list = []
102
  timesteps = []
103
 
104
- def cb(step, t, latents):
 
105
  latents_list.append(latents.detach().cpu().numpy()[0])
106
- timesteps.append(int(t))
107
 
108
  t0 = time.time()
109
-
110
- result = pipe(
111
- prompt,
112
- height=256,
113
- width=256,
114
- num_inference_steps=steps,
115
- guidance_scale=guidance,
116
- generator=generator,
117
- callback=cb,
118
- callback_steps=1,
119
- )
 
 
 
 
 
 
 
 
 
 
120
 
121
  total = time.time() - t0
122
 
123
- final = result.images[0]
124
-
125
- pca = compute_pca(latents_list)
 
 
 
 
 
 
 
 
 
 
 
 
126
  norms = compute_norm(latents_list)
127
 
128
- cur = len(latents_list) - 1
129
- step_image = decode_latent(pipe, latents_list[cur])
130
-
131
- explanation = (
132
- "🧒 **Simple Explanation**\n"
133
- "The model starts with noise, slowly removes it, and reveals an image.\n"
134
- if simple else
135
- "🔬 **Technical Explanation**\n"
136
- "We collect latents at each DDIM step, decode them via VAE, and visualize their PCA path."
137
- )
138
- explanation += f"\n⏱ Runtime: {total:.2f}s"
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  state = {
141
  "latents": latents_list,
142
- "pca": pca,
143
  "norms": norms
144
  }
145
 
 
 
146
  return (
147
- final,
148
  explanation,
149
- gr.update(maximum=len(latents_list)-1, value=cur),
150
  step_image,
151
- plot_pca(pca, cur),
152
- plot_norm(norms, cur),
153
  state
154
  )
155
 
156
 
157
- # ------------------- PLOT FUNCTIONS -------------------
158
 
159
  def plot_pca(points, idx):
 
 
 
 
 
 
 
 
160
  fig = go.Figure()
161
- fig.add_trace(go.Scatter(x=points[:,0], y=points[:,1], mode="lines+markers"))
162
  fig.add_trace(go.Scatter(
163
- x=[points[idx,0]], y=[points[idx,1]],
164
- mode="markers", marker=dict(size=12, color="red")
 
 
 
165
  ))
166
- fig.update_layout(height=350, title="PCA Trajectory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  return fig
168
 
 
169
  def plot_norm(norms, idx):
 
 
 
 
 
 
170
  fig = go.Figure()
171
- fig.add_trace(go.Scatter(y=norms, mode="lines+markers"))
172
  fig.add_trace(go.Scatter(
173
- x=[idx], y=[norms[idx]], mode="markers", marker=dict(size=12, color="red")
 
 
 
174
  ))
175
- fig.update_layout(height=350, title="Latent Norm Over Steps")
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  return fig
177
 
178
 
179
  # ------------------- SLIDER UPDATE -------------------
180
 
181
  def update_step(state, idx):
 
 
 
 
 
 
 
 
 
182
  latents = state["latents"]
183
- pca = state["pca"]
184
  norms = state["norms"]
 
 
 
 
 
 
 
185
  pipe = get_pipe()
186
 
187
- img = decode_latent(pipe, latents[idx])
188
- return (
189
- img,
190
- plot_pca(pca, idx),
191
- plot_norm(norms, idx)
192
- )
 
193
 
 
194
 
195
- # ------------------- UI -------------------
196
 
197
- with gr.Blocks(title="SD v1-4 CPU Diffusion Visualizer") as demo:
 
 
198
 
199
  gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)")
200
- gr.Markdown("This version produces **real images**, optimized for free HF CPU.")
 
 
 
 
 
201
 
202
  with gr.Row():
203
  with gr.Column():
204
- prompt = gr.Textbox(label="Prompt", value="a cute cat in watercolor")
205
- steps = gr.Slider(10, 30, value=20, step=1, label="Steps")
206
- guidance = gr.Slider(3, 12, value=7.5, step=0.5, label="Guidance")
207
- seed = gr.Number(label="Seed (-1 for random)", value=-1)
208
- simple = gr.Checkbox(label="Simple Explanation", value=True)
209
- run = gr.Button("Run Diffusion", variant="primary")
 
 
 
 
210
 
211
  with gr.Column():
212
- final = gr.Image(label="Final Image")
213
- expl = gr.Markdown()
 
 
 
 
 
 
 
214
 
215
- step_slider = gr.Slider(0, 0, value=0, step=1, label="View Step")
216
- step_img = gr.Image(label="Latent Image at Step")
217
- pca_plot = gr.Plot(label="PCA")
218
- norm_plot = gr.Plot(label="Norm Plot")
219
  state = gr.State()
220
 
221
  run.click(
@@ -224,6 +390,10 @@ with gr.Blocks(title="SD v1-4 CPU Diffusion Visualizer") as demo:
224
  outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state]
225
  )
226
 
227
- step_slider.change(update_step, [state, step_slider], [step_img, pca_plot, norm_plot])
 
 
 
 
228
 
229
- demo.launch()
 
1
  # ==========================================================
2
+ # Stable Diffusion v1-4 — CPU Diffusion Visualizer (256x256)
3
+ # - Runs on HF CPU
4
+ # - Real images (not blurry)
5
+ # - Step-by-step latents
6
+ # - PCA trajectory + latent norm plots
7
  # ==========================================================
8
 
9
  import gradio as gr
 
22
 
23
  DEVICE = "cpu"
24
 
25
+ # Sometimes MKLDNN causes weird matmul errors with SD on some CPUs, disable to be safe.
26
  torch.backends.mkldnn.enabled = False
27
 
28
  MODEL_ID = "CompVis/stable-diffusion-v1-4"
 
33
  # ------------------- LOAD SD MODEL -------------------
34
 
35
  def get_pipe():
36
+ """
37
+ Load and cache the Stable Diffusion v1-4 pipeline on CPU,
38
+ with safety checker DISABLED correctly.
39
+ """
40
  global PIPE_CACHE
41
+ if PIPE_CACHE is not None:
42
  return PIPE_CACHE
43
 
44
  pipe = StableDiffusionPipeline.from_pretrained(
45
  MODEL_ID,
46
  torch_dtype=torch.float32,
47
+ safety_checker=None, # <--- disable safety checker properly
48
+ requires_safety_checker=False
49
  )
50
 
51
+ # Use DDIM so we have clear, predictable timesteps for visualization
52
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
53
 
54
  pipe.to(DEVICE)
55
 
 
 
 
 
 
 
56
  PIPE_CACHE = pipe
57
  return PIPE_CACHE
58
 
 
60
  # ------------------- PCA + NORM -------------------
61
 
62
  def compute_pca(latents):
63
+ """
64
+ latents: list of (C,H,W) numpy arrays.
65
+ Returns Nx2 array of PCA coords (one point per step).
66
+ """
67
+ if not latents:
68
+ return np.zeros((0, 2))
69
  flat = [x.flatten() for x in latents]
70
  X = np.stack(flat)
71
  if X.shape[0] < 2:
 
74
  pca = PCA(n_components=2)
75
  pts = pca.fit_transform(X)
76
  return pts
77
+ except Exception:
78
  return np.zeros((X.shape[0], 2))
79
 
80
 
81
  def compute_norm(latents):
82
+ """
83
+ L2 norm of each latent over all dims.
84
+ """
85
+ if not latents:
86
+ return []
87
  return [float(np.linalg.norm(x.flatten())) for x in latents]
88
 
89
 
90
  # ------------------- LATENT DECODER -------------------
91
 
92
  def decode_latent(pipe, latent_np):
93
+ """
94
+ Decode a single latent (C,H,W) numpy array into a 256x256 RGB PIL image.
95
+ """
96
  latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
97
  scale = pipe.vae.config.scaling_factor
98
  with torch.no_grad():
 
102
  return Image.fromarray(np_img)
103
 
104
 
105
+ # ------------------- MAIN DIFFUSION RUN -------------------
106
 
107
  def run_diffusion(prompt, steps, guidance, seed, simple):
108
+ """
109
+ Run SD v1-4 at 256x256, capturing latents at EVERY step via callback.
110
+ Returns:
111
+ - final image
112
+ - explanation text
113
+ - step slider config
114
+ - image at current step
115
+ - PCA plot
116
+ - norm plot
117
+ - state dict (for slider updates)
118
+ """
119
+
120
+ if not prompt or not prompt.strip():
121
+ return (
122
+ None,
123
+ "⚠️ Please enter a prompt.",
124
+ gr.update(maximum=0, value=0),
125
+ None,
126
+ None,
127
+ None,
128
+ {}
129
+ )
130
 
131
  pipe = get_pipe()
132
 
133
+ steps = int(steps)
134
+ guidance = float(guidance)
135
+
136
+ if seed is None or seed < 0:
137
+ seed_val = int(time.time())
138
+ else:
139
+ seed_val = int(seed)
140
+
141
+ generator = torch.Generator(device=DEVICE).manual_seed(seed_val)
142
 
143
  latents_list = []
144
  timesteps = []
145
 
146
+ def callback(step: int, timestep: int, latents: torch.FloatTensor):
147
+ # latents shape: (batch, C, H, W)
148
  latents_list.append(latents.detach().cpu().numpy()[0])
149
+ timesteps.append(int(timestep))
150
 
151
  t0 = time.time()
152
+ try:
153
+ result = pipe(
154
+ prompt,
155
+ height=256,
156
+ width=256,
157
+ num_inference_steps=steps,
158
+ guidance_scale=guidance,
159
+ generator=generator,
160
+ callback=callback,
161
+ callback_steps=1,
162
+ )
163
+ except Exception as e:
164
+ return (
165
+ None,
166
+ f"❌ Diffusion error: {e}",
167
+ gr.update(maximum=0, value=0),
168
+ None,
169
+ None,
170
+ None,
171
+ {"error": str(e)}
172
+ )
173
 
174
  total = time.time() - t0
175
 
176
+ if not latents_list:
177
+ return (
178
+ None,
179
+ "❌ No latents collected. Something went wrong inside the pipeline.",
180
+ gr.update(maximum=0, value=0),
181
+ None,
182
+ None,
183
+ None,
184
+ {"error": "no_latents"}
185
+ )
186
+
187
+ final_image = result.images[0] # PIL
188
+
189
+ # Compute PCA trajectory and norms
190
+ pca_pts = compute_pca(latents_list)
191
  norms = compute_norm(latents_list)
192
 
193
+ current_idx = len(latents_list) - 1 # final step
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # Decode image at current step
196
+ try:
197
+ step_image = decode_latent(pipe, latents_list[current_idx])
198
+ except Exception:
199
+ step_image = None
200
+
201
+ # Explanation text
202
+ if simple:
203
+ explanation = (
204
+ "🧒 **Simple explanation of what you see:**\n\n"
205
+ "1. The model starts from pure noise.\n"
206
+ "2. At each step, it removes some noise and makes the picture clearer.\n"
207
+ "3. Your text prompt tells it what kind of picture to create.\n"
208
+ "4. You can move the slider to see the image at different steps.\n"
209
+ )
210
+ else:
211
+ explanation = (
212
+ "🔬 **Technical explanation:**\n\n"
213
+ "- We run a DDIM diffusion process over the latent space.\n"
214
+ "- At each timestep `t`, the UNet predicts noise εₜ and the scheduler updates `zₜ → zₜ₋₁`.\n"
215
+ "- We record `zₜ` at every step and decode it with the VAE.\n"
216
+ "- PCA over flattened latents gives a 2D trajectory of the diffusion path.\n"
217
+ "- The L2 norm plot shows how the latent magnitude evolves per step.\n"
218
+ )
219
+ explanation += f"\n⏱ **Runtime:** {total:.2f}s • **Steps:** {len(latents_list)} • Seed: {seed_val}"
220
+
221
+ # Build plots
222
+ pca_fig = plot_pca(pca_pts, current_idx) if len(pca_pts) > 0 else None
223
+ norm_fig = plot_norm(norms, current_idx) if norms else None
224
+
225
+ # State for slider updates
226
  state = {
227
  "latents": latents_list,
228
+ "pca": pca_pts,
229
  "norms": norms
230
  }
231
 
232
+ step_slider_update = gr.update(maximum=len(latents_list) - 1, value=current_idx)
233
+
234
  return (
235
+ final_image,
236
  explanation,
237
+ step_slider_update,
238
  step_image,
239
+ pca_fig,
240
+ norm_fig,
241
  state
242
  )
243
 
244
 
245
+ # ------------------- PLOT HELPERS -------------------
246
 
247
  def plot_pca(points, idx):
248
+ """
249
+ PCA trajectory plot over steps, highlighting current step.
250
+ points: (N,2)
251
+ """
252
+ if points.shape[0] == 0:
253
+ return None
254
+
255
+ steps = list(range(points.shape[0]))
256
  fig = go.Figure()
 
257
  fig.add_trace(go.Scatter(
258
+ x=points[:, 0],
259
+ y=points[:, 1],
260
+ mode="lines+markers",
261
+ name="steps",
262
+ text=[f"step {i}" for i in steps]
263
  ))
264
+ if 0 <= idx < len(steps):
265
+ fig.add_trace(go.Scatter(
266
+ x=[points[idx, 0]],
267
+ y=[points[idx, 1]],
268
+ mode="markers+text",
269
+ text=[f"step {idx}"],
270
+ textposition="top center",
271
+ marker=dict(size=12, color="red"),
272
+ name="current"
273
+ ))
274
+ fig.update_layout(
275
+ title="Latent PCA trajectory",
276
+ xaxis_title="PC1",
277
+ yaxis_title="PC2",
278
+ height=350
279
+ )
280
  return fig
281
 
282
+
283
  def plot_norm(norms, idx):
284
+ """
285
+ Plot latent L2 norm vs step, highlight current step.
286
+ """
287
+ if not norms:
288
+ return None
289
+ steps = list(range(len(norms)))
290
  fig = go.Figure()
 
291
  fig.add_trace(go.Scatter(
292
+ x=steps,
293
+ y=norms,
294
+ mode="lines+markers",
295
+ name="‖latent‖₂"
296
  ))
297
+ if 0 <= idx < len(steps):
298
+ fig.add_trace(go.Scatter(
299
+ x=[idx],
300
+ y=[norms[idx]],
301
+ mode="markers",
302
+ marker=dict(size=12, color="red"),
303
+ name="current"
304
+ ))
305
+ fig.update_layout(
306
+ title="Latent L2 norm vs step",
307
+ xaxis_title="Step index",
308
+ yaxis_title="‖latent‖₂",
309
+ height=350
310
+ )
311
  return fig
312
 
313
 
314
  # ------------------- SLIDER UPDATE -------------------
315
 
316
  def update_step(state, idx):
317
+ """
318
+ When user moves the slider:
319
+ - decode latent at that step
320
+ - update PCA highlight
321
+ - update norm highlight
322
+ """
323
+ if not state or "latents" not in state:
324
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
325
+
326
  latents = state["latents"]
327
+ pca_pts = state["pca"]
328
  norms = state["norms"]
329
+
330
+ if not latents:
331
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
332
+
333
+ idx = int(idx)
334
+ idx = max(0, min(idx, len(latents) - 1))
335
+
336
  pipe = get_pipe()
337
 
338
+ try:
339
+ img = decode_latent(pipe, latents[idx])
340
+ except Exception:
341
+ img = None
342
+
343
+ pca_fig = plot_pca(pca_pts, idx) if pca_pts is not None else None
344
+ norm_fig = plot_norm(norms, idx) if norms else None
345
 
346
+ return gr.update(value=img), gr.update(value=pca_fig), gr.update(value=norm_fig)
347
 
 
348
 
349
+ # ------------------- GRADIO UI -------------------
350
+
351
+ with gr.Blocks(title="Stable Diffusion v1-4 — CPU Diffusion Visualizer") as demo:
352
 
353
  gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)")
354
+ gr.Markdown(
355
+ "This app shows **how a real Stable Diffusion model** turns noise into an image, step by step.\n"
356
+ "- Uses `CompVis/stable-diffusion-v1-4` on CPU\n"
357
+ "- 256×256 resolution for speed\n"
358
+ "- You can scrub through all diffusion steps\n"
359
+ )
360
 
361
  with gr.Row():
362
  with gr.Column():
363
+ prompt = gr.Textbox(
364
+ label="Prompt",
365
+ value="a small cozy cabin in the forest, digital art",
366
+ lines=3
367
+ )
368
+ steps = gr.Slider(10, 30, value=20, step=1, label="Number of diffusion steps")
369
+ guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.5, label="Guidance scale")
370
+ seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
371
+ simple = gr.Checkbox(label="Simple explanation", value=True)
372
+ run = gr.Button("Run diffusion", variant="primary")
373
 
374
  with gr.Column():
375
+ final = gr.Image(label="Final generated image")
376
+ expl = gr.Markdown(label="Explanation")
377
+
378
+ gr.Markdown("### 🔍 Explore the denoising process step-by-step")
379
+
380
+ step_slider = gr.Slider(0, 0, value=0, step=1, label="View step (0 = early noise, max = final)")
381
+ step_img = gr.Image(label="Image at this diffusion step")
382
+ pca_plot = gr.Plot(label="Latent PCA trajectory")
383
+ norm_plot = gr.Plot(label="Latent norm vs step")
384
 
 
 
 
 
385
  state = gr.State()
386
 
387
  run.click(
 
390
  outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state]
391
  )
392
 
393
+ step_slider.change(
394
+ update_step,
395
+ inputs=[state, step_slider],
396
+ outputs=[step_img, pca_plot, norm_plot]
397
+ )
398
 
399
+ demo.launch()