king159 commited on
Commit
e4924c3
1 Parent(s): 771cc14
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -32,11 +32,13 @@ article = r"""
32
  <br>
33
  If you found this demo/our paper useful, please consider citing:
34
  ```bibtex
35
- @article{he024paid,
36
- title={PAID:(Prompt-guided) Attention Interpolation of Text-to-Image Diffusion},
37
- author={He, Qiyuan and Wang, Jinghao and Liu, Ziwei and Angle, Yao},
38
- journal={},
39
- year={2024}
 
 
40
  }
41
  ```
42
  📧 **Contact**
@@ -50,18 +52,17 @@ USE_TORCH_COMPILE = False
50
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
51
  PREVIEW_IMAGES = False
52
 
53
- dtype = torch.float32
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  pipeline = InterpolationStableDiffusionPipeline(
56
  repo_name="runwayml/stable-diffusion-v1-5",
57
  guidance_scale=10.0,
58
  scheduler_name="unipc",
59
  )
60
- pipeline.to(device, dtype=dtype)
61
 
62
 
63
  def change_model_fn(model_name: str) -> None:
64
- global pipeline
65
  name_mapping = {
66
  "SD1.4-521": "CompVis/stable-diffusion-v1-4",
67
  "SD1.5-512": "runwayml/stable-diffusion-v1-5",
@@ -69,17 +70,21 @@ def change_model_fn(model_name: str) -> None:
69
  "SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
70
  }
71
  if "XL" not in model_name:
72
- pipeline = InterpolationStableDiffusionPipeline(
73
  repo_name=name_mapping[model_name],
74
  guidance_scale=10.0,
75
  scheduler_name="unipc",
76
  )
77
- pipeline.to(device, dtype=dtype)
78
  else:
79
- pipeline = InterpolationStableDiffusionXLPipeline.from_pretrained(
80
- name_mapping[model_name]
 
 
 
 
81
  )
82
- pipeline.to(device, dtype=dtype)
83
 
84
 
85
  def save_image(img, index):
@@ -107,7 +112,7 @@ def plot_gemma_fn(alpha: float, beta: float, size: int) -> pd.DataFrame:
107
  )
108
 
109
 
110
- def get_example() -> list:
111
  case = [
112
  [
113
  "A photo of dog, best quality, extremely detailed",
@@ -115,7 +120,7 @@ def get_example() -> list:
115
  3,
116
  6,
117
  3,
118
- "A photo of a dog driving a car, logical, best quality, extremely detailed",
119
  "monochrome, lowres, bad anatomy, worst quality, low quality",
120
  "SD1.5-512",
121
  6.1 / 50,
@@ -125,11 +130,52 @@ def get_example() -> list:
125
  "self",
126
  1002,
127
  True,
128
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  ]
130
  return case
131
 
132
 
 
 
 
 
 
 
 
133
  def dynamic_gallery_fn(interpolation_size: int):
134
 
135
  return gr.Gallery(
@@ -192,6 +238,9 @@ def generate(
192
  negative_prompt=negative_prompt,
193
  guidance_scale=guidance_scale,
194
  )
 
 
 
195
  if interpolation_size == 3:
196
  final_images = images
197
  break
@@ -206,7 +255,7 @@ def generate(
206
 
207
  interpolation_size = None
208
 
209
- with gr.Blocks() as demo:
210
  gr.Markdown(title)
211
  gr.Markdown(description)
212
  with gr.Group():
@@ -225,7 +274,7 @@ with gr.Blocks() as demo:
225
  value="A photo of car, best quality, extremely detaile",
226
  )
227
  result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3)
228
- generate_button = gr.Button("Generate", variant="primary")
229
  with gr.Accordion("Advanced options", open=True):
230
  with gr.Group():
231
  with gr.Row():
@@ -242,14 +291,14 @@ with gr.Blocks() as demo:
242
  label="alpha",
243
  minimum=1,
244
  maximum=50,
245
- step=0.1,
246
  value=6.0,
247
  )
248
  beta = gr.Slider(
249
  label="beta",
250
  minimum=1,
251
  maximum=50,
252
- step=0.1,
253
  value=3.0,
254
  )
255
  gamma_plot = gr.LinePlot(
@@ -346,6 +395,7 @@ with gr.Blocks() as demo:
346
  label="Model",
347
  value="SD1.5-512",
348
  interactive=True,
 
349
  )
350
  with gr.Column():
351
  seed = gr.Slider(
@@ -381,8 +431,6 @@ with gr.Blocks() as demo:
381
  seed,
382
  same_latent,
383
  ],
384
- outputs=result,
385
- fn=generate,
386
  cache_examples=CACHE_EXAMPLES,
387
  )
388
 
@@ -395,7 +443,15 @@ with gr.Blocks() as demo:
395
  interpolation_size.change(
396
  fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
397
  )
398
- model_choice.change(fn=change_model_fn, inputs=model_choice)
 
 
 
 
 
 
 
 
399
  inputs = [
400
  prompt1,
401
  prompt2,
@@ -423,9 +479,4 @@ with gr.Blocks() as demo:
423
  )
424
  gr.Markdown(article)
425
 
426
- with gr.Blocks(css="style.css") as demo_with_history:
427
- with gr.Tab("App"):
428
- demo.render()
429
-
430
- if __name__ == "__main__":
431
- demo_with_history.queue(max_size=20).launch()
 
32
  <br>
33
  If you found this demo/our paper useful, please consider citing:
34
  ```bibtex
35
+ @misc{he2024aid,
36
+ title={AID: Attention Interpolation of Text-to-Image Diffusion},
37
+ author={Qiyuan He and Jinghao Wang and Ziwei Liu and Angela Yao},
38
+ year={2024},
39
+ eprint={2403.17924},
40
+ archivePrefix={arXiv},
41
+ primaryClass={cs.CV}
42
  }
43
  ```
44
  📧 **Contact**
 
52
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
53
  PREVIEW_IMAGES = False
54
 
 
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  pipeline = InterpolationStableDiffusionPipeline(
57
  repo_name="runwayml/stable-diffusion-v1-5",
58
  guidance_scale=10.0,
59
  scheduler_name="unipc",
60
  )
61
+ pipeline.to(device, dtype=torch.float32)
62
 
63
 
64
  def change_model_fn(model_name: str) -> None:
65
+ global device
66
  name_mapping = {
67
  "SD1.4-521": "CompVis/stable-diffusion-v1-4",
68
  "SD1.5-512": "runwayml/stable-diffusion-v1-5",
 
70
  "SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
71
  }
72
  if "XL" not in model_name:
73
+ globals()["pipeline"] = InterpolationStableDiffusionPipeline(
74
  repo_name=name_mapping[model_name],
75
  guidance_scale=10.0,
76
  scheduler_name="unipc",
77
  )
78
+ globals()["pipeline"].to(device, dtype=torch.float32)
79
  else:
80
+ if device == torch.device("cpu"):
81
+ dtype = torch.float32
82
+ else:
83
+ dtype = torch.float16
84
+ globals()["pipeline"] = InterpolationStableDiffusionXLPipeline.from_pretrained(
85
+ name_mapping[model_name], torch_dtype=dtype
86
  )
87
+ globals()["pipeline"].to(device)
88
 
89
 
90
  def save_image(img, index):
 
112
  )
113
 
114
 
115
+ def get_example() -> list[list[str | float | int]]:
116
  case = [
117
  [
118
  "A photo of dog, best quality, extremely detailed",
 
120
  3,
121
  6,
122
  3,
123
+ "A car with dog furry texture, best quality, extremely detailed",
124
  "monochrome, lowres, bad anatomy, worst quality, low quality",
125
  "SD1.5-512",
126
  6.1 / 50,
 
130
  "self",
131
  1002,
132
  True,
133
+ ],
134
+ [
135
+ "A photo of dog, best quality, extremely detailed",
136
+ "A photo of car, best quality, extremely detailed",
137
+ 7,
138
+ 8,
139
+ 8,
140
+ "A toy named dog-car, best quality, extremely detailed",
141
+ "monochrome, lowres, bad anatomy, worst quality, low quality",
142
+ "SD1.5-512",
143
+ 8.1 / 50,
144
+ 10,
145
+ 50,
146
+ "fused_inner",
147
+ "self",
148
+ 1002,
149
+ True,
150
+ ],
151
+ [
152
+ "anime artwork a Pokemon called Pikachu sitting on the grass, dramatic, anime style, key visual, vibrant, studio anime, highly detailed",
153
+ "anime artwork a beautiful girl, dramatic, anime style, key visual, vibrant, studio anime, highly detailed",
154
+ 7,
155
+ 3,
156
+ 3,
157
+ None,
158
+ "monochrome, lowres, bad anatomy, worst quality, low quality",
159
+ "SDXL-1024",
160
+ 25 / 50,
161
+ 10,
162
+ 50,
163
+ "fused_outer",
164
+ "self",
165
+ 1002,
166
+ False,
167
+ ],
168
  ]
169
  return case
170
 
171
 
172
+ def change_generate_button_fn(enable: int) -> gr.Button:
173
+ if enable == 0:
174
+ return gr.Button(interactive=False, value="Switching Model...")
175
+ else:
176
+ return gr.Button(interactive=True, value="Generate")
177
+
178
+
179
  def dynamic_gallery_fn(interpolation_size: int):
180
 
181
  return gr.Gallery(
 
238
  negative_prompt=negative_prompt,
239
  guidance_scale=guidance_scale,
240
  )
241
+ if hasattr(images, "images"):
242
+ # for sdxl
243
+ images = np.array(images.images)
244
  if interpolation_size == 3:
245
  final_images = images
246
  break
 
255
 
256
  interpolation_size = None
257
 
258
+ with gr.Blocks(css="style.css") as demo:
259
  gr.Markdown(title)
260
  gr.Markdown(description)
261
  with gr.Group():
 
274
  value="A photo of car, best quality, extremely detaile",
275
  )
276
  result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3)
277
+ generate_button = gr.Button(value="Generate", variant="primary")
278
  with gr.Accordion("Advanced options", open=True):
279
  with gr.Group():
280
  with gr.Row():
 
291
  label="alpha",
292
  minimum=1,
293
  maximum=50,
294
+ step=1,
295
  value=6.0,
296
  )
297
  beta = gr.Slider(
298
  label="beta",
299
  minimum=1,
300
  maximum=50,
301
+ step=1,
302
  value=3.0,
303
  )
304
  gamma_plot = gr.LinePlot(
 
395
  label="Model",
396
  value="SD1.5-512",
397
  interactive=True,
398
+ info="SDXL will run on float16 while the rest will run on float32.",
399
  )
400
  with gr.Column():
401
  seed = gr.Slider(
 
431
  seed,
432
  same_latent,
433
  ],
 
 
434
  cache_examples=CACHE_EXAMPLES,
435
  )
436
 
 
443
  interpolation_size.change(
444
  fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
445
  )
446
+ model_choice.change(
447
+ fn=change_generate_button_fn,
448
+ inputs=gr.Number(0, visible=False),
449
+ outputs=generate_button,
450
+ ).then(fn=change_model_fn, inputs=model_choice).then(
451
+ fn=change_generate_button_fn,
452
+ inputs=gr.Number(1, visible=False),
453
+ outputs=generate_button,
454
+ )
455
  inputs = [
456
  prompt1,
457
  prompt2,
 
479
  )
480
  gr.Markdown(article)
481
 
482
+ demo.launch()
 
 
 
 
 
pipeline_interpolated_sdxl.py CHANGED
@@ -403,6 +403,12 @@ class InterpolationStableDiffusionXLPipeline(
403
  else:
404
  self.watermark = None
405
 
 
 
 
 
 
 
406
  def generate_latent(
407
  self, generator: Optional[torch.Generator] = None, torch_device: str = "cpu"
408
  ) -> torch.FloatTensor:
 
403
  else:
404
  self.watermark = None
405
 
406
+ def to(self, *args, **kwargs):
407
+ super().to(*args, **kwargs)
408
+ self.vae.to(*args, **kwargs)
409
+ self.text_encoder.to(*args, **kwargs)
410
+ self.unet.to(*args, **kwargs)
411
+
412
  def generate_latent(
413
  self, generator: Optional[torch.Generator] = None, torch_device: str = "cpu"
414
  ) -> torch.FloatTensor:
pipeline_interpolated_stable_diffusion.py CHANGED
@@ -286,7 +286,7 @@ class InterpolationStableDiffusionPipeline:
286
  noise_pred = self.unet(
287
  latent_model_input, t, encoder_hidden_states=embs
288
  ).sample
289
- attn_proc = AttnProcessor()
290
  self.unet.set_attn_processor(processor=attn_proc)
291
  noise_uncond = self.unet(
292
  latent_model_input, t, encoder_hidden_states=uncond_embs
@@ -477,7 +477,7 @@ class InterpolationStableDiffusionPipeline:
477
  t=it,
478
  is_fused=True,
479
  )
480
- self_attn_proc = AttnProcessor()
481
  procs_dict = {
482
  "pure_inner": pure_inner_attn_proc,
483
  "fused_inner": fused_inner_attn_proc,
@@ -503,7 +503,7 @@ class InterpolationStableDiffusionPipeline:
503
  noise_pred = self.unet(
504
  latent_model_input, t, encoder_hidden_states=embs
505
  ).sample
506
- attn_proc = AttnProcessor()
507
  self.unet.set_attn_processor(processor=attn_proc)
508
  noise_uncond = self.unet(
509
  latent_model_input, t, encoder_hidden_states=uncond_embs
@@ -544,7 +544,7 @@ class InterpolationStableDiffusionPipeline:
544
  Returns:
545
  numpy.ndarray: The interpolated images.
546
  """
