king159
commited on
Commit
•
e4924c3
1
Parent(s):
771cc14
fix sdxl
Browse files- .gitattributes +1 -1
- .gitignore +1 -0
- app.py +80 -29
- pipeline_interpolated_sdxl.py +6 -0
- pipeline_interpolated_stable_diffusion.py +4 -4
.gitattributes
CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
CHANGED
@@ -32,11 +32,13 @@ article = r"""
|
|
32 |
<br>
|
33 |
If you found this demo/our paper useful, please consider citing:
|
34 |
```bibtex
|
35 |
-
@
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
}
|
41 |
```
|
42 |
📧 **Contact**
|
@@ -50,18 +52,17 @@ USE_TORCH_COMPILE = False
|
|
50 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
51 |
PREVIEW_IMAGES = False
|
52 |
|
53 |
-
dtype = torch.float32
|
54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
pipeline = InterpolationStableDiffusionPipeline(
|
56 |
repo_name="runwayml/stable-diffusion-v1-5",
|
57 |
guidance_scale=10.0,
|
58 |
scheduler_name="unipc",
|
59 |
)
|
60 |
-
pipeline.to(device, dtype=
|
61 |
|
62 |
|
63 |
def change_model_fn(model_name: str) -> None:
|
64 |
-
global
|
65 |
name_mapping = {
|
66 |
"SD1.4-521": "CompVis/stable-diffusion-v1-4",
|
67 |
"SD1.5-512": "runwayml/stable-diffusion-v1-5",
|
@@ -69,17 +70,21 @@ def change_model_fn(model_name: str) -> None:
|
|
69 |
"SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
|
70 |
}
|
71 |
if "XL" not in model_name:
|
72 |
-
pipeline = InterpolationStableDiffusionPipeline(
|
73 |
repo_name=name_mapping[model_name],
|
74 |
guidance_scale=10.0,
|
75 |
scheduler_name="unipc",
|
76 |
)
|
77 |
-
pipeline.to(device, dtype=
|
78 |
else:
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
)
|
82 |
-
pipeline.to(device
|
83 |
|
84 |
|
85 |
def save_image(img, index):
|
@@ -107,7 +112,7 @@ def plot_gemma_fn(alpha: float, beta: float, size: int) -> pd.DataFrame:
|
|
107 |
)
|
108 |
|
109 |
|
110 |
-
def get_example() -> list:
|
111 |
case = [
|
112 |
[
|
113 |
"A photo of dog, best quality, extremely detailed",
|
@@ -115,7 +120,7 @@ def get_example() -> list:
|
|
115 |
3,
|
116 |
6,
|
117 |
3,
|
118 |
-
"A
|
119 |
"monochrome, lowres, bad anatomy, worst quality, low quality",
|
120 |
"SD1.5-512",
|
121 |
6.1 / 50,
|
@@ -125,11 +130,52 @@ def get_example() -> list:
|
|
125 |
"self",
|
126 |
1002,
|
127 |
True,
|
128 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
]
|
130 |
return case
|
131 |
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def dynamic_gallery_fn(interpolation_size: int):
|
134 |
|
135 |
return gr.Gallery(
|
@@ -192,6 +238,9 @@ def generate(
|
|
192 |
negative_prompt=negative_prompt,
|
193 |
guidance_scale=guidance_scale,
|
194 |
)
|
|
|
|
|
|
|
195 |
if interpolation_size == 3:
|
196 |
final_images = images
|
197 |
break
|
@@ -206,7 +255,7 @@ def generate(
|
|
206 |
|
207 |
interpolation_size = None
|
208 |
|
209 |
-
with gr.Blocks() as demo:
|
210 |
gr.Markdown(title)
|
211 |
gr.Markdown(description)
|
212 |
with gr.Group():
|
@@ -225,7 +274,7 @@ with gr.Blocks() as demo:
|
|
225 |
value="A photo of car, best quality, extremely detaile",
|
226 |
)
|
227 |
result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3)
|
228 |
-
generate_button = gr.Button("Generate", variant="primary")
|
229 |
with gr.Accordion("Advanced options", open=True):
|
230 |
with gr.Group():
|
231 |
with gr.Row():
|
@@ -242,14 +291,14 @@ with gr.Blocks() as demo:
|
|
242 |
label="alpha",
|
243 |
minimum=1,
|
244 |
maximum=50,
|
245 |
-
step=
|
246 |
value=6.0,
|
247 |
)
|
248 |
beta = gr.Slider(
|
249 |
label="beta",
|
250 |
minimum=1,
|
251 |
maximum=50,
|
252 |
-
step=
|
253 |
value=3.0,
|
254 |
)
|
255 |
gamma_plot = gr.LinePlot(
|
@@ -346,6 +395,7 @@ with gr.Blocks() as demo:
|
|
346 |
label="Model",
|
347 |
value="SD1.5-512",
|
348 |
interactive=True,
|
|
|
349 |
)
|
350 |
with gr.Column():
|
351 |
seed = gr.Slider(
|
@@ -381,8 +431,6 @@ with gr.Blocks() as demo:
|
|
381 |
seed,
|
382 |
same_latent,
|
383 |
],
|
384 |
-
outputs=result,
|
385 |
-
fn=generate,
|
386 |
cache_examples=CACHE_EXAMPLES,
|
387 |
)
|
388 |
|
@@ -395,7 +443,15 @@ with gr.Blocks() as demo:
|
|
395 |
interpolation_size.change(
|
396 |
fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
|
397 |
)
|
398 |
-
model_choice.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
inputs = [
|
400 |
prompt1,
|
401 |
prompt2,
|
@@ -423,9 +479,4 @@ with gr.Blocks() as demo:
|
|
423 |
)
|
424 |
gr.Markdown(article)
|
425 |
|
426 |
-
|
427 |
-
with gr.Tab("App"):
|
428 |
-
demo.render()
|
429 |
-
|
430 |
-
if __name__ == "__main__":
|
431 |
-
demo_with_history.queue(max_size=20).launch()
|
|
|
32 |
<br>
|
33 |
If you found this demo/our paper useful, please consider citing:
|
34 |
```bibtex
|
35 |
+
@misc{he2024aid,
|
36 |
+
title={AID: Attention Interpolation of Text-to-Image Diffusion},
|
37 |
+
author={Qiyuan He and Jinghao Wang and Ziwei Liu and Angela Yao},
|
38 |
+
year={2024},
|
39 |
+
eprint={2403.17924},
|
40 |
+
archivePrefix={arXiv},
|
41 |
+
primaryClass={cs.CV}
|
42 |
}
|
43 |
```
|
44 |
📧 **Contact**
|
|
|
52 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
53 |
PREVIEW_IMAGES = False
|
54 |
|
|
|
55 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
pipeline = InterpolationStableDiffusionPipeline(
|
57 |
repo_name="runwayml/stable-diffusion-v1-5",
|
58 |
guidance_scale=10.0,
|
59 |
scheduler_name="unipc",
|
60 |
)
|
61 |
+
pipeline.to(device, dtype=torch.float32)
|
62 |
|
63 |
|
64 |
def change_model_fn(model_name: str) -> None:
|
65 |
+
global device
|
66 |
name_mapping = {
|
67 |
"SD1.4-521": "CompVis/stable-diffusion-v1-4",
|
68 |
"SD1.5-512": "runwayml/stable-diffusion-v1-5",
|
|
|
70 |
"SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
|
71 |
}
|
72 |
if "XL" not in model_name:
|
73 |
+
globals()["pipeline"] = InterpolationStableDiffusionPipeline(
|
74 |
repo_name=name_mapping[model_name],
|
75 |
guidance_scale=10.0,
|
76 |
scheduler_name="unipc",
|
77 |
)
|
78 |
+
globals()["pipeline"].to(device, dtype=torch.float32)
|
79 |
else:
|
80 |
+
if device == torch.device("cpu"):
|
81 |
+
dtype = torch.float32
|
82 |
+
else:
|
83 |
+
dtype = torch.float16
|
84 |
+
globals()["pipeline"] = InterpolationStableDiffusionXLPipeline.from_pretrained(
|
85 |
+
name_mapping[model_name], torch_dtype=dtype
|
86 |
)
|
87 |
+
globals()["pipeline"].to(device)
|
88 |
|
89 |
|
90 |
def save_image(img, index):
|
|
|
112 |
)
|
113 |
|
114 |
|
115 |
+
def get_example() -> list[list[str | float | int]]:
|
116 |
case = [
|
117 |
[
|
118 |
"A photo of dog, best quality, extremely detailed",
|
|
|
120 |
3,
|
121 |
6,
|
122 |
3,
|
123 |
+
"A car with dog furry texture, best quality, extremely detailed",
|
124 |
"monochrome, lowres, bad anatomy, worst quality, low quality",
|
125 |
"SD1.5-512",
|
126 |
6.1 / 50,
|
|
|
130 |
"self",
|
131 |
1002,
|
132 |
True,
|
133 |
+
],
|
134 |
+
[
|
135 |
+
"A photo of dog, best quality, extremely detailed",
|
136 |
+
"A photo of car, best quality, extremely detailed",
|
137 |
+
7,
|
138 |
+
8,
|
139 |
+
8,
|
140 |
+
"A toy named dog-car, best quality, extremely detailed",
|
141 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality",
|
142 |
+
"SD1.5-512",
|
143 |
+
8.1 / 50,
|
144 |
+
10,
|
145 |
+
50,
|
146 |
+
"fused_inner",
|
147 |
+
"self",
|
148 |
+
1002,
|
149 |
+
True,
|
150 |
+
],
|
151 |
+
[
|
152 |
+
"anime artwork a Pokemon called Pikachu sitting on the grass, dramatic, anime style, key visual, vibrant, studio anime, highly detailed",
|
153 |
+
"anime artwork a beautiful girl, dramatic, anime style, key visual, vibrant, studio anime, highly detailed",
|
154 |
+
7,
|
155 |
+
3,
|
156 |
+
3,
|
157 |
+
None,
|
158 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality",
|
159 |
+
"SDXL-1024",
|
160 |
+
25 / 50,
|
161 |
+
10,
|
162 |
+
50,
|
163 |
+
"fused_outer",
|
164 |
+
"self",
|
165 |
+
1002,
|
166 |
+
False,
|
167 |
+
],
|
168 |
]
|
169 |
return case
|
170 |
|
171 |
|
172 |
+
def change_generate_button_fn(enable: int) -> gr.Button:
|
173 |
+
if enable == 0:
|
174 |
+
return gr.Button(interactive=False, value="Switching Model...")
|
175 |
+
else:
|
176 |
+
return gr.Button(interactive=True, value="Generate")
|
177 |
+
|
178 |
+
|
179 |
def dynamic_gallery_fn(interpolation_size: int):
|
180 |
|
181 |
return gr.Gallery(
|
|
|
238 |
negative_prompt=negative_prompt,
|
239 |
guidance_scale=guidance_scale,
|
240 |
)
|
241 |
+
if hasattr(images, "images"):
|
242 |
+
# for sdxl
|
243 |
+
images = np.array(images.images)
|
244 |
if interpolation_size == 3:
|
245 |
final_images = images
|
246 |
break
|
|
|
255 |
|
256 |
interpolation_size = None
|
257 |
|
258 |
+
with gr.Blocks(css="style.css") as demo:
|
259 |
gr.Markdown(title)
|
260 |
gr.Markdown(description)
|
261 |
with gr.Group():
|
|
|
274 |
value="A photo of car, best quality, extremely detaile",
|
275 |
)
|
276 |
result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3)
|
277 |
+
generate_button = gr.Button(value="Generate", variant="primary")
|
278 |
with gr.Accordion("Advanced options", open=True):
|
279 |
with gr.Group():
|
280 |
with gr.Row():
|
|
|
291 |
label="alpha",
|
292 |
minimum=1,
|
293 |
maximum=50,
|
294 |
+
step=1,
|
295 |
value=6.0,
|
296 |
)
|
297 |
beta = gr.Slider(
|
298 |
label="beta",
|
299 |
minimum=1,
|
300 |
maximum=50,
|
301 |
+
step=1,
|
302 |
value=3.0,
|
303 |
)
|
304 |
gamma_plot = gr.LinePlot(
|
|
|
395 |
label="Model",
|
396 |
value="SD1.5-512",
|
397 |
interactive=True,
|
398 |
+
info="SDXL will run on float16 while the rest will run on float32.",
|
399 |
)
|
400 |
with gr.Column():
|
401 |
seed = gr.Slider(
|
|
|
431 |
seed,
|
432 |
same_latent,
|
433 |
],
|
|
|
|
|
434 |
cache_examples=CACHE_EXAMPLES,
|
435 |
)
|
436 |
|
|
|
443 |
interpolation_size.change(
|
444 |
fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
|
445 |
)
|
446 |
+
model_choice.change(
|
447 |
+
fn=change_generate_button_fn,
|
448 |
+
inputs=gr.Number(0, visible=False),
|
449 |
+
outputs=generate_button,
|
450 |
+
).then(fn=change_model_fn, inputs=model_choice).then(
|
451 |
+
fn=change_generate_button_fn,
|
452 |
+
inputs=gr.Number(1, visible=False),
|
453 |
+
outputs=generate_button,
|
454 |
+
)
|
455 |
inputs = [
|
456 |
prompt1,
|
457 |
prompt2,
|
|
|
479 |
)
|
480 |
gr.Markdown(article)
|
481 |
|
482 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
pipeline_interpolated_sdxl.py
CHANGED
@@ -403,6 +403,12 @@ class InterpolationStableDiffusionXLPipeline(
|
|
403 |
else:
|
404 |
self.watermark = None
|
405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
def generate_latent(
|
407 |
self, generator: Optional[torch.Generator] = None, torch_device: str = "cpu"
|
408 |
) -> torch.FloatTensor:
|
|
|
403 |
else:
|
404 |
self.watermark = None
|
405 |
|
406 |
+
def to(self, *args, **kwargs):
|
407 |
+
super().to(*args, **kwargs)
|
408 |
+
self.vae.to(*args, **kwargs)
|
409 |
+
self.text_encoder.to(*args, **kwargs)
|
410 |
+
self.unet.to(*args, **kwargs)
|
411 |
+
|
412 |
def generate_latent(
|
413 |
self, generator: Optional[torch.Generator] = None, torch_device: str = "cpu"
|
414 |
) -> torch.FloatTensor:
|
pipeline_interpolated_stable_diffusion.py
CHANGED
@@ -286,7 +286,7 @@ class InterpolationStableDiffusionPipeline:
|
|
286 |
noise_pred = self.unet(
|
287 |
latent_model_input, t, encoder_hidden_states=embs
|
288 |
).sample
|
289 |
-
attn_proc =
|
290 |
self.unet.set_attn_processor(processor=attn_proc)
|
291 |
noise_uncond = self.unet(
|
292 |
latent_model_input, t, encoder_hidden_states=uncond_embs
|
@@ -477,7 +477,7 @@ class InterpolationStableDiffusionPipeline:
|
|
477 |
t=it,
|
478 |
is_fused=True,
|
479 |
)
|
480 |
-
self_attn_proc =
|
481 |
procs_dict = {
|
482 |
"pure_inner": pure_inner_attn_proc,
|
483 |
"fused_inner": fused_inner_attn_proc,
|
@@ -503,7 +503,7 @@ class InterpolationStableDiffusionPipeline:
|
|
503 |
noise_pred = self.unet(
|
504 |
latent_model_input, t, encoder_hidden_states=embs
|
505 |
).sample
|
506 |
-
attn_proc =
|
507 |
self.unet.set_attn_processor(processor=attn_proc)
|
508 |
noise_uncond = self.unet(
|
509 |
latent_model_input, t, encoder_hidden_states=uncond_embs
|
@@ -544,7 +544,7 @@ class InterpolationStableDiffusionPipeline:
|
|
544 |
Returns:
|
545 |
numpy.ndarray: The interpolated images.
|
546 |
"""
|
547 |
-
self.unet.set_attn_processor(processor=
|
548 |
start_emb = self.prompt_to_embedding(text_1)
|
549 |
end_emb = self.prompt_to_embedding(text_2)
|
550 |
neg_emb = self.prompt_to_embedding(negative_prompt)
|
|
|
286 |
noise_pred = self.unet(
|
287 |
latent_model_input, t, encoder_hidden_states=embs
|
288 |
).sample
|
289 |
+
attn_proc = AttnProcessor2_0()
|
290 |
self.unet.set_attn_processor(processor=attn_proc)
|
291 |
noise_uncond = self.unet(
|
292 |
latent_model_input, t, encoder_hidden_states=uncond_embs
|
|
|
477 |
t=it,
|
478 |
is_fused=True,
|
479 |
)
|
480 |
+
self_attn_proc = AttnProcessor2_0()
|
481 |
procs_dict = {
|
482 |
"pure_inner": pure_inner_attn_proc,
|
483 |
"fused_inner": fused_inner_attn_proc,
|
|
|
503 |
noise_pred = self.unet(
|
504 |
latent_model_input, t, encoder_hidden_states=embs
|
505 |
).sample
|
506 |
+
attn_proc = AttnProcessor2_0()
|
507 |
self.unet.set_attn_processor(processor=attn_proc)
|
508 |
noise_uncond = self.unet(
|
509 |
latent_model_input, t, encoder_hidden_states=uncond_embs
|
|
|
544 |
Returns:
|
545 |
numpy.ndarray: The interpolated images.
|
546 |
"""
|
547 |
+
self.unet.set_attn_processor(processor=AttnProcessor2_0())
|
548 |
start_emb = self.prompt_to_embedding(text_1)
|
549 |
end_emb = self.prompt_to_embedding(text_2)
|
550 |
neg_emb = self.prompt_to_embedding(negative_prompt)
|