547
- self.unet.set_attn_processor(processor=AttnProcessor())
548
  start_emb = self.prompt_to_embedding(text_1)
549
  end_emb = self.prompt_to_embedding(text_2)
550
  neg_emb = self.prompt_to_embedding(negative_prompt)
 
286
  noise_pred = self.unet(
287
  latent_model_input, t, encoder_hidden_states=embs
288
  ).sample
289
+ attn_proc = AttnProcessor2_0()
290
  self.unet.set_attn_processor(processor=attn_proc)
291
  noise_uncond = self.unet(
292
  latent_model_input, t, encoder_hidden_states=uncond_embs
 
477
  t=it,
478
  is_fused=True,
479
  )
480
+ self_attn_proc = AttnProcessor2_0()
481
  procs_dict = {
482
  "pure_inner": pure_inner_attn_proc,
483
  "fused_inner": fused_inner_attn_proc,
 
503
  noise_pred = self.unet(
504
  latent_model_input, t, encoder_hidden_states=embs
505
  ).sample
506
+ attn_proc = AttnProcessor2_0()
507
  self.unet.set_attn_processor(processor=attn_proc)
508
  noise_uncond = self.unet(
509
  latent_model_input, t, encoder_hidden_states=uncond_embs
 
544
  Returns:
545
  numpy.ndarray: The interpolated images.
546
  """
547
+ self.unet.set_attn_processor(processor=AttnProcessor2_0())
548
  start_emb = self.prompt_to_embedding(text_1)
549
  end_emb = self.prompt_to_embedding(text_2)
550
  neg_emb = self.prompt_to_embedding(negative_prompt)