Spaces:
Sleeping
Sleeping
Tony Lian
commited on
Commit
•
89f6983
1
Parent(s):
e32648c
Update: add attention guidance and refactor the code
Browse files- app.py +77 -149
- examples.py +56 -6
- generation.py +412 -130
- models/modeling_utils.py +0 -874
- models/pipelines.py +352 -2
- models/sam.py +4 -2
- utils/attn.py +140 -0
- utils/boxdiff.py +259 -0
- utils/guidance.py +358 -0
- utils/latents.py +3 -2
- utils/parse.py +93 -18
- utils/utils.py +0 -1
- utils/vis.py +153 -0
app.py
CHANGED
@@ -1,65 +1,27 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
-
import
|
4 |
-
from matplotlib.patches import Polygon
|
5 |
-
from matplotlib.collections import PatchCollection
|
6 |
import matplotlib.pyplot as plt
|
7 |
-
from utils.parse import filter_boxes
|
8 |
from generation import run as run_ours
|
9 |
from baseline import run as run_baseline
|
10 |
import torch
|
11 |
from shared import DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
|
12 |
-
from examples import stage1_examples, stage2_examples
|
13 |
|
14 |
-
|
15 |
-
if torch.cuda.is_available():
|
16 |
-
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
17 |
-
|
18 |
-
box_scale = (512, 512)
|
19 |
-
size = box_scale
|
20 |
-
|
21 |
-
bg_prompt_text = "Background prompt: "
|
22 |
-
|
23 |
-
default_template = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512, and the bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object. Make the boxes larger if possible. Do not put objects that are already provided in the bounding boxes into the background prompt. If needed, you can make reasonable guesses. Generate the object descriptions and background prompts in English even if the caption might not be in English. Do not include non-existing or excluded objects in the background prompt. Please refer to the example below for the desired format.
|
24 |
-
|
25 |
-
Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
|
26 |
-
Objects: [('a green car', [21, 181, 211, 159]), ('a blue truck', [269, 181, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
|
27 |
-
Background prompt: A realistic image of a landscape scene
|
28 |
-
|
29 |
-
Caption: A watercolor painting of a wooden table in the living room with an apple on it
|
30 |
-
Objects: [('a wooden table', [65, 243, 344, 206]), ('a apple', [206, 306, 81, 69])]
|
31 |
-
Background prompt: A watercolor painting of a living room
|
32 |
-
|
33 |
-
Caption: A watercolor painting of two pandas eating bamboo in a forest
|
34 |
-
Objects: [('a panda eating bambooo', [30, 171, 212, 226]), ('a panda eating bambooo', [264, 173, 222, 221])]
|
35 |
-
Background prompt: A watercolor painting of a forest
|
36 |
-
|
37 |
-
Caption: A realistic image of four skiers standing in a line on the snow near a palm tree
|
38 |
-
Objects: [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 180, 103, 180])]
|
39 |
-
Background prompt: A realistic image of an outdoor scene with snow
|
40 |
|
41 |
-
|
42 |
-
Objects: [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]
|
43 |
-
Background prompt: An oil painting of the sea
|
44 |
-
|
45 |
-
Caption: A realistic image of a cat playing with a dog in a park with flowers
|
46 |
-
Objects: [('a playful cat', [51, 67, 271, 324]), ('a playful dog', [302, 119, 211, 228])]
|
47 |
-
Background prompt: A realistic image of a park with flowers
|
48 |
-
|
49 |
-
Caption: 一个客厅场景的油画,墙上挂着电视,电视下面是一个柜子,柜子上有一个花瓶。
|
50 |
-
Objects: [('a tv', [88, 85, 335, 203]), ('a cabinet', [57, 308, 404, 201]), ('a flower vase', [166, 222, 92, 108])]
|
51 |
-
Background prompt: An oil painting of a living room scene"""
|
52 |
-
|
53 |
-
simplified_prompt = """{template}
|
54 |
-
|
55 |
-
Caption: {prompt}
|
56 |
-
Objects: """
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
Background prompt: A realistic photo of a grassy area."""
|
63 |
|
64 |
def get_lmd_prompt(prompt, template=default_template):
|
65 |
if prompt == "":
|
@@ -71,10 +33,10 @@ def get_lmd_prompt(prompt, template=default_template):
|
|
71 |
def get_layout_image(response):
|
72 |
if response == "":
|
73 |
response = layout_placeholder
|
74 |
-
gen_boxes, bg_prompt =
|
75 |
fig = plt.figure(figsize=(8, 8))
|
76 |
# https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
|
77 |
-
show_boxes(gen_boxes, bg_prompt)
|
78 |
# If we haven't already shown or saved the plot, then we need to
|
79 |
# draw the figure first...
|
80 |
fig.canvas.draw()
|
@@ -88,32 +50,41 @@ def get_layout_image(response):
|
|
88 |
def get_layout_image_gallery(response):
|
89 |
return [get_layout_image(response)]
|
90 |
|
91 |
-
def get_ours_image(response, overall_prompt_override="", seed=0, num_inference_steps=
|
92 |
if response == "":
|
93 |
response = layout_placeholder
|
94 |
-
gen_boxes, bg_prompt =
|
95 |
gen_boxes = filter_boxes(gen_boxes, scale_boxes=scale_boxes)
|
96 |
spec = {
|
97 |
# prompt is unused
|
98 |
'prompt': '',
|
99 |
'gen_boxes': gen_boxes,
|
100 |
-
'bg_prompt': bg_prompt
|
|
|
101 |
}
|
102 |
|
103 |
if dpm_scheduler:
|
104 |
scheduler_key = "dpm_scheduler"
|
105 |
else:
|
106 |
scheduler_key = "scheduler"
|
107 |
-
|
|
|
|
|
108 |
image_np, so_img_list = run_ours(
|
109 |
spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
|
110 |
fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
|
111 |
-
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
|
112 |
-
so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt,
|
|
|
113 |
)
|
114 |
images = [image_np]
|
115 |
if show_so_imgs:
|
116 |
images.extend([np.asarray(so_img) for so_img in so_img_list])
|
|
|
|
|
|
|
|
|
|
|
117 |
return images
|
118 |
|
119 |
def get_baseline_image(prompt, seed=0):
|
@@ -126,73 +97,6 @@ def get_baseline_image(prompt, seed=0):
|
|
126 |
image_np = run_baseline(prompt, bg_seed=seed, scheduler_key=scheduler_key, num_inference_steps=num_inference_steps)
|
127 |
return [image_np]
|
128 |
|
129 |
-
def parse_input(text=None):
|
130 |
-
try:
|
131 |
-
if "Objects: " in text:
|
132 |
-
text = text.split("Objects: ")[1]
|
133 |
-
|
134 |
-
text_split = text.split(bg_prompt_text)
|
135 |
-
if len(text_split) == 2:
|
136 |
-
gen_boxes, bg_prompt = text_split
|
137 |
-
gen_boxes = ast.literal_eval(gen_boxes)
|
138 |
-
bg_prompt = bg_prompt.strip()
|
139 |
-
except Exception as e:
|
140 |
-
raise gr.Error(f"response format invalid: {e} (text: {text})")
|
141 |
-
|
142 |
-
return gen_boxes, bg_prompt
|
143 |
-
|
144 |
-
def draw_boxes(anns):
|
145 |
-
ax = plt.gca()
|
146 |
-
ax.set_autoscale_on(False)
|
147 |
-
polygons = []
|
148 |
-
color = []
|
149 |
-
for ann in anns:
|
150 |
-
c = (np.random.random((1, 3))*0.6+0.4)
|
151 |
-
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
|
152 |
-
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
|
153 |
-
[bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
154 |
-
np_poly = np.array(poly).reshape((4, 2))
|
155 |
-
polygons.append(Polygon(np_poly))
|
156 |
-
color.append(c)
|
157 |
-
|
158 |
-
# print(ann)
|
159 |
-
name = ann['name'] if 'name' in ann else str(ann['category_id'])
|
160 |
-
ax.text(bbox_x, bbox_y, name, style='italic',
|
161 |
-
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
|
162 |
-
|
163 |
-
p = PatchCollection(polygons, facecolor='none',
|
164 |
-
edgecolors=color, linewidths=2)
|
165 |
-
ax.add_collection(p)
|
166 |
-
|
167 |
-
|
168 |
-
def show_boxes(gen_boxes, bg_prompt=None):
|
169 |
-
anns = [{'name': gen_box[0], 'bbox': gen_box[1]}
|
170 |
-
for gen_box in gen_boxes]
|
171 |
-
|
172 |
-
# White background (to allow line to show on the edge)
|
173 |
-
I = np.ones((size[0]+4, size[1]+4, 3), dtype=np.uint8) * 255
|
174 |
-
|
175 |
-
plt.imshow(I)
|
176 |
-
plt.axis('off')
|
177 |
-
|
178 |
-
if bg_prompt is not None:
|
179 |
-
ax = plt.gca()
|
180 |
-
ax.text(0, 0, bg_prompt, style='italic',
|
181 |
-
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
|
182 |
-
|
183 |
-
c = np.zeros((1, 3))
|
184 |
-
[bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0])
|
185 |
-
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
|
186 |
-
[bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
187 |
-
np_poly = np.array(poly).reshape((4, 2))
|
188 |
-
polygons = [Polygon(np_poly)]
|
189 |
-
color = [c]
|
190 |
-
p = PatchCollection(polygons, facecolor='none',
|
191 |
-
edgecolors=color, linewidths=2)
|
192 |
-
ax.add_collection(p)
|
193 |
-
|
194 |
-
draw_boxes(anns)
|
195 |
-
|
196 |
duplicate_html = '<a style="display:inline-block" href="https://huggingface.co/spaces/longlian/llm-grounded-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>'
|
197 |
|
198 |
html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
|
@@ -200,15 +104,28 @@ html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to
|
|
200 |
<h2><a href='https://llm-grounded-diffusion.github.io/'>Project Page</a> | <a href='https://bair.berkeley.edu/blog/2023/05/23/lmd/'>5-minute Blog Post</a> | <a href='https://arxiv.org/pdf/2305.13655.pdf'>ArXiv Paper</a> | <a href='https://github.com/TonyLianLong/LLM-groundedDiffusion'>Github</a> | <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</h2>
|
201 |
<p><b>Tips:</b><p>
|
202 |
<p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
|
203 |
-
<p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the
|
204 |
-
<p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.</p>
|
205 |
<p>4. The diffusion model only runs 20 steps by default in this demo. You can make it run more steps to get higher quality images (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
|
206 |
<p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (<b>Currently we are using a T4 GPU on this space, which is quite slow, and you can add a A10G to make it 5x faster</b>) {duplicate_html}</p>
|
207 |
<br/>
|
208 |
-
<p>Implementation note: In this demo, we
|
209 |
-
<style>.btn {{flex-grow: unset !important;}} </
|
210 |
"""
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
with gr.Blocks(
|
213 |
title="LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models"
|
214 |
) as g:
|
@@ -230,42 +147,53 @@ with gr.Blocks(
|
|
230 |
inputs=[prompt],
|
231 |
outputs=[output],
|
232 |
fn=get_lmd_prompt,
|
233 |
-
cache_examples=
|
|
|
234 |
)
|
235 |
|
236 |
with gr.Tab("Stage 2 (New). Layout to Image generation"):
|
237 |
with gr.Row():
|
238 |
with gr.Column(scale=1):
|
239 |
-
|
240 |
-
|
241 |
-
num_inference_steps = gr.Slider(1, 250, value=
|
|
|
|
|
242 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
243 |
with gr.Accordion("Advanced options (play around for better generation)", open=False):
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
254 |
visualize_btn = gr.Button("Visualize Layout", elem_classes="btn")
|
255 |
generate_btn = gr.Button("Generate Image from Layout", variant='primary', elem_classes="btn")
|
256 |
with gr.Column(scale=1):
|
257 |
gallery = gr.Gallery(
|
258 |
label="Generated image", show_label=False, elem_id="gallery", columns=[1], rows=[1], object_fit="contain", preview=True
|
259 |
)
|
|
|
|
|
260 |
visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
|
261 |
-
generate_btn.click(fn=get_ours_image, inputs=[response, overall_prompt_override, seed, num_inference_steps, dpm_scheduler, use_autocast, fg_seed_start, fg_blending_ratio, frozen_step_ratio, gligen_scheduled_sampling_beta, so_negative_prompt, overall_negative_prompt, show_so_imgs, scale_boxes], outputs=gallery, api_name="layout-to-image")
|
262 |
|
263 |
gr.Examples(
|
264 |
examples=stage2_examples,
|
265 |
inputs=[response, overall_prompt_override, seed],
|
266 |
outputs=[gallery],
|
267 |
fn=get_ours_image,
|
268 |
-
cache_examples=
|
|
|
269 |
)
|
270 |
|
271 |
with gr.Tab("Baseline: Stable Diffusion"):
|
@@ -274,8 +202,7 @@ with gr.Blocks(
|
|
274 |
sd_prompt = gr.Textbox(lines=2, label="Prompt for baseline SD", placeholder=prompt_placeholder)
|
275 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
276 |
generate_btn = gr.Button("Generate", elem_classes="btn")
|
277 |
-
|
278 |
-
# output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img")
|
279 |
with gr.Column(scale=1):
|
280 |
gallery = gr.Gallery(
|
281 |
label="Generated image", show_label=False, elem_id="gallery2", columns=[1], rows=[1], object_fit="contain", preview=True
|
@@ -287,7 +214,8 @@ with gr.Blocks(
|
|
287 |
inputs=[sd_prompt],
|
288 |
outputs=[gallery],
|
289 |
fn=get_baseline_image,
|
290 |
-
cache_examples=
|
|
|
291 |
)
|
292 |
|
293 |
g.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
+
import os
|
|
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
from utils.parse import filter_boxes, parse_input_with_negative, show_boxes
|
6 |
from generation import run as run_ours
|
7 |
from baseline import run as run_baseline
|
8 |
import torch
|
9 |
from shared import DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
|
10 |
+
from examples import stage1_examples, stage2_examples, default_template, simplified_prompt, prompt_placeholder, layout_placeholder
|
11 |
|
12 |
+
cuda_available = torch.cuda.is_available()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
if cuda_available:
|
17 |
+
gpu_memory = torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory
|
18 |
+
low_memory = gpu_memory <= 16 * 1024 ** 3
|
19 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}. With GPU memory: {gpu_memory}. Low memory: {low_memory}")
|
20 |
+
else:
|
21 |
+
low_memory = False
|
22 |
|
23 |
+
cache_examples = True
|
24 |
+
default_num_inference_steps = 20 if low_memory else 50
|
|
|
25 |
|
26 |
def get_lmd_prompt(prompt, template=default_template):
|
27 |
if prompt == "":
|
|
|
33 |
def get_layout_image(response):
|
34 |
if response == "":
|
35 |
response = layout_placeholder
|
36 |
+
gen_boxes, bg_prompt, neg_prompt = parse_input_with_negative(response, no_input=True)
|
37 |
fig = plt.figure(figsize=(8, 8))
|
38 |
# https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
|
39 |
+
show_boxes(gen_boxes, bg_prompt, neg_prompt)
|
40 |
# If we haven't already shown or saved the plot, then we need to
|
41 |
# draw the figure first...
|
42 |
fig.canvas.draw()
|
|
|
50 |
def get_layout_image_gallery(response):
|
51 |
return [get_layout_image(response)]
|
52 |
|
53 |
+
def get_ours_image(response, overall_prompt_override="", seed=0, num_inference_steps=250, dpm_scheduler=True, use_autocast=False, fg_seed_start=20, fg_blending_ratio=0.1, frozen_step_ratio=0.5, attn_guidance_step_ratio=0.6, gligen_scheduled_sampling_beta=0.4, attn_guidance_scale=20, use_ref_ca=True, so_negative_prompt=DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT, show_so_imgs=False, scale_boxes=False):
|
54 |
if response == "":
|
55 |
response = layout_placeholder
|
56 |
+
gen_boxes, bg_prompt, neg_prompt = parse_input_with_negative(response, no_input=True)
|
57 |
gen_boxes = filter_boxes(gen_boxes, scale_boxes=scale_boxes)
|
58 |
spec = {
|
59 |
# prompt is unused
|
60 |
'prompt': '',
|
61 |
'gen_boxes': gen_boxes,
|
62 |
+
'bg_prompt': bg_prompt,
|
63 |
+
'extra_neg_prompt': neg_prompt
|
64 |
}
|
65 |
|
66 |
if dpm_scheduler:
|
67 |
scheduler_key = "dpm_scheduler"
|
68 |
else:
|
69 |
scheduler_key = "scheduler"
|
70 |
+
|
71 |
+
overall_max_index_step = int(attn_guidance_step_ratio * num_inference_steps)
|
72 |
+
|
73 |
image_np, so_img_list = run_ours(
|
74 |
spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
|
75 |
fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
|
76 |
+
so_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, overall_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
|
77 |
+
use_ref_ca=use_ref_ca, so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt,
|
78 |
+
loss_scale=attn_guidance_scale, max_index_step=0, overall_loss_scale=attn_guidance_scale, overall_max_index_step=overall_max_index_step,
|
79 |
)
|
80 |
images = [image_np]
|
81 |
if show_so_imgs:
|
82 |
images.extend([np.asarray(so_img) for so_img in so_img_list])
|
83 |
+
|
84 |
+
if cuda_available:
|
85 |
+
print(f"Max GPU memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")
|
86 |
+
torch.cuda.reset_max_memory_allocated()
|
87 |
+
|
88 |
return images
|
89 |
|
90 |
def get_baseline_image(prompt, seed=0):
|
|
|
97 |
image_np = run_baseline(prompt, bg_seed=seed, scheduler_key=scheduler_key, num_inference_steps=num_inference_steps)
|
98 |
return [image_np]
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
duplicate_html = '<a style="display:inline-block" href="https://huggingface.co/spaces/longlian/llm-grounded-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>'
|
101 |
|
102 |
html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
|
|
|
104 |
<h2><a href='https://llm-grounded-diffusion.github.io/'>Project Page</a> | <a href='https://bair.berkeley.edu/blog/2023/05/23/lmd/'>5-minute Blog Post</a> | <a href='https://arxiv.org/pdf/2305.13655.pdf'>ArXiv Paper</a> | <a href='https://github.com/TonyLianLong/LLM-groundedDiffusion'>Github</a> | <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</h2>
|
105 |
<p><b>Tips:</b><p>
|
106 |
<p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
|
107 |
+
<p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the objects bigger or move the objects).</p>
|
108 |
+
<p>3. You can also try prompts in Simplified Chinese. You need to leave "prompt for overall image" empty in this case. If you want to try prompts in another language, translate the first line of last example to your language.</p>
|
109 |
<p>4. The diffusion model only runs 20 steps by default in this demo. You can make it run more steps to get higher quality images (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
|
110 |
<p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (<b>Currently we are using a T4 GPU on this space, which is quite slow, and you can add a A10G to make it 5x faster</b>) {duplicate_html}</p>
|
111 |
<br/>
|
112 |
+
<p>Implementation note (updated): In this demo, we provide a few modes: faster generation by disabling attention/per-box guidance. The standard version describes what is implemented for the paper. You can set GLIGEN guidance steps ratio to 0 to disable GLIGEN and use only the original SD weights.</p>
|
113 |
+
<style>.btn {{flex-grow: unset !important;}} </p>
|
114 |
"""
|
115 |
|
116 |
+
def preset_change(preset):
|
117 |
+
# frozen_step_ratio, attn_guidance_step_ratio, attn_guidance_scale, use_ref_ca, so_negative_prompt
|
118 |
+
if preset == "Standard":
|
119 |
+
return gr.update(value=0.5, interactive=True), gr.update(value=0.6, interactive=True), gr.update(interactive=True), gr.update(value=True, interactive=True), gr.update(interactive=True)
|
120 |
+
elif preset == "Faster (disable attention guidance)":
|
121 |
+
return gr.update(value=0.5, interactive=True), gr.update(value=0, interactive=False), gr.update(interactive=False), gr.update(value=True, interactive=True), gr.update(interactive=True)
|
122 |
+
elif preset == "Faster (disable per-box guidance)":
|
123 |
+
return gr.update(value=0, interactive=False), gr.update(value=0.6, interactive=True), gr.update(interactive=True), gr.update(value=False, interactive=False), gr.update(interactive=False)
|
124 |
+
elif preset == "Fastest (disable both)":
|
125 |
+
return gr.update(value=0, interactive=False), gr.update(value=0, interactive=False), gr.update(interactive=False), gr.update(value=False, interactive=False), gr.update(interactive=True)
|
126 |
+
else:
|
127 |
+
raise gr.Error(f"Unknown preset {preset}")
|
128 |
+
|
129 |
with gr.Blocks(
|
130 |
title="LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models"
|
131 |
) as g:
|
|
|
147 |
inputs=[prompt],
|
148 |
outputs=[output],
|
149 |
fn=get_lmd_prompt,
|
150 |
+
cache_examples=cache_examples,
|
151 |
+
label="example_stage1"
|
152 |
)
|
153 |
|
154 |
with gr.Tab("Stage 2 (New). Layout to Image generation"):
|
155 |
with gr.Row():
|
156 |
with gr.Column(scale=1):
|
157 |
+
overall_prompt_override = gr.Textbox(lines=2, label="Prompt for the overall image (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
|
158 |
+
response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed here)", placeholder=layout_placeholder)
|
159 |
+
num_inference_steps = gr.Slider(1, 100 if low_memory else 250, value=default_num_inference_steps, step=1, label="Number of denoising steps (set to >=50 for higher generation quality)")
|
160 |
+
# Using a environment variable allows setting default to faster/fastest on low-end GPUs.
|
161 |
+
preset = gr.Radio(label="Guidance: apply less control for faster generation", choices=["Standard", "Faster (disable attention guidance)", "Faster (disable per-box guidance)", "Fastest (disable both)"], value="Faster (disable attention guidance)" if low_memory else "Standard")
|
162 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
163 |
with gr.Accordion("Advanced options (play around for better generation)", open=False):
|
164 |
+
with gr.Tab("Guidance"):
|
165 |
+
frozen_step_ratio = gr.Slider(0, 1, value=0.5, step=0.1, label="Foreground frozen steps ratio (higher: stronger attribute binding; lower: higher coherence")
|
166 |
+
gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.4, step=0.1, label="GLIGEN guidance steps ratio (the beta value, higher: stronger GLIGEN guidance)")
|
167 |
+
attn_guidance_step_ratio = gr.Slider(0, 1, value=0.6, step=0.01, label="Attention guidance steps ratio (higher: stronger attention guidance; lower: faster and higher coherence")
|
168 |
+
attn_guidance_scale = gr.Slider(0, 50, value=20, step=0.5, label="Attention guidance scale: 0 means no attention guidance.")
|
169 |
+
use_ref_ca = gr.Checkbox(label="Using per-box attention to guide reference attention", show_label=False, value=True)
|
170 |
+
with gr.Tab("Generation"):
|
171 |
+
dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
|
172 |
+
use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)" + " [enabled due to low GPU memory]" if low_memory else "", show_label=False, value=True, interactive=not low_memory)
|
173 |
+
fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
|
174 |
+
fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
|
175 |
+
scale_boxes = gr.Checkbox(label="Scale bounding boxes to just fit the scene", show_label=False, value=False)
|
176 |
+
so_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for single object generation", value=DEFAULT_SO_NEGATIVE_PROMPT)
|
177 |
+
overall_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for overall generation", value=DEFAULT_OVERALL_NEGATIVE_PROMPT)
|
178 |
+
show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False, value=False)
|
179 |
visualize_btn = gr.Button("Visualize Layout", elem_classes="btn")
|
180 |
generate_btn = gr.Button("Generate Image from Layout", variant='primary', elem_classes="btn")
|
181 |
with gr.Column(scale=1):
|
182 |
gallery = gr.Gallery(
|
183 |
label="Generated image", show_label=False, elem_id="gallery", columns=[1], rows=[1], object_fit="contain", preview=True
|
184 |
)
|
185 |
+
preset.change(preset_change, [preset], [frozen_step_ratio, attn_guidance_step_ratio, attn_guidance_scale, use_ref_ca, so_negative_prompt])
|
186 |
+
prompt.change(None, [prompt], overall_prompt_override, _js="(x) => x")
|
187 |
visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
|
188 |
+
generate_btn.click(fn=get_ours_image, inputs=[response, overall_prompt_override, seed, num_inference_steps, dpm_scheduler, use_autocast, fg_seed_start, fg_blending_ratio, frozen_step_ratio, attn_guidance_step_ratio, gligen_scheduled_sampling_beta, attn_guidance_scale, use_ref_ca, so_negative_prompt, overall_negative_prompt, show_so_imgs, scale_boxes], outputs=gallery, api_name="layout-to-image")
|
189 |
|
190 |
gr.Examples(
|
191 |
examples=stage2_examples,
|
192 |
inputs=[response, overall_prompt_override, seed],
|
193 |
outputs=[gallery],
|
194 |
fn=get_ours_image,
|
195 |
+
cache_examples=cache_examples,
|
196 |
+
label="example_ours"
|
197 |
)
|
198 |
|
199 |
with gr.Tab("Baseline: Stable Diffusion"):
|
|
|
202 |
sd_prompt = gr.Textbox(lines=2, label="Prompt for baseline SD", placeholder=prompt_placeholder)
|
203 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
204 |
generate_btn = gr.Button("Generate", elem_classes="btn")
|
205 |
+
|
|
|
206 |
with gr.Column(scale=1):
|
207 |
gallery = gr.Gallery(
|
208 |
label="Generated image", show_label=False, elem_id="gallery2", columns=[1], rows=[1], object_fit="contain", preview=True
|
|
|
214 |
inputs=[sd_prompt],
|
215 |
outputs=[gallery],
|
216 |
fn=get_baseline_image,
|
217 |
+
cache_examples=cache_examples,
|
218 |
+
label="example_sd"
|
219 |
)
|
220 |
|
221 |
g.launch()
|
examples.py
CHANGED
@@ -1,3 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
stage1_examples = [
|
2 |
["""A realistic photo of a wooden table with an apple on the left and a pear on the right."""],
|
3 |
["""A realistic photo of 4 TVs on a wall."""],
|
@@ -10,25 +52,33 @@ stage1_examples = [
|
|
10 |
|
11 |
# Layout, seed
|
12 |
stage2_examples = [
|
13 |
-
["""Caption: A realistic
|
14 |
Objects: [('a wooden table', [30, 30, 452, 452]), ('an apple', [52, 223, 50, 60]), ('a pear', [400, 240, 50, 60])]
|
15 |
-
Background prompt: A realistic
|
16 |
["""Caption: A realistic photo of 4 TVs on a wall.
|
17 |
Objects: [('a TV', [12, 108, 120, 100]), ('a TV', [132, 112, 120, 100]), ('a TV', [252, 104, 120, 100]), ('a TV', [372, 106, 120, 100])]
|
18 |
-
Background prompt: A realistic photo of a wall""", "", 0],
|
19 |
["""Caption: A realistic photo of a gray cat and an orange dog on the grass.
|
20 |
Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
|
21 |
-
Background prompt: A realistic photo of a grassy area.""", "", 0],
|
22 |
["""Caption: 一个室内场景的水彩画,一个桌子上面放着一盘水果
|
23 |
Objects: [('a table', [81, 242, 350, 210]), ('a plate of fruits', [151, 287, 210, 117])]
|
24 |
Background prompt: A watercolor painting of an indoor scene""", "", 1],
|
25 |
["""Caption: In an empty indoor scene, a blue cube directly above a red cube with a vase on the left of them.
|
26 |
Objects: [('a blue cube', [232, 116, 76, 76]), ('a red cube', [232, 212, 76, 76]), ('a vase', [100, 198, 62, 144])]
|
27 |
-
Background prompt: An empty indoor scene""", "", 2],
|
28 |
["""Caption: A realistic photo of a wooden table without bananas in an indoor scene
|
29 |
Objects: [('a wooden table', [75, 256, 365, 156])]
|
30 |
-
Background prompt: A realistic photo of an indoor scene
|
|
|
31 |
["""Caption: A realistic photo of two cars on the road.
|
32 |
Objects: [('a car', [20, 242, 235, 185]), ('a car', [275, 246, 215, 180])]
|
33 |
Background prompt: A realistic photo of a road.""", "A realistic photo of two cars on the road.", 4],
|
34 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_template = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512. The top-left corner has coordinate [0, 0]. The bottom-right corner has coordinnate [512, 512]. The bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object (i.e., start the object name with "a" or "an" if possible). Do not put objects that are already provided in the bounding boxes into the background prompt. Do not include non-existing or excluded objects in the background prompt. If needed, you can make reasonable guesses. Please refer to the example below for the desired format.
|
2 |
+
|
3 |
+
Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
|
4 |
+
Objects: [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
|
5 |
+
Background prompt: A realistic landscape scene
|
6 |
+
Negative prompt:
|
7 |
+
|
8 |
+
Caption: A realistic top-down view of a wooden table with two apples on it
|
9 |
+
Objects: [('a wooden table', [20, 148, 472, 216]), ('an apple', [150, 226, 100, 100]), ('an apple', [280, 226, 100, 100])]
|
10 |
+
Background prompt: A realistic top-down view
|
11 |
+
Negative prompt:
|
12 |
+
|
13 |
+
Caption: A realistic scene of three skiers standing in a line on the snow near a palm tree
|
14 |
+
Objects: [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]
|
15 |
+
Background prompt: A realistic outdoor scene with snow
|
16 |
+
Negative prompt:
|
17 |
+
|
18 |
+
Caption: An oil painting of a pink dolphin jumping on the left of a steam boat on the sea
|
19 |
+
Objects: [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]
|
20 |
+
Background prompt: An oil painting of the sea
|
21 |
+
Negative prompt:
|
22 |
+
|
23 |
+
Caption: A cute cat and an angry dog without birds
|
24 |
+
Objects: [('a cute cat', [51, 67, 271, 324]), ('an angry dog', [302, 119, 211, 228])]
|
25 |
+
Background prompt: A realistic scene
|
26 |
+
Negative prompt: birds
|
27 |
+
|
28 |
+
Caption: Two pandas in a forest without flowers
|
29 |
+
Objects: [('a panda', [30, 171, 212, 226]), ('a panda', [264, 173, 222, 221])]
|
30 |
+
Background prompt: A forest
|
31 |
+
Negative prompt: flowers
|
32 |
+
|
33 |
+
Caption: 一个客厅场景的油画,墙上挂着一幅画,电视下面是一个柜子,柜子上有一个花瓶,画里没有椅子。
|
34 |
+
Objects: [('a painting', [88, 85, 335, 203]), ('a cabinet', [57, 308, 404, 201]), ('a flower vase', [166, 222, 92, 108]), ('a flower vase', [328, 222, 92, 108])]
|
35 |
+
Background prompt: An oil painting of a living room scene
|
36 |
+
Negative prompt: chairs"""
|
37 |
+
|
38 |
+
simplified_prompt = """{template}
|
39 |
+
|
40 |
+
Caption: {prompt}
|
41 |
+
Objects: """
|
42 |
+
|
43 |
stage1_examples = [
|
44 |
["""A realistic photo of a wooden table with an apple on the left and a pear on the right."""],
|
45 |
["""A realistic photo of 4 TVs on a wall."""],
|
|
|
52 |
|
53 |
# Layout, seed
|
54 |
stage2_examples = [
|
55 |
+
["""Caption: A realistic top-down view of a wooden table with an apple on the left and a pear on the right.
|
56 |
Objects: [('a wooden table', [30, 30, 452, 452]), ('an apple', [52, 223, 50, 60]), ('a pear', [400, 240, 50, 60])]
|
57 |
+
Background prompt: A realistic top-down view of a room""", "A realistic top-down view of a wooden table with an apple on the left and a pear on the right.", 0],
|
58 |
["""Caption: A realistic photo of 4 TVs on a wall.
|
59 |
Objects: [('a TV', [12, 108, 120, 100]), ('a TV', [132, 112, 120, 100]), ('a TV', [252, 104, 120, 100]), ('a TV', [372, 106, 120, 100])]
|
60 |
+
Background prompt: A realistic photo of a wall""", "A realistic photo of 4 TVs on a wall.", 0],
|
61 |
["""Caption: A realistic photo of a gray cat and an orange dog on the grass.
|
62 |
Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
|
63 |
+
Background prompt: A realistic photo of a grassy area.""", "A realistic photo of a gray cat and an orange dog on the grass.", 0],
|
64 |
["""Caption: 一个室内场景的水彩画,一个桌子上面放着一盘水果
|
65 |
Objects: [('a table', [81, 242, 350, 210]), ('a plate of fruits', [151, 287, 210, 117])]
|
66 |
Background prompt: A watercolor painting of an indoor scene""", "", 1],
|
67 |
["""Caption: In an empty indoor scene, a blue cube directly above a red cube with a vase on the left of them.
|
68 |
Objects: [('a blue cube', [232, 116, 76, 76]), ('a red cube', [232, 212, 76, 76]), ('a vase', [100, 198, 62, 144])]
|
69 |
+
Background prompt: An empty indoor scene""", "In an empty indoor scene, a blue cube directly above a red cube with a vase on the left of them.", 2],
|
70 |
["""Caption: A realistic photo of a wooden table without bananas in an indoor scene
|
71 |
Objects: [('a wooden table', [75, 256, 365, 156])]
|
72 |
+
Background prompt: A realistic photo of an indoor scene
|
73 |
+
Negative prompt: bananas""", "A realistic photo of a wooden table without bananas in an indoor scene", 3],
|
74 |
["""Caption: A realistic photo of two cars on the road.
|
75 |
Objects: [('a car', [20, 242, 235, 185]), ('a car', [275, 246, 215, 180])]
|
76 |
Background prompt: A realistic photo of a road.""", "A realistic photo of two cars on the road.", 4],
|
77 |
]
|
78 |
+
|
79 |
+
|
80 |
+
prompt_placeholder = "A realistic photo of a gray cat and an orange dog on the grass."
|
81 |
+
|
82 |
+
layout_placeholder = """Caption: A realistic photo of a gray cat and an orange dog on the grass.
|
83 |
+
Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
|
84 |
+
Background prompt: A realistic photo of a grassy area."""
|
generation.py
CHANGED
@@ -1,19 +1,24 @@
|
|
1 |
-
version = "v3.0"
|
2 |
-
|
3 |
import torch
|
4 |
-
import numpy as np
|
5 |
import models
|
6 |
import utils
|
7 |
from models import pipelines, sam
|
8 |
-
from utils import parse, latents
|
9 |
-
from shared import
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
verbose = False
|
13 |
-
# Accelerates per-box generation
|
14 |
-
use_fast_schedule = True
|
15 |
|
16 |
-
vae, tokenizer, text_encoder, unet, dtype =
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
model_dict.update(sam_model_dict)
|
19 |
|
@@ -21,195 +26,472 @@ model_dict.update(sam_model_dict)
|
|
21 |
# Hyperparams
|
22 |
height = 512 # default height of Stable Diffusion
|
23 |
width = 512 # default width of Stable Diffusion
|
24 |
-
H, W = height // 8, width // 8
|
25 |
guidance_scale = 7.5 # Scale for classifier-free guidance
|
26 |
|
27 |
# batch size that is not 1 is not supported
|
28 |
overall_batch_size = 1
|
29 |
|
|
|
|
|
|
|
30 |
# discourage masks with confidence below
|
31 |
discourage_mask_below_confidence = 0.85
|
32 |
|
33 |
# discourage masks with iou (with coarse binarized attention mask) below
|
34 |
discourage_mask_below_coarse_iou = 0.25
|
35 |
|
|
|
|
|
|
|
36 |
run_ind = None
|
37 |
|
38 |
|
39 |
-
def
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
mask_selected_tensor = torch.tensor(mask_selected)
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
if not so_prompt_phrase_word_box_list:
|
104 |
-
return latents_all_list, mask_tensor_list
|
105 |
-
|
106 |
-
prompts, bboxes, phrases, words = [], [], [], []
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
return latents_all_list, mask_tensor_list, so_img_list
|
117 |
|
118 |
|
119 |
# Note: need to keep the supervision, especially the box corrdinates, corresponds to each other in single object and overall.
|
120 |
|
|
|
121 |
def run(
|
122 |
-
spec,
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
):
|
126 |
-
"""
|
127 |
so_center_box: using centered box in single object generation
|
128 |
so_horizontal_center_only: move to the center horizontally only
|
129 |
-
|
130 |
align_with_overall_bboxes: Align the center of the mask, latents, and cross-attention with the center of the box in overall bboxes
|
131 |
horizontal_shift_only: only shift horizontally for the alignment of mask, latents, and cross-attention
|
132 |
"""
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
frozen_step_ratio = min(max(frozen_step_ratio, 0.), 1.)
|
137 |
frozen_steps = int(num_inference_steps * frozen_step_ratio)
|
138 |
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
if overall_prompt_override and overall_prompt_override.strip():
|
143 |
overall_prompt = overall_prompt_override.strip()
|
144 |
|
145 |
-
overall_phrases, overall_words, overall_bboxes =
|
|
|
|
|
|
|
|
|
146 |
|
147 |
# The so box is centered but the overall boxes are not (since we need to place to the right place).
|
148 |
if so_center_box:
|
149 |
-
so_prompt_phrase_word_box_list = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
if verbose:
|
151 |
-
print(
|
|
|
|
|
152 |
so_boxes = [item[-1] for item in so_prompt_phrase_word_box_list]
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
sam_refine_kwargs = dict(
|
155 |
-
discourage_mask_below_confidence=discourage_mask_below_confidence,
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
)
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
159 |
# Note that so and overall use different negative prompts
|
160 |
|
161 |
with torch.autocast("cuda", enabled=use_autocast):
|
162 |
so_prompts = [item[0] for item in so_prompt_phrase_word_box_list]
|
163 |
if so_prompts:
|
164 |
-
so_input_embeddings = models.encode_prompts(
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
else:
|
166 |
so_input_embeddings = []
|
167 |
|
168 |
-
overall_input_embeddings = models.encode_prompts(prompts=[overall_prompt], tokenizer=tokenizer, negative_prompt=overall_negative_prompt, text_encoder=text_encoder)
|
169 |
-
|
170 |
input_latents_list, latents_bg = latents.get_input_latents_list(
|
171 |
-
model_dict,
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
)
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
)
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
)
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
overall_bboxes_flattened, overall_phrases_flattened = [], []
|
189 |
for overall_bboxes_item, overall_phrase in zip(overall_bboxes, overall_phrases):
|
190 |
for overall_bbox in overall_bboxes_item:
|
191 |
overall_bboxes_flattened.append(overall_bbox)
|
192 |
overall_phrases_flattened.append(overall_phrase)
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
# Generate with composed latents
|
195 |
|
196 |
# Foreground should be frozen
|
197 |
frozen_mask = foreground_indices != 0
|
198 |
-
|
199 |
-
|
200 |
-
model_dict,
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
)
|
205 |
|
206 |
-
print(
|
|
|
|
|
207 |
print("Generation from composed latents (with semantic guidance)")
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
gc.collect()
|
212 |
-
torch.cuda.empty_cache()
|
213 |
-
|
214 |
-
return images[0], so_img_list
|
215 |
|
|
|
|
|
|
|
|
1 |
import torch
|
|
|
2 |
import models
|
3 |
import utils
|
4 |
from models import pipelines, sam
|
5 |
+
from utils import parse, guidance, attn, latents, vis
|
6 |
+
from shared import (
|
7 |
+
model_dict,
|
8 |
+
sam_model_dict,
|
9 |
+
DEFAULT_SO_NEGATIVE_PROMPT,
|
10 |
+
DEFAULT_OVERALL_NEGATIVE_PROMPT,
|
11 |
+
)
|
12 |
|
13 |
verbose = False
|
|
|
|
|
14 |
|
15 |
+
vae, tokenizer, text_encoder, unet, dtype = (
|
16 |
+
model_dict.vae,
|
17 |
+
model_dict.tokenizer,
|
18 |
+
model_dict.text_encoder,
|
19 |
+
model_dict.unet,
|
20 |
+
model_dict.dtype,
|
21 |
+
)
|
22 |
|
23 |
model_dict.update(sam_model_dict)
|
24 |
|
|
|
26 |
# Hyperparams
|
27 |
height = 512 # default height of Stable Diffusion
|
28 |
width = 512 # default width of Stable Diffusion
|
29 |
+
H, W = height // 8, width // 8 # size of the latent
|
30 |
guidance_scale = 7.5 # Scale for classifier-free guidance
|
31 |
|
32 |
# batch size that is not 1 is not supported
|
33 |
overall_batch_size = 1
|
34 |
|
35 |
+
# semantic guidance kwargs (single object)
|
36 |
+
guidance_attn_keys = pipelines.DEFAULT_GUIDANCE_ATTN_KEYS
|
37 |
+
|
38 |
# discourage masks with confidence below
|
39 |
discourage_mask_below_confidence = 0.85
|
40 |
|
41 |
# discourage masks with iou (with coarse binarized attention mask) below
|
42 |
discourage_mask_below_coarse_iou = 0.25
|
43 |
|
44 |
+
# This is controls the foreground variations
|
45 |
+
fg_blending_ratio = 0.1
|
46 |
+
|
47 |
run_ind = None
|
48 |
|
49 |
|
50 |
+
def generate_single_object_with_box(
|
51 |
+
prompt,
|
52 |
+
box,
|
53 |
+
phrase,
|
54 |
+
word,
|
55 |
+
input_latents,
|
56 |
+
input_embeddings,
|
57 |
+
semantic_guidance_kwargs,
|
58 |
+
obj_attn_key,
|
59 |
+
saved_cross_attn_keys,
|
60 |
+
sam_refine_kwargs,
|
61 |
+
num_inference_steps,
|
62 |
+
gligen_scheduled_sampling_beta=0.3,
|
63 |
+
verbose=False,
|
64 |
+
visualize=False,
|
65 |
+
**kwargs,
|
66 |
+
):
|
67 |
+
bboxes, phrases, words = [box], [phrase], [word]
|
68 |
+
|
69 |
+
if verbose:
|
70 |
+
print(f"Getting token map (prompt: {prompt})")
|
71 |
+
|
72 |
+
object_positions, word_token_indices = guidance.get_phrase_indices(
|
73 |
+
tokenizer=tokenizer,
|
74 |
+
prompt=prompt,
|
75 |
+
phrases=phrases,
|
76 |
+
words=words,
|
77 |
+
return_word_token_indices=True,
|
78 |
+
# Since the prompt for single object is from background prompt + object name, we will not have the case of not found
|
79 |
+
add_suffix_if_not_found=False,
|
80 |
+
verbose=verbose,
|
81 |
+
)
|
82 |
+
# phrases only has one item, so we select the first item in word_token_indices
|
83 |
+
word_token_index = word_token_indices[0]
|
84 |
+
|
85 |
+
if verbose:
|
86 |
+
print("word_token_index:", word_token_index)
|
87 |
+
|
88 |
+
# `offload_guidance_cross_attn_to_cpu` will greatly slow down generation
|
89 |
+
(
|
90 |
+
latents,
|
91 |
+
single_object_images,
|
92 |
+
saved_attns,
|
93 |
+
single_object_pil_images_box_ann,
|
94 |
+
latents_all,
|
95 |
+
) = pipelines.generate_gligen(
|
96 |
+
model_dict,
|
97 |
+
input_latents,
|
98 |
+
input_embeddings,
|
99 |
+
num_inference_steps,
|
100 |
+
bboxes,
|
101 |
+
phrases,
|
102 |
+
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
103 |
+
guidance_scale=guidance_scale,
|
104 |
+
return_saved_cross_attn=True,
|
105 |
+
semantic_guidance=True,
|
106 |
+
semantic_guidance_bboxes=bboxes,
|
107 |
+
semantic_guidance_object_positions=object_positions,
|
108 |
+
semantic_guidance_kwargs=semantic_guidance_kwargs,
|
109 |
+
saved_cross_attn_keys=[obj_attn_key, *saved_cross_attn_keys],
|
110 |
+
return_cond_ca_only=True,
|
111 |
+
return_token_ca_only=word_token_index,
|
112 |
+
offload_cross_attn_to_cpu=False,
|
113 |
+
return_box_vis=True,
|
114 |
+
save_all_latents=True,
|
115 |
+
dynamic_num_inference_steps=True,
|
116 |
+
**kwargs,
|
117 |
+
)
|
118 |
+
# `saved_cross_attn_keys` kwargs may have duplicates
|
119 |
+
|
120 |
+
utils.free_memory()
|
121 |
+
|
122 |
+
single_object_pil_image_box_ann = single_object_pil_images_box_ann[0]
|
123 |
+
|
124 |
+
if visualize:
|
125 |
+
print("Single object image")
|
126 |
+
vis.display(single_object_pil_image_box_ann)
|
127 |
+
|
128 |
+
mask_selected, conf_score_selected = sam.sam_refine_box(
|
129 |
+
sam_input_image=single_object_images[0],
|
130 |
+
box=box,
|
131 |
+
model_dict=model_dict,
|
132 |
+
verbose=verbose,
|
133 |
+
**sam_refine_kwargs,
|
134 |
+
)
|
135 |
+
|
136 |
mask_selected_tensor = torch.tensor(mask_selected)
|
137 |
+
|
138 |
+
if verbose:
|
139 |
+
vis.visualize(mask_selected, "Mask (selected) after resize")
|
140 |
+
# This is only for visualizations
|
141 |
+
masked_latents = latents_all * mask_selected_tensor[None, None, None, ...]
|
142 |
+
vis.visualize_masked_latents(
|
143 |
+
latents_all, masked_latents, timestep_T=False, timestep_0=True
|
144 |
+
)
|
145 |
+
|
146 |
+
return (
|
147 |
+
latents_all,
|
148 |
+
mask_selected_tensor,
|
149 |
+
saved_attns,
|
150 |
+
single_object_pil_image_box_ann,
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def get_masked_latents_all_list(
|
155 |
+
so_prompt_phrase_word_box_list,
|
156 |
+
input_latents_list,
|
157 |
+
so_input_embeddings,
|
158 |
+
verbose=False,
|
159 |
+
**kwargs,
|
160 |
+
):
|
161 |
+
latents_all_list, mask_tensor_list, saved_attns_list, so_img_list = [], [], [], []
|
162 |
+
|
163 |
if not so_prompt_phrase_word_box_list:
|
164 |
+
return latents_all_list, mask_tensor_list, saved_attns_list
|
|
|
|
|
165 |
|
166 |
+
so_uncond_embeddings, so_cond_embeddings = so_input_embeddings
|
167 |
+
|
168 |
+
for idx, ((prompt, phrase, word, box), input_latents) in enumerate(
|
169 |
+
zip(so_prompt_phrase_word_box_list, input_latents_list)
|
170 |
+
):
|
171 |
+
so_current_cond_embeddings = so_cond_embeddings[idx : idx + 1]
|
172 |
+
so_current_text_embeddings = torch.cat(
|
173 |
+
[so_uncond_embeddings, so_current_cond_embeddings], dim=0
|
174 |
+
)
|
175 |
+
so_current_input_embeddings = (
|
176 |
+
so_current_text_embeddings,
|
177 |
+
so_uncond_embeddings,
|
178 |
+
so_current_cond_embeddings,
|
179 |
+
)
|
180 |
+
|
181 |
+
latents_all, mask_tensor, saved_attns, so_img = generate_single_object_with_box(
|
182 |
+
prompt,
|
183 |
+
box,
|
184 |
+
phrase,
|
185 |
+
word,
|
186 |
+
input_latents,
|
187 |
+
input_embeddings=so_current_input_embeddings,
|
188 |
+
verbose=verbose,
|
189 |
+
**kwargs,
|
190 |
+
)
|
191 |
+
latents_all_list.append(latents_all)
|
192 |
+
mask_tensor_list.append(mask_tensor)
|
193 |
+
saved_attns_list.append(saved_attns)
|
194 |
+
so_img_list.append(so_img)
|
195 |
|
196 |
+
return latents_all_list, mask_tensor_list, saved_attns_list, so_img_list
|
197 |
|
198 |
|
199 |
# Note: need to keep the supervision, especially the box corrdinates, corresponds to each other in single object and overall.
|
200 |
|
201 |
+
|
202 |
def run(
|
203 |
+
spec,
|
204 |
+
bg_seed=1,
|
205 |
+
overall_prompt_override="",
|
206 |
+
fg_seed_start=20,
|
207 |
+
frozen_step_ratio=0.4,
|
208 |
+
num_inference_steps=20,
|
209 |
+
loss_scale=20,
|
210 |
+
loss_threshold=5.0,
|
211 |
+
max_iter=[2] * 5 + [1] * 10,
|
212 |
+
max_index_step=15,
|
213 |
+
overall_loss_scale=20,
|
214 |
+
overall_loss_threshold=5.0,
|
215 |
+
overall_max_iter=[4] * 5 + [3] * 5 + [2] * 5 + [2] * 5 + [1] * 10,
|
216 |
+
overall_max_index_step=30,
|
217 |
+
so_gligen_scheduled_sampling_beta=0.4,
|
218 |
+
overall_gligen_scheduled_sampling_beta=0.4,
|
219 |
+
ref_ca_loss_weight=0.5,
|
220 |
+
so_center_box=False,
|
221 |
+
fg_blending_ratio=0.1,
|
222 |
+
scheduler_key="dpm_scheduler",
|
223 |
+
so_negative_prompt=DEFAULT_SO_NEGATIVE_PROMPT,
|
224 |
+
overall_negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT,
|
225 |
+
so_horizontal_center_only=True,
|
226 |
+
align_with_overall_bboxes=False,
|
227 |
+
horizontal_shift_only=True,
|
228 |
+
use_fast_schedule=True,
|
229 |
+
# Transfer the cross-attention from single object generation (with ref_ca_saved_attns)
|
230 |
+
# Use reference cross attention to guide the cross attention in the overall generation
|
231 |
+
use_ref_ca=True,
|
232 |
+
use_autocast=False,
|
233 |
):
|
234 |
+
"""
|
235 |
so_center_box: using centered box in single object generation
|
236 |
so_horizontal_center_only: move to the center horizontally only
|
237 |
+
|
238 |
align_with_overall_bboxes: Align the center of the mask, latents, and cross-attention with the center of the box in overall bboxes
|
239 |
horizontal_shift_only: only shift horizontally for the alignment of mask, latents, and cross-attention
|
240 |
"""
|
241 |
+
|
242 |
+
frozen_step_ratio = min(max(frozen_step_ratio, 0.0), 1.0)
|
|
|
|
|
243 |
frozen_steps = int(num_inference_steps * frozen_step_ratio)
|
244 |
|
245 |
+
print(
|
246 |
+
"generation:",
|
247 |
+
spec,
|
248 |
+
bg_seed,
|
249 |
+
fg_seed_start,
|
250 |
+
frozen_step_ratio,
|
251 |
+
so_gligen_scheduled_sampling_beta,
|
252 |
+
overall_gligen_scheduled_sampling_beta,
|
253 |
+
overall_max_index_step,
|
254 |
+
)
|
255 |
+
|
256 |
+
(
|
257 |
+
so_prompt_phrase_word_box_list,
|
258 |
+
overall_prompt,
|
259 |
+
overall_phrases_words_bboxes,
|
260 |
+
) = parse.convert_spec(spec, height, width, verbose=verbose)
|
261 |
|
262 |
if overall_prompt_override and overall_prompt_override.strip():
|
263 |
overall_prompt = overall_prompt_override.strip()
|
264 |
|
265 |
+
overall_phrases, overall_words, overall_bboxes = (
|
266 |
+
[item[0] for item in overall_phrases_words_bboxes],
|
267 |
+
[item[1] for item in overall_phrases_words_bboxes],
|
268 |
+
[item[2] for item in overall_phrases_words_bboxes],
|
269 |
+
)
|
270 |
|
271 |
# The so box is centered but the overall boxes are not (since we need to place to the right place).
|
272 |
if so_center_box:
|
273 |
+
so_prompt_phrase_word_box_list = [
|
274 |
+
(
|
275 |
+
prompt,
|
276 |
+
phrase,
|
277 |
+
word,
|
278 |
+
utils.get_centered_box(
|
279 |
+
bbox, horizontal_center_only=so_horizontal_center_only
|
280 |
+
),
|
281 |
+
)
|
282 |
+
for prompt, phrase, word, bbox in so_prompt_phrase_word_box_list
|
283 |
+
]
|
284 |
if verbose:
|
285 |
+
print(
|
286 |
+
f"centered so_prompt_phrase_word_box_list: {so_prompt_phrase_word_box_list}"
|
287 |
+
)
|
288 |
so_boxes = [item[-1] for item in so_prompt_phrase_word_box_list]
|
289 |
|
290 |
+
so_negative_prompt = DEFAULT_SO_NEGATIVE_PROMPT
|
291 |
+
overall_negative_prompt = DEFAULT_OVERALL_NEGATIVE_PROMPT
|
292 |
+
if "extra_neg_prompt" in spec and spec["extra_neg_prompt"]:
|
293 |
+
so_negative_prompt = spec["extra_neg_prompt"] + ", " + so_negative_prompt
|
294 |
+
overall_negative_prompt = (
|
295 |
+
spec["extra_neg_prompt"] + ", " + overall_negative_prompt
|
296 |
+
)
|
297 |
+
|
298 |
+
semantic_guidance_kwargs = dict(
|
299 |
+
loss_scale=loss_scale,
|
300 |
+
loss_threshold=loss_threshold,
|
301 |
+
max_iter=max_iter,
|
302 |
+
max_index_step=max_index_step,
|
303 |
+
use_ratio_based_loss=False,
|
304 |
+
guidance_attn_keys=guidance_attn_keys,
|
305 |
+
verbose=True,
|
306 |
+
)
|
307 |
+
|
308 |
sam_refine_kwargs = dict(
|
309 |
+
discourage_mask_below_confidence=discourage_mask_below_confidence,
|
310 |
+
discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
|
311 |
+
height=height,
|
312 |
+
width=width,
|
313 |
+
H=H,
|
314 |
+
W=W,
|
315 |
)
|
316 |
+
|
317 |
+
if verbose:
|
318 |
+
vis.visualize_bboxes(
|
319 |
+
bboxes=[item[-1] for item in so_prompt_phrase_word_box_list], H=H, W=W
|
320 |
+
)
|
321 |
+
|
322 |
# Note that so and overall use different negative prompts
|
323 |
|
324 |
with torch.autocast("cuda", enabled=use_autocast):
|
325 |
so_prompts = [item[0] for item in so_prompt_phrase_word_box_list]
|
326 |
if so_prompts:
|
327 |
+
so_input_embeddings = models.encode_prompts(
|
328 |
+
prompts=so_prompts,
|
329 |
+
tokenizer=tokenizer,
|
330 |
+
text_encoder=text_encoder,
|
331 |
+
negative_prompt=so_negative_prompt,
|
332 |
+
one_uncond_input_only=True,
|
333 |
+
)
|
334 |
else:
|
335 |
so_input_embeddings = []
|
336 |
|
|
|
|
|
337 |
input_latents_list, latents_bg = latents.get_input_latents_list(
|
338 |
+
model_dict,
|
339 |
+
bg_seed=bg_seed,
|
340 |
+
fg_seed_start=fg_seed_start,
|
341 |
+
so_boxes=so_boxes,
|
342 |
+
fg_blending_ratio=fg_blending_ratio,
|
343 |
+
height=height,
|
344 |
+
width=width,
|
345 |
+
verbose=False,
|
346 |
+
)
|
347 |
+
|
348 |
+
if use_fast_schedule:
|
349 |
+
fast_after_steps = max(frozen_steps, overall_max_index_step) if use_ref_ca else frozen_steps
|
350 |
+
else:
|
351 |
+
fast_after_steps = None
|
352 |
+
|
353 |
+
if use_ref_ca or frozen_steps > 0:
|
354 |
+
(
|
355 |
+
latents_all_list,
|
356 |
+
mask_tensor_list,
|
357 |
+
saved_attns_list,
|
358 |
+
so_img_list,
|
359 |
+
) = get_masked_latents_all_list(
|
360 |
+
so_prompt_phrase_word_box_list,
|
361 |
+
input_latents_list,
|
362 |
+
gligen_scheduled_sampling_beta=so_gligen_scheduled_sampling_beta,
|
363 |
+
semantic_guidance_kwargs=semantic_guidance_kwargs,
|
364 |
+
obj_attn_key=("down", 2, 1, 0),
|
365 |
+
saved_cross_attn_keys=guidance_attn_keys if use_ref_ca else [],
|
366 |
+
sam_refine_kwargs=sam_refine_kwargs,
|
367 |
+
so_input_embeddings=so_input_embeddings,
|
368 |
+
num_inference_steps=num_inference_steps,
|
369 |
+
scheduler_key=scheduler_key,
|
370 |
+
verbose=verbose,
|
371 |
+
fast_after_steps=fast_after_steps,
|
372 |
+
fast_rate=2,
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
# No per-box guidance
|
376 |
+
(latents_all_list, mask_tensor_list, saved_attns_list, so_img_list) = [], [], [], []
|
377 |
+
|
378 |
+
(
|
379 |
+
composed_latents,
|
380 |
+
foreground_indices,
|
381 |
+
offset_list,
|
382 |
+
) = latents.compose_latents_with_alignment(
|
383 |
+
model_dict,
|
384 |
+
latents_all_list,
|
385 |
+
mask_tensor_list,
|
386 |
+
num_inference_steps,
|
387 |
+
overall_batch_size,
|
388 |
+
height,
|
389 |
+
width,
|
390 |
+
latents_bg=latents_bg,
|
391 |
+
align_with_overall_bboxes=align_with_overall_bboxes,
|
392 |
+
overall_bboxes=overall_bboxes,
|
393 |
+
horizontal_shift_only=horizontal_shift_only,
|
394 |
+
use_fast_schedule=use_fast_schedule,
|
395 |
+
fast_after_steps=fast_after_steps,
|
396 |
)
|
397 |
+
|
398 |
+
# NOTE: need to ensure overall embeddings are generated after the update of overall prompt
|
399 |
+
(
|
400 |
+
overall_object_positions,
|
401 |
+
overall_word_token_indices,
|
402 |
+
overall_prompt
|
403 |
+
) = guidance.get_phrase_indices(
|
404 |
+
tokenizer=tokenizer,
|
405 |
+
prompt=overall_prompt,
|
406 |
+
phrases=overall_phrases,
|
407 |
+
words=overall_words,
|
408 |
+
verbose=verbose,
|
409 |
+
return_word_token_indices=True,
|
410 |
+
add_suffix_if_not_found=True
|
411 |
)
|
412 |
|
413 |
+
overall_input_embeddings = models.encode_prompts(
|
414 |
+
prompts=[overall_prompt],
|
415 |
+
tokenizer=tokenizer,
|
416 |
+
negative_prompt=overall_negative_prompt,
|
417 |
+
text_encoder=text_encoder,
|
418 |
)
|
419 |
+
|
420 |
+
if use_ref_ca:
|
421 |
+
# ref_ca_saved_attns has the same hierarchy as bboxes
|
422 |
+
ref_ca_saved_attns = []
|
423 |
+
|
424 |
+
flattened_box_idx = 0
|
425 |
+
for bboxes in overall_bboxes:
|
426 |
+
# bboxes: correspond to a phrase
|
427 |
+
ref_ca_current_phrase_saved_attns = []
|
428 |
+
for bbox in bboxes:
|
429 |
+
# each individual bbox
|
430 |
+
saved_attns = saved_attns_list[flattened_box_idx]
|
431 |
+
if align_with_overall_bboxes:
|
432 |
+
offset = offset_list[flattened_box_idx]
|
433 |
+
saved_attns = attn.shift_saved_attns(
|
434 |
+
saved_attns,
|
435 |
+
offset,
|
436 |
+
guidance_attn_keys=guidance_attn_keys,
|
437 |
+
horizontal_shift_only=horizontal_shift_only,
|
438 |
+
)
|
439 |
+
ref_ca_current_phrase_saved_attns.append(saved_attns)
|
440 |
+
flattened_box_idx += 1
|
441 |
+
ref_ca_saved_attns.append(ref_ca_current_phrase_saved_attns)
|
442 |
+
|
443 |
overall_bboxes_flattened, overall_phrases_flattened = [], []
|
444 |
for overall_bboxes_item, overall_phrase in zip(overall_bboxes, overall_phrases):
|
445 |
for overall_bbox in overall_bboxes_item:
|
446 |
overall_bboxes_flattened.append(overall_bbox)
|
447 |
overall_phrases_flattened.append(overall_phrase)
|
448 |
|
449 |
+
# This is currently not-shared with the single object one.
|
450 |
+
overall_semantic_guidance_kwargs = dict(
|
451 |
+
loss_scale=overall_loss_scale,
|
452 |
+
loss_threshold=overall_loss_threshold,
|
453 |
+
max_iter=overall_max_iter,
|
454 |
+
max_index_step=overall_max_index_step,
|
455 |
+
# ref_ca comes from the attention map of the word token of the phrase in single object generation, so we apply it only to the word token of the phrase in overall generation.
|
456 |
+
ref_ca_word_token_only=True,
|
457 |
+
# If a word is not provided, we use the last token.
|
458 |
+
ref_ca_last_token_only=True,
|
459 |
+
ref_ca_saved_attns=ref_ca_saved_attns if use_ref_ca else None,
|
460 |
+
word_token_indices=overall_word_token_indices,
|
461 |
+
guidance_attn_keys=guidance_attn_keys,
|
462 |
+
ref_ca_loss_weight=ref_ca_loss_weight,
|
463 |
+
use_ratio_based_loss=False,
|
464 |
+
verbose=True,
|
465 |
+
)
|
466 |
+
|
467 |
# Generate with composed latents
|
468 |
|
469 |
# Foreground should be frozen
|
470 |
frozen_mask = foreground_indices != 0
|
471 |
+
|
472 |
+
_, images = pipelines.generate_gligen(
|
473 |
+
model_dict,
|
474 |
+
composed_latents,
|
475 |
+
overall_input_embeddings,
|
476 |
+
num_inference_steps,
|
477 |
+
overall_bboxes_flattened,
|
478 |
+
overall_phrases_flattened,
|
479 |
+
guidance_scale=guidance_scale,
|
480 |
+
gligen_scheduled_sampling_beta=overall_gligen_scheduled_sampling_beta,
|
481 |
+
semantic_guidance=True,
|
482 |
+
semantic_guidance_bboxes=overall_bboxes,
|
483 |
+
semantic_guidance_object_positions=overall_object_positions,
|
484 |
+
semantic_guidance_kwargs=overall_semantic_guidance_kwargs,
|
485 |
+
frozen_steps=frozen_steps,
|
486 |
+
frozen_mask=frozen_mask,
|
487 |
+
scheduler_key=scheduler_key,
|
488 |
)
|
489 |
|
490 |
+
print(
|
491 |
+
f"Generation with spatial guidance from input latents and first {frozen_steps} steps frozen (directly from the composed latents input)"
|
492 |
+
)
|
493 |
print("Generation from composed latents (with semantic guidance)")
|
494 |
|
495 |
+
utils.free_memory()
|
|
|
|
|
|
|
|
|
|
|
496 |
|
497 |
+
return images[0], so_img_list
|
models/modeling_utils.py
DELETED
@@ -1,874 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
-
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
|
17 |
-
import inspect
|
18 |
-
import itertools
|
19 |
-
import os
|
20 |
-
from functools import partial
|
21 |
-
from typing import Any, Callable, List, Optional, Tuple, Union
|
22 |
-
|
23 |
-
import torch
|
24 |
-
from torch import Tensor, device
|
25 |
-
|
26 |
-
from diffusers import __version__
|
27 |
-
from diffusers.utils import (
|
28 |
-
CONFIG_NAME,
|
29 |
-
DIFFUSERS_CACHE,
|
30 |
-
FLAX_WEIGHTS_NAME,
|
31 |
-
HF_HUB_OFFLINE,
|
32 |
-
SAFETENSORS_WEIGHTS_NAME,
|
33 |
-
WEIGHTS_NAME,
|
34 |
-
_add_variant,
|
35 |
-
_get_model_file,
|
36 |
-
deprecate,
|
37 |
-
is_accelerate_available,
|
38 |
-
is_safetensors_available,
|
39 |
-
is_torch_version,
|
40 |
-
logging,
|
41 |
-
)
|
42 |
-
|
43 |
-
|
44 |
-
logger = logging.get_logger(__name__)
|
45 |
-
|
46 |
-
|
47 |
-
if is_torch_version(">=", "1.9.0"):
|
48 |
-
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
49 |
-
else:
|
50 |
-
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
51 |
-
|
52 |
-
|
53 |
-
if is_accelerate_available():
|
54 |
-
import accelerate
|
55 |
-
from accelerate.utils import set_module_tensor_to_device
|
56 |
-
from accelerate.utils.versions import is_torch_version
|
57 |
-
|
58 |
-
if is_safetensors_available():
|
59 |
-
import safetensors
|
60 |
-
|
61 |
-
|
62 |
-
def get_parameter_device(parameter: torch.nn.Module):
|
63 |
-
try:
|
64 |
-
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
65 |
-
return next(parameters_and_buffers).device
|
66 |
-
except StopIteration:
|
67 |
-
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
68 |
-
|
69 |
-
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
70 |
-
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
71 |
-
return tuples
|
72 |
-
|
73 |
-
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
74 |
-
first_tuple = next(gen)
|
75 |
-
return first_tuple[1].device
|
76 |
-
|
77 |
-
|
78 |
-
def get_parameter_dtype(parameter: torch.nn.Module):
|
79 |
-
try:
|
80 |
-
params = tuple(parameter.parameters())
|
81 |
-
if len(params) > 0:
|
82 |
-
return params[0].dtype
|
83 |
-
|
84 |
-
buffers = tuple(parameter.buffers())
|
85 |
-
if len(buffers) > 0:
|
86 |
-
return buffers[0].dtype
|
87 |
-
|
88 |
-
except StopIteration:
|
89 |
-
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
90 |
-
|
91 |
-
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
92 |
-
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
93 |
-
return tuples
|
94 |
-
|
95 |
-
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
96 |
-
first_tuple = next(gen)
|
97 |
-
return first_tuple[1].dtype
|
98 |
-
|
99 |
-
|
100 |
-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
101 |
-
"""
|
102 |
-
Reads a checkpoint file, returning properly formatted errors if they arise.
|
103 |
-
"""
|
104 |
-
try:
|
105 |
-
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
106 |
-
return torch.load(checkpoint_file, map_location="cpu")
|
107 |
-
else:
|
108 |
-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
109 |
-
except Exception as e:
|
110 |
-
try:
|
111 |
-
with open(checkpoint_file) as f:
|
112 |
-
if f.read().startswith("version"):
|
113 |
-
raise OSError(
|
114 |
-
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
115 |
-
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
116 |
-
"you cloned."
|
117 |
-
)
|
118 |
-
else:
|
119 |
-
raise ValueError(
|
120 |
-
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
121 |
-
"model. Make sure you have saved the model properly."
|
122 |
-
) from e
|
123 |
-
except (UnicodeDecodeError, ValueError):
|
124 |
-
raise OSError(
|
125 |
-
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
126 |
-
f"at '{checkpoint_file}'. "
|
127 |
-
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
128 |
-
)
|
129 |
-
|
130 |
-
|
131 |
-
def _load_state_dict_into_model(model_to_load, state_dict):
|
132 |
-
# Convert old format to new format if needed from a PyTorch state_dict
|
133 |
-
# copy state_dict so _load_from_state_dict can modify it
|
134 |
-
state_dict = state_dict.copy()
|
135 |
-
error_msgs = []
|
136 |
-
|
137 |
-
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
138 |
-
# so we need to apply the function recursively.
|
139 |
-
def load(module: torch.nn.Module, prefix=""):
|
140 |
-
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
141 |
-
module._load_from_state_dict(*args)
|
142 |
-
|
143 |
-
for name, child in module._modules.items():
|
144 |
-
if child is not None:
|
145 |
-
load(child, prefix + name + ".")
|
146 |
-
|
147 |
-
load(model_to_load)
|
148 |
-
|
149 |
-
return error_msgs
|
150 |
-
|
151 |
-
|
152 |
-
class ModelMixin(torch.nn.Module):
|
153 |
-
r"""
|
154 |
-
Base class for all models.
|
155 |
-
|
156 |
-
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
157 |
-
and saving models.
|
158 |
-
|
159 |
-
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
160 |
-
[`~models.ModelMixin.save_pretrained`].
|
161 |
-
"""
|
162 |
-
config_name = CONFIG_NAME
|
163 |
-
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
164 |
-
_supports_gradient_checkpointing = False
|
165 |
-
|
166 |
-
def __init__(self):
|
167 |
-
super().__init__()
|
168 |
-
|
169 |
-
def __getattr__(self, name: str) -> Any:
|
170 |
-
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
171 |
-
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
172 |
-
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
173 |
-
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
174 |
-
"""
|
175 |
-
|
176 |
-
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
177 |
-
is_attribute = name in self.__dict__
|
178 |
-
|
179 |
-
if is_in_config and not is_attribute:
|
180 |
-
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
181 |
-
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
182 |
-
return self._internal_dict[name]
|
183 |
-
|
184 |
-
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
185 |
-
return super().__getattr__(name)
|
186 |
-
|
187 |
-
@property
|
188 |
-
def is_gradient_checkpointing(self) -> bool:
|
189 |
-
"""
|
190 |
-
Whether gradient checkpointing is activated for this model or not.
|
191 |
-
|
192 |
-
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
193 |
-
activations".
|
194 |
-
"""
|
195 |
-
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
196 |
-
|
197 |
-
def enable_gradient_checkpointing(self):
|
198 |
-
"""
|
199 |
-
Activates gradient checkpointing for the current model.
|
200 |
-
|
201 |
-
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
202 |
-
activations".
|
203 |
-
"""
|
204 |
-
if not self._supports_gradient_checkpointing:
|
205 |
-
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
206 |
-
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
207 |
-
|
208 |
-
def disable_gradient_checkpointing(self):
|
209 |
-
"""
|
210 |
-
Deactivates gradient checkpointing for the current model.
|
211 |
-
|
212 |
-
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
213 |
-
activations".
|
214 |
-
"""
|
215 |
-
if self._supports_gradient_checkpointing:
|
216 |
-
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
217 |
-
|
218 |
-
def set_use_memory_efficient_attention_xformers(
|
219 |
-
self, valid: bool, attention_op: Optional[Callable] = None
|
220 |
-
) -> None:
|
221 |
-
# Recursively walk through all the children.
|
222 |
-
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
223 |
-
# gets the message
|
224 |
-
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
225 |
-
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
226 |
-
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
227 |
-
|
228 |
-
for child in module.children():
|
229 |
-
fn_recursive_set_mem_eff(child)
|
230 |
-
|
231 |
-
for module in self.children():
|
232 |
-
if isinstance(module, torch.nn.Module):
|
233 |
-
fn_recursive_set_mem_eff(module)
|
234 |
-
|
235 |
-
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
236 |
-
r"""
|
237 |
-
Enable memory efficient attention as implemented in xformers.
|
238 |
-
|
239 |
-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
240 |
-
time. Speed up at training time is not guaranteed.
|
241 |
-
|
242 |
-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
243 |
-
is used.
|
244 |
-
|
245 |
-
Parameters:
|
246 |
-
attention_op (`Callable`, *optional*):
|
247 |
-
Override the default `None` operator for use as `op` argument to the
|
248 |
-
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
249 |
-
function of xFormers.
|
250 |
-
|
251 |
-
Examples:
|
252 |
-
|
253 |
-
```py
|
254 |
-
>>> import torch
|
255 |
-
>>> from diffusers import UNet2DConditionModel
|
256 |
-
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
257 |
-
|
258 |
-
>>> model = UNet2DConditionModel.from_pretrained(
|
259 |
-
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
260 |
-
... )
|
261 |
-
>>> model = model.to("cuda")
|
262 |
-
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
263 |
-
```
|
264 |
-
"""
|
265 |
-
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
266 |
-
|
267 |
-
def disable_xformers_memory_efficient_attention(self):
|
268 |
-
r"""
|
269 |
-
Disable memory efficient attention as implemented in xformers.
|
270 |
-
"""
|
271 |
-
self.set_use_memory_efficient_attention_xformers(False)
|
272 |
-
|
273 |
-
def save_pretrained(
|
274 |
-
self,
|
275 |
-
save_directory: Union[str, os.PathLike],
|
276 |
-
is_main_process: bool = True,
|
277 |
-
save_function: Callable = None,
|
278 |
-
safe_serialization: bool = False,
|
279 |
-
variant: Optional[str] = None,
|
280 |
-
):
|
281 |
-
"""
|
282 |
-
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
283 |
-
`[`~models.ModelMixin.from_pretrained`]` class method.
|
284 |
-
|
285 |
-
Arguments:
|
286 |
-
save_directory (`str` or `os.PathLike`):
|
287 |
-
Directory to which to save. Will be created if it doesn't exist.
|
288 |
-
is_main_process (`bool`, *optional*, defaults to `True`):
|
289 |
-
Whether the process calling this is the main process or not. Useful when in distributed training like
|
290 |
-
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
291 |
-
the main process to avoid race conditions.
|
292 |
-
save_function (`Callable`):
|
293 |
-
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
294 |
-
need to replace `torch.save` by another method. Can be configured with the environment variable
|
295 |
-
`DIFFUSERS_SAVE_MODE`.
|
296 |
-
safe_serialization (`bool`, *optional*, defaults to `False`):
|
297 |
-
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
298 |
-
variant (`str`, *optional*):
|
299 |
-
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
300 |
-
"""
|
301 |
-
if safe_serialization and not is_safetensors_available():
|
302 |
-
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
303 |
-
|
304 |
-
if os.path.isfile(save_directory):
|
305 |
-
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
306 |
-
return
|
307 |
-
|
308 |
-
os.makedirs(save_directory, exist_ok=True)
|
309 |
-
|
310 |
-
model_to_save = self
|
311 |
-
|
312 |
-
# Attach architecture to the config
|
313 |
-
# Save the config
|
314 |
-
if is_main_process:
|
315 |
-
model_to_save.save_config(save_directory)
|
316 |
-
|
317 |
-
# Save the model
|
318 |
-
state_dict = model_to_save.state_dict()
|
319 |
-
|
320 |
-
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
321 |
-
weights_name = _add_variant(weights_name, variant)
|
322 |
-
|
323 |
-
# Save the model
|
324 |
-
if safe_serialization:
|
325 |
-
safetensors.torch.save_file(
|
326 |
-
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
327 |
-
)
|
328 |
-
else:
|
329 |
-
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
330 |
-
|
331 |
-
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
332 |
-
|
333 |
-
@classmethod
|
334 |
-
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
335 |
-
r"""
|
336 |
-
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
337 |
-
|
338 |
-
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
339 |
-
the model, you should first set it back in training mode with `model.train()`.
|
340 |
-
|
341 |
-
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
342 |
-
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
343 |
-
task.
|
344 |
-
|
345 |
-
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
346 |
-
weights are discarded.
|
347 |
-
|
348 |
-
Parameters:
|
349 |
-
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
350 |
-
Can be either:
|
351 |
-
|
352 |
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
353 |
-
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
354 |
-
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
355 |
-
`./my_model_directory/`.
|
356 |
-
|
357 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
358 |
-
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
359 |
-
standard cache should not be used.
|
360 |
-
torch_dtype (`str` or `torch.dtype`, *optional*):
|
361 |
-
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
362 |
-
will be automatically derived from the model's weights.
|
363 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
364 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
365 |
-
cached versions if they exist.
|
366 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
367 |
-
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
368 |
-
file exists.
|
369 |
-
proxies (`Dict[str, str]`, *optional*):
|
370 |
-
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
371 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
372 |
-
output_loading_info(`bool`, *optional*, defaults to `False`):
|
373 |
-
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
374 |
-
local_files_only(`bool`, *optional*, defaults to `False`):
|
375 |
-
Whether or not to only look at local files (i.e., do not try to download the model).
|
376 |
-
use_auth_token (`str` or *bool*, *optional*):
|
377 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
378 |
-
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
379 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
380 |
-
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
381 |
-
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
382 |
-
identifier allowed by git.
|
383 |
-
from_flax (`bool`, *optional*, defaults to `False`):
|
384 |
-
Load the model weights from a Flax checkpoint save file.
|
385 |
-
subfolder (`str`, *optional*, defaults to `""`):
|
386 |
-
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
387 |
-
huggingface.co or downloaded locally), you can specify the folder name here.
|
388 |
-
|
389 |
-
mirror (`str`, *optional*):
|
390 |
-
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
391 |
-
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
392 |
-
Please refer to the mirror site for more information.
|
393 |
-
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
394 |
-
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
395 |
-
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
396 |
-
same device.
|
397 |
-
|
398 |
-
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
399 |
-
more information about each option see [designing a device
|
400 |
-
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
401 |
-
max_memory (`Dict`, *optional*):
|
402 |
-
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
403 |
-
GPU and the available CPU RAM if unset.
|
404 |
-
offload_folder (`str` or `os.PathLike`, *optional*):
|
405 |
-
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
406 |
-
offload_state_dict (`bool`, *optional*):
|
407 |
-
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
|
408 |
-
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
|
409 |
-
`True` when there is some disk offload.
|
410 |
-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
411 |
-
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
412 |
-
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
413 |
-
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
414 |
-
setting this argument to `True` will raise an error.
|
415 |
-
variant (`str`, *optional*):
|
416 |
-
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
417 |
-
ignored when using `from_flax`.
|
418 |
-
use_safetensors (`bool`, *optional*, defaults to `None`):
|
419 |
-
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
420 |
-
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
421 |
-
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
422 |
-
|
423 |
-
<Tip>
|
424 |
-
|
425 |
-
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
426 |
-
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
427 |
-
|
428 |
-
</Tip>
|
429 |
-
|
430 |
-
<Tip>
|
431 |
-
|
432 |
-
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
433 |
-
this method in a firewalled environment.
|
434 |
-
|
435 |
-
</Tip>
|
436 |
-
|
437 |
-
"""
|
438 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
439 |
-
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
440 |
-
force_download = kwargs.pop("force_download", False)
|
441 |
-
from_flax = kwargs.pop("from_flax", False)
|
442 |
-
resume_download = kwargs.pop("resume_download", False)
|
443 |
-
proxies = kwargs.pop("proxies", None)
|
444 |
-
output_loading_info = kwargs.pop("output_loading_info", False)
|
445 |
-
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
446 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
447 |
-
revision = kwargs.pop("revision", None)
|
448 |
-
torch_dtype = kwargs.pop("torch_dtype", None)
|
449 |
-
subfolder = kwargs.pop("subfolder", None)
|
450 |
-
device_map = kwargs.pop("device_map", None)
|
451 |
-
max_memory = kwargs.pop("max_memory", None)
|
452 |
-
offload_folder = kwargs.pop("offload_folder", None)
|
453 |
-
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
454 |
-
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
455 |
-
variant = kwargs.pop("variant", None)
|
456 |
-
use_safetensors = kwargs.pop("use_safetensors", None)
|
457 |
-
|
458 |
-
if use_safetensors and not is_safetensors_available():
|
459 |
-
raise ValueError(
|
460 |
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
461 |
-
)
|
462 |
-
|
463 |
-
allow_pickle = False
|
464 |
-
if use_safetensors is None:
|
465 |
-
use_safetensors = is_safetensors_available()
|
466 |
-
allow_pickle = True
|
467 |
-
|
468 |
-
if low_cpu_mem_usage and not is_accelerate_available():
|
469 |
-
low_cpu_mem_usage = False
|
470 |
-
logger.warning(
|
471 |
-
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
472 |
-
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
473 |
-
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
474 |
-
" install accelerate\n```\n."
|
475 |
-
)
|
476 |
-
|
477 |
-
if device_map is not None and not is_accelerate_available():
|
478 |
-
raise NotImplementedError(
|
479 |
-
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
480 |
-
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
481 |
-
)
|
482 |
-
|
483 |
-
# Check if we can handle device_map and dispatching the weights
|
484 |
-
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
485 |
-
raise NotImplementedError(
|
486 |
-
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
487 |
-
" `device_map=None`."
|
488 |
-
)
|
489 |
-
|
490 |
-
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
491 |
-
raise NotImplementedError(
|
492 |
-
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
493 |
-
" `low_cpu_mem_usage=False`."
|
494 |
-
)
|
495 |
-
|
496 |
-
if low_cpu_mem_usage is False and device_map is not None:
|
497 |
-
raise ValueError(
|
498 |
-
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
499 |
-
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
500 |
-
)
|
501 |
-
|
502 |
-
# Load config if we don't provide a configuration
|
503 |
-
config_path = pretrained_model_name_or_path
|
504 |
-
|
505 |
-
user_agent = {
|
506 |
-
"diffusers": __version__,
|
507 |
-
"file_type": "model",
|
508 |
-
"framework": "pytorch",
|
509 |
-
}
|
510 |
-
|
511 |
-
# load config
|
512 |
-
config, unused_kwargs, commit_hash = cls.load_config(
|
513 |
-
config_path,
|
514 |
-
cache_dir=cache_dir,
|
515 |
-
return_unused_kwargs=True,
|
516 |
-
return_commit_hash=True,
|
517 |
-
force_download=force_download,
|
518 |
-
resume_download=resume_download,
|
519 |
-
proxies=proxies,
|
520 |
-
local_files_only=local_files_only,
|
521 |
-
use_auth_token=use_auth_token,
|
522 |
-
revision=revision,
|
523 |
-
subfolder=subfolder,
|
524 |
-
device_map=device_map,
|
525 |
-
max_memory=max_memory,
|
526 |
-
offload_folder=offload_folder,
|
527 |
-
offload_state_dict=offload_state_dict,
|
528 |
-
user_agent=user_agent,
|
529 |
-
**kwargs,
|
530 |
-
)
|
531 |
-
|
532 |
-
# load model
|
533 |
-
model_file = None
|
534 |
-
if from_flax:
|
535 |
-
model_file = _get_model_file(
|
536 |
-
pretrained_model_name_or_path,
|
537 |
-
weights_name=FLAX_WEIGHTS_NAME,
|
538 |
-
cache_dir=cache_dir,
|
539 |
-
force_download=force_download,
|
540 |
-
resume_download=resume_download,
|
541 |
-
proxies=proxies,
|
542 |
-
local_files_only=local_files_only,
|
543 |
-
use_auth_token=use_auth_token,
|
544 |
-
revision=revision,
|
545 |
-
subfolder=subfolder,
|
546 |
-
user_agent=user_agent,
|
547 |
-
commit_hash=commit_hash,
|
548 |
-
)
|
549 |
-
model = cls.from_config(config, **unused_kwargs)
|
550 |
-
|
551 |
-
# Convert the weights
|
552 |
-
from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
553 |
-
|
554 |
-
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
555 |
-
else:
|
556 |
-
if use_safetensors:
|
557 |
-
try:
|
558 |
-
model_file = _get_model_file(
|
559 |
-
pretrained_model_name_or_path,
|
560 |
-
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
561 |
-
cache_dir=cache_dir,
|
562 |
-
force_download=force_download,
|
563 |
-
resume_download=resume_download,
|
564 |
-
proxies=proxies,
|
565 |
-
local_files_only=local_files_only,
|
566 |
-
use_auth_token=use_auth_token,
|
567 |
-
revision=revision,
|
568 |
-
subfolder=subfolder,
|
569 |
-
user_agent=user_agent,
|
570 |
-
commit_hash=commit_hash,
|
571 |
-
)
|
572 |
-
except IOError as e:
|
573 |
-
if not allow_pickle:
|
574 |
-
raise e
|
575 |
-
pass
|
576 |
-
if model_file is None:
|
577 |
-
model_file = _get_model_file(
|
578 |
-
pretrained_model_name_or_path,
|
579 |
-
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
580 |
-
cache_dir=cache_dir,
|
581 |
-
force_download=force_download,
|
582 |
-
resume_download=resume_download,
|
583 |
-
proxies=proxies,
|
584 |
-
local_files_only=local_files_only,
|
585 |
-
use_auth_token=use_auth_token,
|
586 |
-
revision=revision,
|
587 |
-
subfolder=subfolder,
|
588 |
-
user_agent=user_agent,
|
589 |
-
commit_hash=commit_hash,
|
590 |
-
)
|
591 |
-
|
592 |
-
if low_cpu_mem_usage:
|
593 |
-
# Instantiate model with empty weights
|
594 |
-
with accelerate.init_empty_weights():
|
595 |
-
model = cls.from_config(config, **unused_kwargs)
|
596 |
-
|
597 |
-
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
598 |
-
if device_map is None:
|
599 |
-
param_device = "cpu"
|
600 |
-
state_dict = load_state_dict(model_file, variant=variant)
|
601 |
-
model._convert_deprecated_attention_blocks(state_dict)
|
602 |
-
# move the params from meta device to cpu
|
603 |
-
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
604 |
-
if len(missing_keys) > 0:
|
605 |
-
raise ValueError(
|
606 |
-
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
607 |
-
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
608 |
-
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
609 |
-
" those weights or else make sure your checkpoint file is correct."
|
610 |
-
)
|
611 |
-
|
612 |
-
empty_state_dict = model.state_dict()
|
613 |
-
for param_name, param in state_dict.items():
|
614 |
-
accepts_dtype = "dtype" in set(
|
615 |
-
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
616 |
-
)
|
617 |
-
|
618 |
-
if empty_state_dict[param_name].shape != param.shape:
|
619 |
-
raise ValueError(
|
620 |
-
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
621 |
-
)
|
622 |
-
|
623 |
-
if accepts_dtype:
|
624 |
-
set_module_tensor_to_device(
|
625 |
-
model, param_name, param_device, value=param, dtype=torch_dtype
|
626 |
-
)
|
627 |
-
else:
|
628 |
-
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
629 |
-
else: # else let accelerate handle loading and dispatching.
|
630 |
-
# Load weights and dispatch according to the device_map
|
631 |
-
# by default the device_map is None and the weights are loaded on the CPU
|
632 |
-
accelerate.load_checkpoint_and_dispatch(
|
633 |
-
model,
|
634 |
-
model_file,
|
635 |
-
device_map,
|
636 |
-
max_memory=max_memory,
|
637 |
-
offload_folder=offload_folder,
|
638 |
-
offload_state_dict=offload_state_dict,
|
639 |
-
dtype=torch_dtype,
|
640 |
-
)
|
641 |
-
|
642 |
-
loading_info = {
|
643 |
-
"missing_keys": [],
|
644 |
-
"unexpected_keys": [],
|
645 |
-
"mismatched_keys": [],
|
646 |
-
"error_msgs": [],
|
647 |
-
}
|
648 |
-
else:
|
649 |
-
model = cls.from_config(config, **unused_kwargs)
|
650 |
-
|
651 |
-
state_dict = load_state_dict(model_file, variant=variant)
|
652 |
-
model._convert_deprecated_attention_blocks(state_dict)
|
653 |
-
|
654 |
-
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
655 |
-
model,
|
656 |
-
state_dict,
|
657 |
-
model_file,
|
658 |
-
pretrained_model_name_or_path,
|
659 |
-
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
660 |
-
)
|
661 |
-
|
662 |
-
loading_info = {
|
663 |
-
"missing_keys": missing_keys,
|
664 |
-
"unexpected_keys": unexpected_keys,
|
665 |
-
"mismatched_keys": mismatched_keys,
|
666 |
-
"error_msgs": error_msgs,
|
667 |
-
}
|
668 |
-
|
669 |
-
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
670 |
-
raise ValueError(
|
671 |
-
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
672 |
-
)
|
673 |
-
elif torch_dtype is not None:
|
674 |
-
model = model.to(torch_dtype)
|
675 |
-
|
676 |
-
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
677 |
-
|
678 |
-
# Set model in evaluation mode to deactivate DropOut modules by default
|
679 |
-
model.eval()
|
680 |
-
if output_loading_info:
|
681 |
-
return model, loading_info
|
682 |
-
|
683 |
-
return model
|
684 |
-
|
685 |
-
@classmethod
|
686 |
-
def _load_pretrained_model(
|
687 |
-
cls,
|
688 |
-
model,
|
689 |
-
state_dict,
|
690 |
-
resolved_archive_file,
|
691 |
-
pretrained_model_name_or_path,
|
692 |
-
ignore_mismatched_sizes=False,
|
693 |
-
):
|
694 |
-
# Retrieve missing & unexpected_keys
|
695 |
-
model_state_dict = model.state_dict()
|
696 |
-
loaded_keys = list(state_dict.keys())
|
697 |
-
|
698 |
-
expected_keys = list(model_state_dict.keys())
|
699 |
-
|
700 |
-
original_loaded_keys = loaded_keys
|
701 |
-
|
702 |
-
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
703 |
-
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
704 |
-
|
705 |
-
# Make sure we are able to load base models as well as derived models (with heads)
|
706 |
-
model_to_load = model
|
707 |
-
|
708 |
-
def _find_mismatched_keys(
|
709 |
-
state_dict,
|
710 |
-
model_state_dict,
|
711 |
-
loaded_keys,
|
712 |
-
ignore_mismatched_sizes,
|
713 |
-
):
|
714 |
-
mismatched_keys = []
|
715 |
-
if ignore_mismatched_sizes:
|
716 |
-
for checkpoint_key in loaded_keys:
|
717 |
-
model_key = checkpoint_key
|
718 |
-
|
719 |
-
if (
|
720 |
-
model_key in model_state_dict
|
721 |
-
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
722 |
-
):
|
723 |
-
mismatched_keys.append(
|
724 |
-
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
725 |
-
)
|
726 |
-
del state_dict[checkpoint_key]
|
727 |
-
return mismatched_keys
|
728 |
-
|
729 |
-
if state_dict is not None:
|
730 |
-
# Whole checkpoint
|
731 |
-
mismatched_keys = _find_mismatched_keys(
|
732 |
-
state_dict,
|
733 |
-
model_state_dict,
|
734 |
-
original_loaded_keys,
|
735 |
-
ignore_mismatched_sizes,
|
736 |
-
)
|
737 |
-
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
738 |
-
|
739 |
-
if len(error_msgs) > 0:
|
740 |
-
error_msg = "\n\t".join(error_msgs)
|
741 |
-
if "size mismatch" in error_msg:
|
742 |
-
error_msg += (
|
743 |
-
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
744 |
-
)
|
745 |
-
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
746 |
-
|
747 |
-
if len(unexpected_keys) > 0:
|
748 |
-
logger.warning(
|
749 |
-
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
750 |
-
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
751 |
-
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
752 |
-
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
753 |
-
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
754 |
-
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
755 |
-
" identical (initializing a BertForSequenceClassification model from a"
|
756 |
-
" BertForSequenceClassification model)."
|
757 |
-
)
|
758 |
-
else:
|
759 |
-
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
760 |
-
if len(missing_keys) > 0:
|
761 |
-
logger.warning(
|
762 |
-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
763 |
-
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
764 |
-
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
765 |
-
)
|
766 |
-
elif len(mismatched_keys) == 0:
|
767 |
-
logger.info(
|
768 |
-
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
769 |
-
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
770 |
-
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
771 |
-
" without further training."
|
772 |
-
)
|
773 |
-
if len(mismatched_keys) > 0:
|
774 |
-
mismatched_warning = "\n".join(
|
775 |
-
[
|
776 |
-
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
777 |
-
for key, shape1, shape2 in mismatched_keys
|
778 |
-
]
|
779 |
-
)
|
780 |
-
logger.warning(
|
781 |
-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
782 |
-
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
783 |
-
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
784 |
-
" able to use it for predictions and inference."
|
785 |
-
)
|
786 |
-
|
787 |
-
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
788 |
-
|
789 |
-
@property
|
790 |
-
def device(self) -> device:
|
791 |
-
"""
|
792 |
-
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
793 |
-
device).
|
794 |
-
"""
|
795 |
-
return get_parameter_device(self)
|
796 |
-
|
797 |
-
@property
|
798 |
-
def dtype(self) -> torch.dtype:
|
799 |
-
"""
|
800 |
-
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
801 |
-
"""
|
802 |
-
return get_parameter_dtype(self)
|
803 |
-
|
804 |
-
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
805 |
-
"""
|
806 |
-
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
807 |
-
|
808 |
-
Args:
|
809 |
-
only_trainable (`bool`, *optional*, defaults to `False`):
|
810 |
-
Whether or not to return only the number of trainable parameters
|
811 |
-
|
812 |
-
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
813 |
-
Whether or not to return only the number of non-embeddings parameters
|
814 |
-
|
815 |
-
Returns:
|
816 |
-
`int`: The number of parameters.
|
817 |
-
"""
|
818 |
-
|
819 |
-
if exclude_embeddings:
|
820 |
-
embedding_param_names = [
|
821 |
-
f"{name}.weight"
|
822 |
-
for name, module_type in self.named_modules()
|
823 |
-
if isinstance(module_type, torch.nn.Embedding)
|
824 |
-
]
|
825 |
-
non_embedding_parameters = [
|
826 |
-
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
827 |
-
]
|
828 |
-
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
829 |
-
else:
|
830 |
-
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
831 |
-
|
832 |
-
def _convert_deprecated_attention_blocks(self, state_dict):
|
833 |
-
deprecated_attention_block_paths = []
|
834 |
-
|
835 |
-
def recursive_find_attn_block(name, module):
|
836 |
-
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
837 |
-
deprecated_attention_block_paths.append(name)
|
838 |
-
|
839 |
-
for sub_name, sub_module in module.named_children():
|
840 |
-
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
841 |
-
recursive_find_attn_block(sub_name, sub_module)
|
842 |
-
|
843 |
-
recursive_find_attn_block("", self)
|
844 |
-
|
845 |
-
# NOTE: we have to check if the deprecated parameters are in the state dict
|
846 |
-
# because it is possible we are loading from a state dict that was already
|
847 |
-
# converted
|
848 |
-
|
849 |
-
for path in deprecated_attention_block_paths:
|
850 |
-
# group_norm path stays the same
|
851 |
-
|
852 |
-
# query -> to_q
|
853 |
-
if f"{path}.query.weight" in state_dict:
|
854 |
-
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
855 |
-
if f"{path}.query.bias" in state_dict:
|
856 |
-
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
857 |
-
|
858 |
-
# key -> to_k
|
859 |
-
if f"{path}.key.weight" in state_dict:
|
860 |
-
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
861 |
-
if f"{path}.key.bias" in state_dict:
|
862 |
-
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
863 |
-
|
864 |
-
# value -> to_v
|
865 |
-
if f"{path}.value.weight" in state_dict:
|
866 |
-
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
867 |
-
if f"{path}.value.bias" in state_dict:
|
868 |
-
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
869 |
-
|
870 |
-
# proj_attn -> to_out.0
|
871 |
-
if f"{path}.proj_attn.weight" in state_dict:
|
872 |
-
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
873 |
-
if f"{path}.proj_attn.bias" in state_dict:
|
874 |
-
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/pipelines.py
CHANGED
@@ -1,12 +1,85 @@
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
|
|
3 |
import utils
|
4 |
-
from utils import schedule
|
5 |
from PIL import Image
|
6 |
import gc
|
7 |
import numpy as np
|
8 |
from .attention import GatedSelfAttentionDense
|
9 |
from .models import process_input_embeddings, torch_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
@torch.no_grad()
|
12 |
def encode(model_dict, image, generator):
|
@@ -53,6 +126,126 @@ def decode(vae, latents):
|
|
53 |
|
54 |
return images
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
@torch.no_grad()
|
57 |
def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False, scheduler_key='dpm_scheduler'):
|
58 |
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
|
@@ -132,9 +325,13 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
132 |
frozen_steps=20, frozen_mask=None,
|
133 |
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
|
134 |
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
|
|
|
135 |
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
|
136 |
"""
|
137 |
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
|
|
|
|
|
|
|
138 |
"""
|
139 |
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
|
140 |
|
@@ -161,6 +358,9 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
161 |
if fast_after_steps is not None:
|
162 |
scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
|
163 |
|
|
|
|
|
|
|
164 |
if frozen_mask is not None:
|
165 |
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
166 |
|
@@ -171,6 +371,23 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
171 |
|
172 |
boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
if return_saved_cross_attn:
|
175 |
saved_attns = []
|
176 |
|
@@ -196,6 +413,9 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
196 |
if index == num_grounding_steps:
|
197 |
gligen_enable_fuser(unet, False)
|
198 |
|
|
|
|
|
|
|
199 |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
200 |
latent_model_input = torch.cat([latents] * 2)
|
201 |
|
@@ -215,7 +435,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
215 |
# perform guidance
|
216 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
217 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
218 |
-
|
219 |
if dynamic_num_inference_steps:
|
220 |
schedule.dynamically_adjust_inference_steps(scheduler, index, t)
|
221 |
|
@@ -225,12 +445,17 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
225 |
if frozen_mask is not None and index < frozen_steps:
|
226 |
latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
|
227 |
|
|
|
228 |
if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
|
229 |
if offload_latents_to_cpu:
|
230 |
latents_all.append(latents.cpu())
|
231 |
else:
|
232 |
latents_all.append(latents)
|
233 |
|
|
|
|
|
|
|
|
|
234 |
# Turn off fuser for typical SD
|
235 |
gligen_enable_fuser(unet, False)
|
236 |
images = decode(vae, latents)
|
@@ -247,3 +472,128 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
247 |
|
248 |
return tuple(ret)
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
+
from utils import guidance, schedule, boxdiff
|
4 |
import utils
|
|
|
5 |
from PIL import Image
|
6 |
import gc
|
7 |
import numpy as np
|
8 |
from .attention import GatedSelfAttentionDense
|
9 |
from .models import process_input_embeddings, torch_device
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
# All keys: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
|
13 |
+
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
|
14 |
+
DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
|
15 |
+
|
16 |
+
def latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, loss_scale = 30, loss_threshold = 0.2, max_iter = 5, max_index_step = 10, cross_attention_kwargs=None, ref_ca_saved_attns=None, guidance_attn_keys=None, verbose=False, clear_cache=False, **kwargs):
|
17 |
+
|
18 |
+
iteration = 0
|
19 |
+
|
20 |
+
if index < max_index_step:
|
21 |
+
if isinstance(max_iter, list):
|
22 |
+
if len(max_iter) > index:
|
23 |
+
max_iter = max_iter[index]
|
24 |
+
else:
|
25 |
+
max_iter = max_iter[-1]
|
26 |
+
|
27 |
+
if verbose:
|
28 |
+
print(f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}")
|
29 |
+
|
30 |
+
while (loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step):
|
31 |
+
saved_attn = {}
|
32 |
+
full_cross_attention_kwargs = {
|
33 |
+
'save_attn_to_dict': saved_attn,
|
34 |
+
'save_keys': guidance_attn_keys,
|
35 |
+
}
|
36 |
+
|
37 |
+
if cross_attention_kwargs is not None:
|
38 |
+
full_cross_attention_kwargs.update(cross_attention_kwargs)
|
39 |
+
|
40 |
+
latents.requires_grad_(True)
|
41 |
+
latent_model_input = latents
|
42 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
43 |
+
|
44 |
+
unet(latent_model_input, t, encoder_hidden_states=cond_embeddings, return_cross_attention_probs=False, cross_attention_kwargs=full_cross_attention_kwargs)
|
45 |
+
|
46 |
+
# TODO: could return the attention maps for the required blocks only and not necessarily the final output
|
47 |
+
# update latents with guidance
|
48 |
+
loss = guidance.compute_ca_lossv3(saved_attn=saved_attn, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys, ref_ca_saved_attns=ref_ca_saved_attns, index=index, verbose=verbose, **kwargs) * loss_scale
|
49 |
+
|
50 |
+
if torch.isnan(loss):
|
51 |
+
print("**Loss is NaN**")
|
52 |
+
|
53 |
+
del full_cross_attention_kwargs, saved_attn
|
54 |
+
# call gc.collect() here may release some memory
|
55 |
+
|
56 |
+
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
|
57 |
+
|
58 |
+
latents.requires_grad_(False)
|
59 |
+
|
60 |
+
if hasattr(scheduler, 'sigmas'):
|
61 |
+
latents = latents - grad_cond * scheduler.sigmas[index] ** 2
|
62 |
+
elif hasattr(scheduler, 'alphas_cumprod'):
|
63 |
+
warnings.warn("Using guidance scaled with alphas_cumprod")
|
64 |
+
# Scaling with classifier guidance
|
65 |
+
alpha_prod_t = scheduler.alphas_cumprod[t]
|
66 |
+
# Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
|
67 |
+
# DDIM: https://arxiv.org/pdf/2010.02502.pdf
|
68 |
+
scale = (1 - alpha_prod_t) ** (0.5)
|
69 |
+
latents = latents - scale * grad_cond
|
70 |
+
else:
|
71 |
+
# NOTE: no scaling is performed
|
72 |
+
warnings.warn("No scaling in guidance is performed")
|
73 |
+
latents = latents - grad_cond
|
74 |
+
iteration += 1
|
75 |
+
|
76 |
+
if clear_cache:
|
77 |
+
utils.free_memory()
|
78 |
+
|
79 |
+
if verbose:
|
80 |
+
print(f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}")
|
81 |
+
|
82 |
+
return latents, loss
|
83 |
|
84 |
@torch.no_grad()
|
85 |
def encode(model_dict, image, generator):
|
|
|
126 |
|
127 |
return images
|
128 |
|
129 |
+
def generate_semantic_guidance(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, object_positions, guidance_scale = 7.5, semantic_guidance_kwargs=None,
|
130 |
+
return_cross_attn=False, return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None, offload_guidance_cross_attn_to_cpu=False,
|
131 |
+
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True, return_box_vis=False, show_progress=True, save_all_latents=False,
|
132 |
+
dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2, use_boxdiff=False):
|
133 |
+
"""
|
134 |
+
object_positions: object indices in text tokens
|
135 |
+
return_cross_attn: should be deprecated. Use `return_saved_cross_attn` and the new format.
|
136 |
+
"""
|
137 |
+
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
|
138 |
+
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
|
139 |
+
|
140 |
+
# Just in case that we have in-place ops
|
141 |
+
latents = latents.clone()
|
142 |
+
|
143 |
+
if save_all_latents:
|
144 |
+
# offload to cpu to save space
|
145 |
+
if offload_latents_to_cpu:
|
146 |
+
latents_all = [latents.cpu()]
|
147 |
+
else:
|
148 |
+
latents_all = [latents]
|
149 |
+
|
150 |
+
scheduler.set_timesteps(num_inference_steps)
|
151 |
+
if fast_after_steps is not None:
|
152 |
+
scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
|
153 |
+
|
154 |
+
if dynamic_num_inference_steps:
|
155 |
+
original_num_inference_steps = scheduler.num_inference_steps
|
156 |
+
|
157 |
+
cross_attention_probs_down = []
|
158 |
+
cross_attention_probs_mid = []
|
159 |
+
cross_attention_probs_up = []
|
160 |
+
|
161 |
+
loss = torch.tensor(10000.)
|
162 |
+
|
163 |
+
# TODO: we can also save necessary tokens only to save memory.
|
164 |
+
# offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
|
165 |
+
guidance_cross_attention_kwargs = {
|
166 |
+
'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
|
167 |
+
'enable_flash_attn': False
|
168 |
+
}
|
169 |
+
|
170 |
+
if return_saved_cross_attn:
|
171 |
+
saved_attns = []
|
172 |
+
|
173 |
+
main_cross_attention_kwargs = {
|
174 |
+
'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
|
175 |
+
'return_cond_ca_only': return_cond_ca_only,
|
176 |
+
'return_token_ca_only': return_token_ca_only,
|
177 |
+
'save_keys': saved_cross_attn_keys,
|
178 |
+
}
|
179 |
+
|
180 |
+
# Repeating keys leads to different weights for each key.
|
181 |
+
# assert len(set(semantic_guidance_kwargs['guidance_attn_keys'])) == len(semantic_guidance_kwargs['guidance_attn_keys']), f"guidance_attn_keys not unique: {semantic_guidance_kwargs['guidance_attn_keys']}"
|
182 |
+
|
183 |
+
for index, t in enumerate(tqdm(scheduler.timesteps, disable=not show_progress)):
|
184 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
185 |
+
|
186 |
+
if bboxes:
|
187 |
+
if use_boxdiff:
|
188 |
+
latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
|
189 |
+
else:
|
190 |
+
# If encountered None in `guidance_attn_keys`, please be sure to check whether `guidance_attn_keys` is added in `semantic_guidance_kwargs`. Default value has been removed.
|
191 |
+
latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
|
192 |
+
|
193 |
+
# predict the noise residual
|
194 |
+
with torch.no_grad():
|
195 |
+
latent_model_input = torch.cat([latents] * 2)
|
196 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
|
197 |
+
|
198 |
+
main_cross_attention_kwargs['save_attn_to_dict'] = {}
|
199 |
+
|
200 |
+
unet_output = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, return_cross_attention_probs=return_cross_attn, cross_attention_kwargs=main_cross_attention_kwargs)
|
201 |
+
noise_pred = unet_output.sample
|
202 |
+
|
203 |
+
if return_cross_attn:
|
204 |
+
cross_attention_probs_down.append(unet_output.cross_attention_probs_down)
|
205 |
+
cross_attention_probs_mid.append(unet_output.cross_attention_probs_mid)
|
206 |
+
cross_attention_probs_up.append(unet_output.cross_attention_probs_up)
|
207 |
+
|
208 |
+
if return_saved_cross_attn:
|
209 |
+
saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
|
210 |
+
|
211 |
+
del main_cross_attention_kwargs['save_attn_to_dict']
|
212 |
+
|
213 |
+
# perform guidance
|
214 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
215 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
216 |
+
|
217 |
+
if dynamic_num_inference_steps:
|
218 |
+
schedule.dynamically_adjust_inference_steps(scheduler, index, t)
|
219 |
+
|
220 |
+
# compute the previous noisy sample x_t -> x_t-1
|
221 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
222 |
+
|
223 |
+
if save_all_latents:
|
224 |
+
if offload_latents_to_cpu:
|
225 |
+
latents_all.append(latents.cpu())
|
226 |
+
else:
|
227 |
+
latents_all.append(latents)
|
228 |
+
|
229 |
+
if dynamic_num_inference_steps:
|
230 |
+
# Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
|
231 |
+
scheduler.num_inference_steps = original_num_inference_steps
|
232 |
+
|
233 |
+
images = decode(vae, latents)
|
234 |
+
|
235 |
+
ret = [latents, images]
|
236 |
+
|
237 |
+
if return_cross_attn:
|
238 |
+
ret.append((cross_attention_probs_down, cross_attention_probs_mid, cross_attention_probs_up))
|
239 |
+
if return_saved_cross_attn:
|
240 |
+
ret.append(saved_attns)
|
241 |
+
if return_box_vis:
|
242 |
+
pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
|
243 |
+
ret.append(pil_images)
|
244 |
+
if save_all_latents:
|
245 |
+
latents_all = torch.stack(latents_all, dim=0)
|
246 |
+
ret.append(latents_all)
|
247 |
+
return tuple(ret)
|
248 |
+
|
249 |
@torch.no_grad()
|
250 |
def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False, scheduler_key='dpm_scheduler'):
|
251 |
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
|
|
|
325 |
frozen_steps=20, frozen_mask=None,
|
326 |
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
|
327 |
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
|
328 |
+
semantic_guidance=False, semantic_guidance_bboxes=None, semantic_guidance_object_positions=None, semantic_guidance_kwargs=None,
|
329 |
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
|
330 |
"""
|
331 |
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
|
332 |
+
batched:
|
333 |
+
Enabled: bboxes and phrases should be a list (batch dimension) of items (specify the bboxes/phrases of each image in the batch).
|
334 |
+
Disabled: bboxes and phrases should be a list of bboxes and phrases specifying the bboxes/phrases of one image (no batch dimension).
|
335 |
"""
|
336 |
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
|
337 |
|
|
|
358 |
if fast_after_steps is not None:
|
359 |
scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
|
360 |
|
361 |
+
if dynamic_num_inference_steps:
|
362 |
+
original_num_inference_steps = scheduler.num_inference_steps
|
363 |
+
|
364 |
if frozen_mask is not None:
|
365 |
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
366 |
|
|
|
371 |
|
372 |
boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
|
373 |
|
374 |
+
if semantic_guidance_bboxes and semantic_guidance:
|
375 |
+
loss = torch.tensor(10000.)
|
376 |
+
# TODO: we can also save necessary tokens only to save memory.
|
377 |
+
# offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
|
378 |
+
guidance_cross_attention_kwargs = {
|
379 |
+
'offload_cross_attn_to_cpu': False,
|
380 |
+
'enable_flash_attn': False,
|
381 |
+
'gligen': {
|
382 |
+
'boxes': boxes[:condition_len // 2],
|
383 |
+
'positive_embeddings': phrase_embeddings[:condition_len // 2],
|
384 |
+
'masks': masks[:condition_len // 2],
|
385 |
+
'fuser_attn_kwargs': {
|
386 |
+
'enable_flash_attn': False,
|
387 |
+
}
|
388 |
+
}
|
389 |
+
}
|
390 |
+
|
391 |
if return_saved_cross_attn:
|
392 |
saved_attns = []
|
393 |
|
|
|
413 |
if index == num_grounding_steps:
|
414 |
gligen_enable_fuser(unet, False)
|
415 |
|
416 |
+
if semantic_guidance_bboxes and semantic_guidance:
|
417 |
+
with torch.enable_grad():
|
418 |
+
latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, semantic_guidance_bboxes, semantic_guidance_object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
|
419 |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
420 |
latent_model_input = torch.cat([latents] * 2)
|
421 |
|
|
|
435 |
# perform guidance
|
436 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
437 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
438 |
+
|
439 |
if dynamic_num_inference_steps:
|
440 |
schedule.dynamically_adjust_inference_steps(scheduler, index, t)
|
441 |
|
|
|
445 |
if frozen_mask is not None and index < frozen_steps:
|
446 |
latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
|
447 |
|
448 |
+
# Do not save the latents in the fast steps
|
449 |
if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
|
450 |
if offload_latents_to_cpu:
|
451 |
latents_all.append(latents.cpu())
|
452 |
else:
|
453 |
latents_all.append(latents)
|
454 |
|
455 |
+
if dynamic_num_inference_steps:
|
456 |
+
# Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
|
457 |
+
scheduler.num_inference_steps = original_num_inference_steps
|
458 |
+
|
459 |
# Turn off fuser for typical SD
|
460 |
gligen_enable_fuser(unet, False)
|
461 |
images = decode(vae, latents)
|
|
|
472 |
|
473 |
return tuple(ret)
|
474 |
|
475 |
+
|
476 |
+
def get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength):
|
477 |
+
# get the original timestep using init_timestep
|
478 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
479 |
+
|
480 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
481 |
+
|
482 |
+
# safety for t_start overflow to prevent empty timsteps slice
|
483 |
+
if t_start == 0:
|
484 |
+
return inverse_scheduler.timesteps, num_inference_steps
|
485 |
+
timesteps = inverse_scheduler.timesteps[:-t_start]
|
486 |
+
|
487 |
+
return timesteps, num_inference_steps - t_start
|
488 |
+
|
489 |
+
@torch.no_grad()
|
490 |
+
def invert(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5):
|
491 |
+
"""
|
492 |
+
latents: encoded from the image, should not have noise (t = 0)
|
493 |
+
|
494 |
+
returns inverted_latents for all time steps
|
495 |
+
"""
|
496 |
+
vae, tokenizer, text_encoder, unet, scheduler, inverse_scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.inverse_scheduler, model_dict.dtype
|
497 |
+
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
|
498 |
+
|
499 |
+
inverse_scheduler.set_timesteps(num_inference_steps, device=latents.device)
|
500 |
+
# We need to invert all steps because we need them to generate the background.
|
501 |
+
timesteps, num_inference_steps = get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength=1.0)
|
502 |
+
|
503 |
+
inverted_latents = [latents.cpu()]
|
504 |
+
for t in tqdm(timesteps[:-1]):
|
505 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
506 |
+
if guidance_scale > 0.:
|
507 |
+
latent_model_input = torch.cat([latents] * 2)
|
508 |
+
|
509 |
+
latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)
|
510 |
+
|
511 |
+
# predict the noise residual
|
512 |
+
with torch.no_grad():
|
513 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
514 |
+
|
515 |
+
# perform guidance
|
516 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
517 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
518 |
+
else:
|
519 |
+
latent_model_input = latents
|
520 |
+
|
521 |
+
latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)
|
522 |
+
|
523 |
+
# predict the noise residual
|
524 |
+
with torch.no_grad():
|
525 |
+
noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample
|
526 |
+
|
527 |
+
# perform guidance
|
528 |
+
noise_pred = noise_pred_uncond
|
529 |
+
|
530 |
+
# compute the previous noisy sample x_t -> x_t-1
|
531 |
+
latents = inverse_scheduler.step(noise_pred, t, latents).prev_sample
|
532 |
+
|
533 |
+
inverted_latents.append(latents.cpu())
|
534 |
+
|
535 |
+
assert len(inverted_latents) == len(timesteps)
|
536 |
+
# timestep is the first dimension
|
537 |
+
inverted_latents = torch.stack(list(reversed(inverted_latents)), dim=0)
|
538 |
+
|
539 |
+
return inverted_latents
|
540 |
+
|
541 |
+
def generate_partial_frozen(model_dict, latents_all, frozen_mask, input_embeddings, num_inference_steps, frozen_steps, guidance_scale = 7.5, bboxes=None, phrases=None, object_positions=None, semantic_guidance_kwargs=None, offload_guidance_cross_attn_to_cpu=False, use_boxdiff=False):
|
542 |
+
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
|
543 |
+
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
|
544 |
+
|
545 |
+
scheduler.set_timesteps(num_inference_steps)
|
546 |
+
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
547 |
+
|
548 |
+
latents = latents_all[0]
|
549 |
+
|
550 |
+
if bboxes:
|
551 |
+
# With semantic guidance
|
552 |
+
loss = torch.tensor(10000.)
|
553 |
+
|
554 |
+
# offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
|
555 |
+
guidance_cross_attention_kwargs = {
|
556 |
+
'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
|
557 |
+
# Getting invalid argument on backward, probably due to insufficient shared memory
|
558 |
+
'enable_flash_attn': False
|
559 |
+
}
|
560 |
+
|
561 |
+
for index, t in enumerate(tqdm(scheduler.timesteps)):
|
562 |
+
if bboxes:
|
563 |
+
# With semantic guidance, `guidance_attn_keys` should be in `semantic_guidance_kwargs`
|
564 |
+
if use_boxdiff:
|
565 |
+
latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
|
566 |
+
else:
|
567 |
+
latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
|
568 |
+
|
569 |
+
with torch.no_grad():
|
570 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
571 |
+
latent_model_input = torch.cat([latents] * 2)
|
572 |
+
|
573 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
|
574 |
+
|
575 |
+
# predict the noise residual
|
576 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
577 |
+
|
578 |
+
# perform guidance
|
579 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
580 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
581 |
+
|
582 |
+
# compute the previous noisy sample x_t -> x_t-1
|
583 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
584 |
+
|
585 |
+
if index < frozen_steps:
|
586 |
+
latents = latents_all[index+1] * frozen_mask + latents * (1. - frozen_mask)
|
587 |
+
|
588 |
+
# scale and decode the image latents with vae
|
589 |
+
scaled_latents = 1 / 0.18215 * latents
|
590 |
+
with torch.no_grad():
|
591 |
+
image = vae.decode(scaled_latents).sample
|
592 |
+
|
593 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
594 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
595 |
+
images = (image * 255).round().astype("uint8")
|
596 |
+
|
597 |
+
ret = [latents, images]
|
598 |
+
|
599 |
+
return tuple(ret)
|
models/sam.py
CHANGED
@@ -164,8 +164,10 @@ def sam_refine_attn(sam_input_image, token_attn_np, model_dict, height, width, H
|
|
164 |
return mask_selected, conf_score_selected
|
165 |
|
166 |
def sam_refine_box(sam_input_image, box, *args, **kwargs):
|
167 |
-
|
168 |
-
|
|
|
|
|
169 |
|
170 |
def sam_refine_boxes(sam_input_images, boxes, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
|
171 |
# (w, h)
|
|
|
164 |
return mask_selected, conf_score_selected
|
165 |
|
166 |
def sam_refine_box(sam_input_image, box, *args, **kwargs):
|
167 |
+
# One image with one box
|
168 |
+
sam_input_images, boxes = [sam_input_image], [[box]]
|
169 |
+
mask_selected_batched_list, conf_score_selected_batched_list = sam_refine_boxes(sam_input_images, boxes, *args, **kwargs)
|
170 |
+
return mask_selected_batched_list[0][0], conf_score_selected_batched_list[0][0]
|
171 |
|
172 |
def sam_refine_boxes(sam_input_images, boxes, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
|
173 |
# (w, h)
|
utils/attn.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# visualization-related functions are in vis
|
2 |
+
import numbers
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
import utils
|
8 |
+
|
9 |
+
def get_token_attnv2(token_id, saved_attns, attn_key, visualize_step_start=10, input_ca_has_condition_only=False, return_np=False):
|
10 |
+
"""
|
11 |
+
saved_attns: a list of saved_attn (list is across timesteps)
|
12 |
+
|
13 |
+
moves to cpu by default
|
14 |
+
"""
|
15 |
+
saved_attns = saved_attns[visualize_step_start:]
|
16 |
+
|
17 |
+
saved_attns = [saved_attn[attn_key].cpu() for saved_attn in saved_attns]
|
18 |
+
|
19 |
+
attn = torch.stack(saved_attns, dim=0).mean(dim=0)
|
20 |
+
|
21 |
+
# print("attn shape", attn.shape)
|
22 |
+
|
23 |
+
# attn: (batch, head, spatial, text)
|
24 |
+
|
25 |
+
if not input_ca_has_condition_only:
|
26 |
+
assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
|
27 |
+
attn = attn[1]
|
28 |
+
else:
|
29 |
+
assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
|
30 |
+
attn = attn[0]
|
31 |
+
attn = attn.mean(dim=0)[:, token_id]
|
32 |
+
H = W = int(math.sqrt(attn.shape[0]))
|
33 |
+
attn = attn.reshape((H, W))
|
34 |
+
|
35 |
+
if return_np:
|
36 |
+
return attn.numpy()
|
37 |
+
|
38 |
+
return attn
|
39 |
+
|
40 |
+
def shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, horizontal_shift_only=False):
|
41 |
+
"""
|
42 |
+
`horizontal_shift_only`: only shift horizontally. If you use `offset` from `compose_latents_with_alignment` with `horizontal_shift_only=True`, the `offset` already has y_offset = 0 and this option is not needed.
|
43 |
+
"""
|
44 |
+
x_offset, y_offset = offset
|
45 |
+
if horizontal_shift_only:
|
46 |
+
y_offset = 0.
|
47 |
+
|
48 |
+
new_saved_attns_item = {}
|
49 |
+
for k in guidance_attn_keys:
|
50 |
+
attn_map = saved_attns_item[k]
|
51 |
+
|
52 |
+
attn_size = attn_map.shape[-2]
|
53 |
+
attn_h = attn_w = int(math.sqrt(attn_size))
|
54 |
+
# Example dimensions: [batch_size, num_heads, 8, 8, num_tokens]
|
55 |
+
attn_map = attn_map.unflatten(2, (attn_h, attn_w))
|
56 |
+
attn_map = utils.shift_tensor(
|
57 |
+
attn_map, x_offset, y_offset,
|
58 |
+
offset_normalized=True, ignore_last_dim=True
|
59 |
+
)
|
60 |
+
attn_map = attn_map.flatten(2, 3)
|
61 |
+
|
62 |
+
new_saved_attns_item[k] = attn_map
|
63 |
+
|
64 |
+
return new_saved_attns_item
|
65 |
+
|
66 |
+
def shift_saved_attns(saved_attns, offset, guidance_attn_keys, **kwargs):
|
67 |
+
# Iterate over timesteps
|
68 |
+
shifted_saved_attns = [shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, **kwargs) for saved_attns_item in saved_attns]
|
69 |
+
|
70 |
+
return shifted_saved_attns
|
71 |
+
|
72 |
+
|
73 |
+
class GaussianSmoothing(nn.Module):
|
74 |
+
"""
|
75 |
+
Apply gaussian smoothing on a
|
76 |
+
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
77 |
+
in the input using a depthwise convolution.
|
78 |
+
Arguments:
|
79 |
+
channels (int, sequence): Number of channels of the input tensors. Output will
|
80 |
+
have this number of channels as well.
|
81 |
+
kernel_size (int, sequence): Size of the gaussian kernel.
|
82 |
+
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
83 |
+
dim (int, optional): The number of dimensions of the data.
|
84 |
+
Default value is 2 (spatial).
|
85 |
+
|
86 |
+
Credit: https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, channels, kernel_size, sigma, dim=2):
|
90 |
+
super(GaussianSmoothing, self).__init__()
|
91 |
+
if isinstance(kernel_size, numbers.Number):
|
92 |
+
kernel_size = [kernel_size] * dim
|
93 |
+
if isinstance(sigma, numbers.Number):
|
94 |
+
sigma = [sigma] * dim
|
95 |
+
|
96 |
+
# The gaussian kernel is the product of the
|
97 |
+
# gaussian function of each dimension.
|
98 |
+
kernel = 1
|
99 |
+
meshgrids = torch.meshgrid(
|
100 |
+
[
|
101 |
+
torch.arange(size, dtype=torch.float32)
|
102 |
+
for size in kernel_size
|
103 |
+
]
|
104 |
+
)
|
105 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
106 |
+
mean = (size - 1) / 2
|
107 |
+
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
108 |
+
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
|
109 |
+
|
110 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
111 |
+
kernel = kernel / torch.sum(kernel)
|
112 |
+
|
113 |
+
# Reshape to depthwise convolutional weight
|
114 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
115 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
116 |
+
|
117 |
+
self.register_buffer('weight', kernel)
|
118 |
+
self.groups = channels
|
119 |
+
|
120 |
+
if dim == 1:
|
121 |
+
self.conv = F.conv1d
|
122 |
+
elif dim == 2:
|
123 |
+
self.conv = F.conv2d
|
124 |
+
elif dim == 3:
|
125 |
+
self.conv = F.conv3d
|
126 |
+
else:
|
127 |
+
raise RuntimeError(
|
128 |
+
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(
|
129 |
+
dim)
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(self, input):
|
133 |
+
"""
|
134 |
+
Apply gaussian filter to input.
|
135 |
+
Arguments:
|
136 |
+
input (torch.Tensor): Input to apply gaussian filter on.
|
137 |
+
Returns:
|
138 |
+
filtered (torch.Tensor): Filtered output.
|
139 |
+
"""
|
140 |
+
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
|
utils/boxdiff.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an reimplementation boxdiff baseline for reference and comparison. It is not used in the Web UI and not enabled by default since the current attention guidance implementation (in `guidance`), which uses attention maps from multiple levels and attention transfer, seems to be more robust and coherent.
|
3 |
+
|
4 |
+
Credit: https://github.com/showlab/BoxDiff/blob/master/pipeline/sd_pipeline_boxdiff.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import math
|
10 |
+
import warnings
|
11 |
+
import gc
|
12 |
+
from collections.abc import Iterable
|
13 |
+
import utils
|
14 |
+
from . import guidance
|
15 |
+
from .attn import GaussianSmoothing
|
16 |
+
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
18 |
+
|
19 |
+
|
20 |
+
def _compute_max_attention_per_index(attention_maps: torch.Tensor,
|
21 |
+
object_positions: List[List[int]],
|
22 |
+
smooth_attentions: bool = False,
|
23 |
+
sigma: float = 0.5,
|
24 |
+
kernel_size: int = 3,
|
25 |
+
normalize_eot: bool = False,
|
26 |
+
bboxes: List[List[int]] = None,
|
27 |
+
P: float = 0.2,
|
28 |
+
L: int = 1,
|
29 |
+
) -> List[torch.Tensor]:
|
30 |
+
""" Computes the maximum attention value for each of the tokens we wish to alter. """
|
31 |
+
last_idx = -1
|
32 |
+
assert not normalize_eot, "normalize_eot is unimplemented"
|
33 |
+
|
34 |
+
attention_for_text = attention_maps[:, :, 1:last_idx]
|
35 |
+
attention_for_text *= 100
|
36 |
+
attention_for_text = F.softmax(attention_for_text, dim=-1)
|
37 |
+
|
38 |
+
# Extract the maximum values
|
39 |
+
max_indices_list_fg = []
|
40 |
+
max_indices_list_bg = []
|
41 |
+
dist_x = []
|
42 |
+
dist_y = []
|
43 |
+
|
44 |
+
for obj_idx, text_positions_per_obj in enumerate(object_positions):
|
45 |
+
for text_position_per_obj in text_positions_per_obj:
|
46 |
+
# Shift indices since we removed the first token
|
47 |
+
image = attention_for_text[:, :, text_position_per_obj - 1]
|
48 |
+
H, W = image.shape
|
49 |
+
|
50 |
+
obj_mask = torch.zeros_like(image)
|
51 |
+
corner_mask_x = torch.zeros(
|
52 |
+
(W,), device=obj_mask.device, dtype=obj_mask.dtype)
|
53 |
+
corner_mask_y = torch.zeros(
|
54 |
+
(H,), device=obj_mask.device, dtype=obj_mask.dtype)
|
55 |
+
|
56 |
+
obj_boxes = bboxes[obj_idx]
|
57 |
+
|
58 |
+
# We support two level (one box per phrase) and three level (multiple boxes per phrase)
|
59 |
+
if not isinstance(obj_boxes[0], Iterable):
|
60 |
+
obj_boxes = [obj_boxes]
|
61 |
+
|
62 |
+
for obj_box in obj_boxes:
|
63 |
+
x_min, y_min, x_max, y_max = utils.scale_proportion(
|
64 |
+
obj_box, H=H, W=W)
|
65 |
+
obj_mask[y_min: y_max, x_min: x_max] = 1
|
66 |
+
|
67 |
+
corner_mask_x[max(x_min - L, 0): min(x_min + L + 1, W)] = 1.
|
68 |
+
corner_mask_x[max(x_max - L, 0): min(x_max + L + 1, W)] = 1.
|
69 |
+
corner_mask_y[max(y_min - L, 0): min(y_min + L + 1, H)] = 1.
|
70 |
+
corner_mask_y[max(y_max - L, 0): min(y_max + L + 1, H)] = 1.
|
71 |
+
|
72 |
+
bg_mask = 1 - obj_mask
|
73 |
+
|
74 |
+
if smooth_attentions:
|
75 |
+
smoothing = GaussianSmoothing(
|
76 |
+
channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
|
77 |
+
input = F.pad(image.unsqueeze(0).unsqueeze(0),
|
78 |
+
(1, 1, 1, 1), mode='reflect')
|
79 |
+
image = smoothing(input).squeeze(0).squeeze(0)
|
80 |
+
|
81 |
+
# Inner-Box constraint
|
82 |
+
k = (obj_mask.sum() * P).long()
|
83 |
+
max_indices_list_fg.append(
|
84 |
+
(image * obj_mask).reshape(-1).topk(k)[0].mean())
|
85 |
+
|
86 |
+
# Outer-Box constraint
|
87 |
+
k = (bg_mask.sum() * P).long()
|
88 |
+
max_indices_list_bg.append(
|
89 |
+
(image * bg_mask).reshape(-1).topk(k)[0].mean())
|
90 |
+
|
91 |
+
# Corner Constraint
|
92 |
+
gt_proj_x = torch.max(obj_mask, dim=0).values
|
93 |
+
gt_proj_y = torch.max(obj_mask, dim=1).values
|
94 |
+
|
95 |
+
# create gt according to the number L
|
96 |
+
dist_x.append((F.l1_loss(image.max(dim=0)[
|
97 |
+
0], gt_proj_x, reduction='none') * corner_mask_x).mean())
|
98 |
+
dist_y.append((F.l1_loss(image.max(dim=1)[
|
99 |
+
0], gt_proj_y, reduction='none') * corner_mask_y).mean())
|
100 |
+
|
101 |
+
return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y
|
102 |
+
|
103 |
+
|
104 |
+
def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor],
|
105 |
+
dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor:
|
106 |
+
""" Computes the attend-and-excite loss using the maximum attention value for each token. """
|
107 |
+
losses_fg = [max(0, 1. - curr_max)
|
108 |
+
for curr_max in max_attention_per_index_fg]
|
109 |
+
losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg]
|
110 |
+
loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y)
|
111 |
+
|
112 |
+
# print(f"{losses_fg}, {losses_bg}, {dist_x}, {dist_y}, {loss}")
|
113 |
+
|
114 |
+
if return_losses:
|
115 |
+
return max(losses_fg), losses_fg
|
116 |
+
else:
|
117 |
+
return max(losses_fg), loss
|
118 |
+
|
119 |
+
|
120 |
+
def compute_ca_loss_boxdiff(saved_attn, bboxes, object_positions, guidance_attn_keys, ref_ca_saved_attns=None, ref_ca_last_token_only=True, ref_ca_word_token_only=False, word_token_indices=None, index=None, ref_ca_loss_weight=1.0, verbose=False, **kwargs):
|
121 |
+
"""
|
122 |
+
v3 is equivalent to v2 but with new dictionary format for attention maps.
|
123 |
+
The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
|
124 |
+
`AttnProcessor` will put attention maps into the `save_attn_to_dict`.
|
125 |
+
|
126 |
+
`index` is the timestep.
|
127 |
+
`ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token).
|
128 |
+
`ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens.
|
129 |
+
"""
|
130 |
+
loss = torch.tensor(0).float().cuda()
|
131 |
+
object_number = len(bboxes)
|
132 |
+
if object_number == 0:
|
133 |
+
return loss
|
134 |
+
|
135 |
+
attn_map_list = []
|
136 |
+
|
137 |
+
for attn_key in guidance_attn_keys:
|
138 |
+
# We only have 1 cross attention for mid.
|
139 |
+
attn_map_integrated = saved_attn[attn_key]
|
140 |
+
if not attn_map_integrated.is_cuda:
|
141 |
+
attn_map_integrated = attn_map_integrated.cuda()
|
142 |
+
# Example dimension: [20, 64, 77]
|
143 |
+
attn_map = attn_map_integrated.squeeze(dim=0)
|
144 |
+
attn_map_list.append(attn_map)
|
145 |
+
# This averages both across layers and across attention heads
|
146 |
+
attn_map = torch.cat(attn_map_list, dim=0).mean(dim=0)
|
147 |
+
loss = add_ca_loss_per_attn_map_to_loss_boxdiff(
|
148 |
+
loss, attn_map, object_number, bboxes, object_positions, verbose=verbose, **kwargs)
|
149 |
+
|
150 |
+
if ref_ca_saved_attns is not None:
|
151 |
+
warnings.warn('Attention reference loss is enabled in boxdiff mode. The original boxdiff does not have attention reference loss.')
|
152 |
+
|
153 |
+
ref_loss = torch.tensor(0).float().cuda()
|
154 |
+
ref_loss = guidance.add_ref_ca_loss_per_attn_map_to_lossv2(
|
155 |
+
ref_loss, saved_attn=saved_attn, object_number=object_number, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys,
|
156 |
+
ref_ca_saved_attns=ref_ca_saved_attns, ref_ca_last_token_only=ref_ca_last_token_only, ref_ca_word_token_only=ref_ca_word_token_only, word_token_indices=word_token_indices, verbose=verbose, index=index, loss_weight=ref_ca_loss_weight
|
157 |
+
)
|
158 |
+
print(f"loss {loss.item():.3f}, reference attention loss (weighted) {ref_loss.item():.3f}")
|
159 |
+
loss += ref_loss
|
160 |
+
|
161 |
+
return loss
|
162 |
+
|
163 |
+
|
164 |
+
def add_ca_loss_per_attn_map_to_loss_boxdiff(original_loss, attention_maps, object_number, bboxes, object_positions, P=0.2, L=1, smooth_attentions=True, sigma=0.5, kernel_size=3, normalize_eot=False, verbose=False):
|
165 |
+
# NOTE: normalize_eot is enabled in SD v2.1 in boxdiff
|
166 |
+
i, j = attention_maps.shape
|
167 |
+
H = W = int(math.sqrt(i))
|
168 |
+
|
169 |
+
attention_maps = attention_maps.view(H, W, j)
|
170 |
+
# attention_maps is aggregated cross attn map across layers and steps
|
171 |
+
# attention_maps shape: [H, W, 77]
|
172 |
+
max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = _compute_max_attention_per_index(
|
173 |
+
attention_maps=attention_maps,
|
174 |
+
object_positions=object_positions,
|
175 |
+
smooth_attentions=smooth_attentions,
|
176 |
+
sigma=sigma,
|
177 |
+
kernel_size=kernel_size,
|
178 |
+
normalize_eot=normalize_eot,
|
179 |
+
bboxes=bboxes,
|
180 |
+
P=P,
|
181 |
+
L=L
|
182 |
+
)
|
183 |
+
|
184 |
+
_, loss = _compute_loss(max_attention_per_index_fg,
|
185 |
+
max_attention_per_index_bg, dist_x, dist_y)
|
186 |
+
|
187 |
+
return original_loss + loss
|
188 |
+
|
189 |
+
|
190 |
+
def latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, amp_loss_scale=10, latent_scale=20, scale_range=(1., 0.5), max_index_step=25, cross_attention_kwargs=None, ref_ca_saved_attns=None, guidance_attn_keys=None, verbose=False, **kwargs):
|
191 |
+
"""
|
192 |
+
amp_loss_scale: this scales the loss but will de-scale before applying for latents. This is to prevent overflow/underflow with amp, not to adjust the update step size.
|
193 |
+
latent_scale: this scales the step size for update (scale_factor in boxdiff).
|
194 |
+
"""
|
195 |
+
|
196 |
+
if index < max_index_step:
|
197 |
+
saved_attn = {}
|
198 |
+
full_cross_attention_kwargs = {
|
199 |
+
'save_attn_to_dict': saved_attn,
|
200 |
+
'save_keys': guidance_attn_keys,
|
201 |
+
}
|
202 |
+
|
203 |
+
if cross_attention_kwargs is not None:
|
204 |
+
full_cross_attention_kwargs.update(cross_attention_kwargs)
|
205 |
+
|
206 |
+
latents.requires_grad_(True)
|
207 |
+
latent_model_input = latents
|
208 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
209 |
+
|
210 |
+
unet(latent_model_input, t, encoder_hidden_states=cond_embeddings,
|
211 |
+
return_cross_attention_probs=False, cross_attention_kwargs=full_cross_attention_kwargs)
|
212 |
+
|
213 |
+
# TODO: could return the attention maps for the required blocks only and not necessarily the final output
|
214 |
+
# update latents with guidance
|
215 |
+
loss = compute_ca_loss_boxdiff(saved_attn=saved_attn, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys,
|
216 |
+
ref_ca_saved_attns=ref_ca_saved_attns, index=index, verbose=verbose, **kwargs) * amp_loss_scale
|
217 |
+
|
218 |
+
if torch.isnan(loss):
|
219 |
+
print("**Loss is NaN**")
|
220 |
+
|
221 |
+
del full_cross_attention_kwargs, saved_attn
|
222 |
+
# call gc.collect() here may release some memory
|
223 |
+
|
224 |
+
grad_cond = torch.autograd.grad(
|
225 |
+
loss.requires_grad_(True), [latents])[0]
|
226 |
+
|
227 |
+
latents.requires_grad_(False)
|
228 |
+
|
229 |
+
if True:
|
230 |
+
warnings.warn("Using guidance scaled with sqrt scale")
|
231 |
+
# According to boxdiff's implementation: https://github.com/Sierkinhane/BoxDiff/blob/16ffb677a9128128e04553a0200870a526731be0/pipeline/sd_pipeline_boxdiff.py#L616
|
232 |
+
scale = (scale_range[0] + (scale_range[1] - scale_range[0])
|
233 |
+
* index / (len(scheduler.timesteps) - 1)) ** (0.5)
|
234 |
+
latents = latents - latent_scale * scale / amp_loss_scale * grad_cond
|
235 |
+
elif hasattr(scheduler, 'sigmas'):
|
236 |
+
warnings.warn("Using guidance scaled with sigmas")
|
237 |
+
scale = scheduler.sigmas[index] ** 2
|
238 |
+
latents = latents - grad_cond * scale
|
239 |
+
elif hasattr(scheduler, 'alphas_cumprod'):
|
240 |
+
warnings.warn("Using guidance scaled with alphas_cumprod")
|
241 |
+
# Scaling with classifier guidance
|
242 |
+
alpha_prod_t = scheduler.alphas_cumprod[t]
|
243 |
+
# Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
|
244 |
+
# DDIM: https://arxiv.org/pdf/2010.02502.pdf
|
245 |
+
scale = (1 - alpha_prod_t) ** (0.5)
|
246 |
+
latents = latents - latent_scale * scale / amp_loss_scale * grad_cond
|
247 |
+
else:
|
248 |
+
warnings.warn("No scaling in guidance is performed")
|
249 |
+
scale = 1
|
250 |
+
latents = latents - grad_cond
|
251 |
+
|
252 |
+
gc.collect()
|
253 |
+
torch.cuda.empty_cache()
|
254 |
+
|
255 |
+
if verbose:
|
256 |
+
print(
|
257 |
+
f"time index {index}, loss: {loss.item() / amp_loss_scale:.3f} (de-scaled with scale {amp_loss_scale:.1f}), latent grad scale: {scale:.3f}")
|
258 |
+
|
259 |
+
return latents, loss
|
utils/guidance.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
from collections.abc import Iterable
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
import utils
|
8 |
+
|
9 |
+
# A list mapping: prompt index to str (prompt in a list of token str)
|
10 |
+
def get_token_map(tokenizer, prompt, verbose=False, padding="do_not_pad"):
|
11 |
+
fg_prompt_tokens = tokenizer([prompt], padding=padding, max_length=77, return_tensors="np")
|
12 |
+
input_ids = fg_prompt_tokens['input_ids'][0]
|
13 |
+
|
14 |
+
# index_to_last_with = np.max(np.where(input_ids == 593))
|
15 |
+
# index_to_last_eot = np.max(np.where(input_ids == 49407))
|
16 |
+
|
17 |
+
token_map = []
|
18 |
+
for ind, item in enumerate(input_ids.tolist()):
|
19 |
+
|
20 |
+
token = tokenizer._convert_id_to_token(item)
|
21 |
+
if verbose:
|
22 |
+
print(f"{ind}, {token} ({item})")
|
23 |
+
|
24 |
+
token_map.append(token)
|
25 |
+
|
26 |
+
# If we don't pad, we don't need to break.
|
27 |
+
# if item == tokenizer.eos_token_id:
|
28 |
+
# break
|
29 |
+
|
30 |
+
return token_map
|
31 |
+
|
32 |
+
def get_phrase_indices(tokenizer, prompt, phrases, verbose=False, words=None, include_eos=False, token_map=None, return_word_token_indices=False, add_suffix_if_not_found=False):
|
33 |
+
for obj in phrases:
|
34 |
+
# Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
|
35 |
+
if obj not in prompt:
|
36 |
+
prompt += "| " + obj
|
37 |
+
|
38 |
+
if token_map is None:
|
39 |
+
# We allow using a pre-computed token map.
|
40 |
+
token_map = get_token_map(tokenizer, prompt=prompt, verbose=verbose, padding="do_not_pad")
|
41 |
+
token_map_str = " ".join(token_map)
|
42 |
+
|
43 |
+
object_positions = []
|
44 |
+
word_token_indices = []
|
45 |
+
for obj_ind, obj in enumerate(phrases):
|
46 |
+
phrase_token_map = get_token_map(tokenizer, prompt=obj, verbose=verbose, padding="do_not_pad")
|
47 |
+
# Remove <bos> and <eos> in substr
|
48 |
+
phrase_token_map = phrase_token_map[1:-1]
|
49 |
+
phrase_token_map_len = len(phrase_token_map)
|
50 |
+
phrase_token_map_str = " ".join(phrase_token_map)
|
51 |
+
|
52 |
+
if verbose:
|
53 |
+
print("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
|
54 |
+
|
55 |
+
# Count the number of token before substr
|
56 |
+
# The substring comes with a trailing space that needs to be removed by minus one in the index.
|
57 |
+
obj_first_index = len(token_map_str[:token_map_str.index(phrase_token_map_str)-1].split(" "))
|
58 |
+
|
59 |
+
obj_position = list(range(obj_first_index, obj_first_index + phrase_token_map_len))
|
60 |
+
if include_eos:
|
61 |
+
obj_position.append(token_map.index(tokenizer.eos_token))
|
62 |
+
object_positions.append(obj_position)
|
63 |
+
|
64 |
+
if return_word_token_indices:
|
65 |
+
# Picking the last token in the specification
|
66 |
+
if words is None:
|
67 |
+
so_token_index = object_positions[0][-1]
|
68 |
+
# Picking the noun or perform pooling on attention with the tokens may be better
|
69 |
+
print(f"Picking the last token \"{token_map[so_token_index]}\" ({so_token_index}) as attention token for extracting attention for SAM, which might not be the right one")
|
70 |
+
else:
|
71 |
+
word = words[obj_ind]
|
72 |
+
word_token_map = get_token_map(tokenizer, prompt=word, verbose=verbose, padding="do_not_pad")
|
73 |
+
# Get the index of the last token of word (the occurrence in phrase) in the prompt. Note that we skip the <eos> token through indexing with -2.
|
74 |
+
so_token_index = obj_first_index + phrase_token_map.index(word_token_map[-2])
|
75 |
+
|
76 |
+
if verbose:
|
77 |
+
print("so_token_index:", so_token_index)
|
78 |
+
|
79 |
+
word_token_indices.append(so_token_index)
|
80 |
+
|
81 |
+
if return_word_token_indices:
|
82 |
+
if add_suffix_if_not_found:
|
83 |
+
return object_positions, word_token_indices, prompt
|
84 |
+
return object_positions, word_token_indices
|
85 |
+
|
86 |
+
if add_suffix_if_not_found:
|
87 |
+
return object_positions, prompt
|
88 |
+
|
89 |
+
return object_positions
|
90 |
+
|
91 |
+
def add_ca_loss_per_attn_map_to_loss(loss, attn_map, object_number, bboxes, object_positions, use_ratio_based_loss=True, fg_top_p=0.2, bg_top_p=0.2, fg_weight=1.0, bg_weight=1.0, verbose=False):
|
92 |
+
"""
|
93 |
+
fg_top_p, bg_top_p, fg_weight, and bg_weight are only used with max-based loss
|
94 |
+
"""
|
95 |
+
|
96 |
+
# Uncomment to debug:
|
97 |
+
# print(fg_top_p, bg_top_p, fg_weight, bg_weight)
|
98 |
+
|
99 |
+
# b is the number of heads, not batch
|
100 |
+
b, i, j = attn_map.shape
|
101 |
+
H = W = int(math.sqrt(i))
|
102 |
+
for obj_idx in range(object_number):
|
103 |
+
obj_loss = 0
|
104 |
+
mask = torch.zeros(size=(H, W), device="cuda")
|
105 |
+
obj_boxes = bboxes[obj_idx]
|
106 |
+
|
107 |
+
# We support two level (one box per phrase) and three level (multiple boxes per phrase)
|
108 |
+
if not isinstance(obj_boxes[0], Iterable):
|
109 |
+
obj_boxes = [obj_boxes]
|
110 |
+
|
111 |
+
for obj_box in obj_boxes:
|
112 |
+
# x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
|
113 |
+
x_min, y_min, x_max, y_max = utils.scale_proportion(obj_box, H=H, W=W)
|
114 |
+
mask[y_min: y_max, x_min: x_max] = 1
|
115 |
+
|
116 |
+
for obj_position in object_positions[obj_idx]:
|
117 |
+
# Could potentially optimize to compute this for loop in batch.
|
118 |
+
# Could crop the ref cross attention before saving to save memory.
|
119 |
+
|
120 |
+
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
|
121 |
+
|
122 |
+
if use_ratio_based_loss:
|
123 |
+
warnings.warn("Using ratio-based loss, which is deprecated. Max-based loss is recommended. The scale may be different.")
|
124 |
+
# Original loss function (ratio-based loss function)
|
125 |
+
|
126 |
+
# Enforces the attention to be within the mask only. Does not enforce within-mask distribution.
|
127 |
+
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
|
128 |
+
obj_loss += torch.mean((1 - activation_value) ** 2)
|
129 |
+
# if verbose:
|
130 |
+
# print(f"enforce attn to be within the mask loss: {torch.mean((1 - activation_value) ** 2).item():.2f}")
|
131 |
+
else:
|
132 |
+
# Max-based loss function
|
133 |
+
|
134 |
+
# shape: (b, H * W)
|
135 |
+
ca_map_obj = attn_map[:, :, obj_position] # .reshape(b, H, W)
|
136 |
+
k_fg = (mask.sum() * fg_top_p).long().clamp_(min=1)
|
137 |
+
k_bg = ((1 - mask).sum() * bg_top_p).long().clamp_(min=1)
|
138 |
+
|
139 |
+
mask_1d = mask.view(1, -1)
|
140 |
+
|
141 |
+
# Take the topk over spatial dimension, and then take the sum over heads dim
|
142 |
+
# The mean is over k_fg and k_bg dimension, so we don't need to sum and divide on our own.
|
143 |
+
obj_loss += (1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1)).sum(dim=0) * fg_weight
|
144 |
+
obj_loss += ((ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1)).sum(dim=0) * bg_weight
|
145 |
+
|
146 |
+
loss += obj_loss / len(object_positions[obj_idx])
|
147 |
+
|
148 |
+
return loss
|
149 |
+
|
150 |
+
def add_ref_ca_loss_per_attn_map_to_lossv2(loss, saved_attn, object_number, bboxes, object_positions, guidance_attn_keys, ref_ca_saved_attns, ref_ca_last_token_only, ref_ca_word_token_only, word_token_indices, index, loss_weight, eps=1e-5, verbose=False):
|
151 |
+
"""
|
152 |
+
This adds the ca loss with ref. Note that this should be used with ca loss without ref since it only enforces the mse of the normalized ca between ref and target.
|
153 |
+
|
154 |
+
`ref_ca_saved_attn` should have the same structure as bboxes and object_positions (until the inner content, which should be similar to saved_attn).
|
155 |
+
"""
|
156 |
+
|
157 |
+
if loss_weight == 0.:
|
158 |
+
# Skip computing the reference loss if the loss weight is 0.
|
159 |
+
return loss
|
160 |
+
|
161 |
+
for obj_idx in range(object_number):
|
162 |
+
obj_loss = 0
|
163 |
+
|
164 |
+
obj_boxes = bboxes[obj_idx]
|
165 |
+
obj_ref_ca_saved_attns = ref_ca_saved_attns[obj_idx]
|
166 |
+
|
167 |
+
# We support two level (one box per phrase) and three level (multiple boxes per phrase)
|
168 |
+
if not isinstance(obj_boxes[0], Iterable):
|
169 |
+
obj_boxes = [obj_boxes]
|
170 |
+
obj_ref_ca_saved_attns = [obj_ref_ca_saved_attns]
|
171 |
+
|
172 |
+
assert len(obj_boxes) == len(obj_ref_ca_saved_attns), f"obj_boxes: {len(obj_boxes)}, obj_ref_ca_saved_attns: {len(obj_ref_ca_saved_attns)}"
|
173 |
+
|
174 |
+
for obj_box, obj_ref_ca_saved_attn in zip(obj_boxes, obj_ref_ca_saved_attns):
|
175 |
+
# obj_ref_ca_map_items has all timesteps.
|
176 |
+
# Format: (timestep (index), attn_key, batch, heads, 2d dim, num text tokens (selected 1))
|
177 |
+
|
178 |
+
# Different from ca_loss without ref, which has one loss for all boxes for a phrase (a set of object positions), we have one loss per box.
|
179 |
+
|
180 |
+
# obj_ref_ca_saved_attn_items: select the timestep
|
181 |
+
obj_ref_ca_saved_attn = obj_ref_ca_saved_attn[index]
|
182 |
+
|
183 |
+
for attn_key in guidance_attn_keys:
|
184 |
+
attn_map = saved_attn[attn_key]
|
185 |
+
if not attn_map.is_cuda:
|
186 |
+
attn_map = attn_map.cuda()
|
187 |
+
attn_map = attn_map.squeeze(dim=0)
|
188 |
+
|
189 |
+
obj_ref_ca_map = obj_ref_ca_saved_attn[attn_key]
|
190 |
+
if not obj_ref_ca_map.is_cuda:
|
191 |
+
obj_ref_ca_map = obj_ref_ca_map.cuda()
|
192 |
+
# obj_ref_ca_map: (batch, heads, 2d dim, num text token)
|
193 |
+
# `squeeze` on `obj_ref_ca_map` is combined with the subsequent indexing
|
194 |
+
|
195 |
+
# b is the number of heads, not batch
|
196 |
+
b, i, j = attn_map.shape
|
197 |
+
H = W = int(math.sqrt(i))
|
198 |
+
# `obj_ref_ca_map` only has one text token (the 0 at the last dimension)
|
199 |
+
|
200 |
+
assert obj_ref_ca_map.ndim == 4, f"{obj_ref_ca_map.shape}"
|
201 |
+
obj_ref_ca_map = obj_ref_ca_map[0, :, :, 0]
|
202 |
+
|
203 |
+
# Same mask for all heads
|
204 |
+
obj_mask = torch.zeros(size=(H, W), device="cuda")
|
205 |
+
# x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
|
206 |
+
x_min, y_min, x_max, y_max = utils.scale_proportion(obj_box, H=H, W=W)
|
207 |
+
obj_mask[y_min: y_max, x_min: x_max] = 1
|
208 |
+
|
209 |
+
# keep 1d mask
|
210 |
+
obj_mask = obj_mask.reshape(1, -1)
|
211 |
+
|
212 |
+
# Optimize the loss over the last phrase token only (assuming the indices in `object_positions[obj_idx]` is sorted)
|
213 |
+
if ref_ca_word_token_only:
|
214 |
+
object_positions_to_iterate = [word_token_indices[obj_idx]]
|
215 |
+
elif ref_ca_last_token_only:
|
216 |
+
object_positions_to_iterate = [object_positions[obj_idx][-1]]
|
217 |
+
else:
|
218 |
+
print(f"Applying attention transfer from one attention to all attention maps in object positions {object_positions[obj_idx]}, which is likely to be incorrect")
|
219 |
+
object_positions_to_iterate = object_positions[obj_idx]
|
220 |
+
for obj_position in object_positions_to_iterate:
|
221 |
+
ca_map_obj = attn_map[:, :, obj_position]
|
222 |
+
|
223 |
+
ca_map_obj_masked = ca_map_obj * obj_mask
|
224 |
+
|
225 |
+
# Add eps because the sum can be very small, causing NaN
|
226 |
+
ca_map_obj_masked_normalized = ca_map_obj_masked / (ca_map_obj_masked.sum(dim=-1, keepdim=True) + eps)
|
227 |
+
obj_ref_ca_map_masked = obj_ref_ca_map * obj_mask
|
228 |
+
obj_ref_ca_map_masked_normalized = obj_ref_ca_map_masked / (obj_ref_ca_map_masked.sum(dim=-1, keepdim=True) + eps)
|
229 |
+
|
230 |
+
# We found dividing by object mask size makes the loss too small. Since the normalized masked attn has mean value inversely proportional to the mask size, summing the values up spatially gives a relatively standard scale to add to other losses.
|
231 |
+
activation_value = (torch.abs(ca_map_obj_masked_normalized - obj_ref_ca_map_masked_normalized)).sum(dim=-1)
|
232 |
+
|
233 |
+
obj_loss += torch.mean(activation_value, dim=0)
|
234 |
+
|
235 |
+
# The normalization for len(obj_ref_ca_map_items) is at the outside of this function.
|
236 |
+
# Note that we assume we have at least one box for each object
|
237 |
+
loss += loss_weight * obj_loss / (len(obj_boxes) * len(object_positions_to_iterate))
|
238 |
+
|
239 |
+
if verbose:
|
240 |
+
print(f"reference cross-attention obj_loss: unweighted {obj_loss.item() / (len(obj_boxes) * len(object_positions[obj_idx])):.3f}, weighted {loss_weight * obj_loss.item() / (len(obj_boxes) * len(object_positions[obj_idx])):.3f}")
|
241 |
+
|
242 |
+
return loss
|
243 |
+
|
244 |
+
def compute_ca_lossv3(saved_attn, bboxes, object_positions, guidance_attn_keys, ref_ca_saved_attns=None, ref_ca_last_token_only=True, ref_ca_word_token_only=False, word_token_indices=None, index=None, ref_ca_loss_weight=1.0, verbose=False, **kwargs):
|
245 |
+
"""
|
246 |
+
v3 is equivalent to v2 but with new dictionary format for attention maps.
|
247 |
+
The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
|
248 |
+
`AttnProcessor` will put attention maps into the `save_attn_to_dict`.
|
249 |
+
|
250 |
+
`index` is the timestep.
|
251 |
+
`ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token).
|
252 |
+
`ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens.
|
253 |
+
"""
|
254 |
+
loss = torch.tensor(0).float().cuda()
|
255 |
+
object_number = len(bboxes)
|
256 |
+
if object_number == 0:
|
257 |
+
return loss
|
258 |
+
|
259 |
+
for attn_key in guidance_attn_keys:
|
260 |
+
# We only have 1 cross attention for mid.
|
261 |
+
attn_map_integrated = saved_attn[attn_key]
|
262 |
+
if not attn_map_integrated.is_cuda:
|
263 |
+
attn_map_integrated = attn_map_integrated.cuda()
|
264 |
+
# Example dimension: [20, 64, 77]
|
265 |
+
attn_map = attn_map_integrated.squeeze(dim=0)
|
266 |
+
loss = add_ca_loss_per_attn_map_to_loss(loss, attn_map, object_number, bboxes, object_positions, verbose=verbose, **kwargs)
|
267 |
+
|
268 |
+
num_attn = len(guidance_attn_keys)
|
269 |
+
|
270 |
+
if num_attn > 0:
|
271 |
+
loss = loss / (object_number * num_attn)
|
272 |
+
|
273 |
+
if ref_ca_saved_attns is not None:
|
274 |
+
ref_loss = torch.tensor(0).float().cuda()
|
275 |
+
ref_loss = add_ref_ca_loss_per_attn_map_to_lossv2(
|
276 |
+
ref_loss, saved_attn=saved_attn, object_number=object_number, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys,
|
277 |
+
ref_ca_saved_attns=ref_ca_saved_attns, ref_ca_last_token_only=ref_ca_last_token_only, ref_ca_word_token_only=ref_ca_word_token_only, word_token_indices=word_token_indices, verbose=verbose, index=index, loss_weight=ref_ca_loss_weight
|
278 |
+
)
|
279 |
+
|
280 |
+
num_attn = len(guidance_attn_keys)
|
281 |
+
|
282 |
+
if verbose:
|
283 |
+
print(f"loss {loss.item():.3f}, reference attention loss (weighted) {ref_loss.item() / (object_number * num_attn):.3f}")
|
284 |
+
|
285 |
+
loss += ref_loss / (object_number * num_attn)
|
286 |
+
|
287 |
+
return loss
|
288 |
+
|
289 |
+
# For compatibility
|
290 |
+
def add_ref_ca_loss_per_attn_map_to_loss(loss, attn_maps, object_number, bboxes, object_positions, ref_ca_maps, stage_id, index, verbose=False):
|
291 |
+
"""
|
292 |
+
This adds the ca loss with ref. Note that this should be used with ca loss without ref since it only enforces the mse of the normalized ca between ref and target.
|
293 |
+
|
294 |
+
ref_ca_maps should have the same structure as bboxes and object_positions.
|
295 |
+
"""
|
296 |
+
# attn_map_items is all cond ca maps for current down/mid/up for the overall generation.
|
297 |
+
attn_map_items = attn_maps[stage_id]
|
298 |
+
|
299 |
+
for obj_idx in range(object_number):
|
300 |
+
obj_loss = 0
|
301 |
+
|
302 |
+
obj_boxes = bboxes[obj_idx]
|
303 |
+
obj_ref_ca_maps = ref_ca_maps[obj_idx]
|
304 |
+
|
305 |
+
# We support two level (one box per phrase) and three level (multiple boxes per phrase)
|
306 |
+
if not isinstance(obj_boxes[0], Iterable):
|
307 |
+
obj_boxes = [obj_boxes]
|
308 |
+
obj_ref_ca_maps = [obj_ref_ca_maps]
|
309 |
+
|
310 |
+
assert len(obj_boxes) == len(obj_ref_ca_maps), f"obj_boxes: {len(obj_boxes)}, obj_ref_ca_maps: {len(obj_ref_ca_maps)}"
|
311 |
+
|
312 |
+
for obj_box, obj_ref_ca_map_items in zip(obj_boxes, obj_ref_ca_maps):
|
313 |
+
# obj_ref_ca_map_items format: (stage, timestep (index), block, batch, heads, 2d dim, num text tokens (selected 1))
|
314 |
+
# Different from ca_loss without ref, which has one loss for all boxes for a phrase (a set of object positions), we have one loss per box.
|
315 |
+
|
316 |
+
# print(len(obj_ref_ca_map_items), obj_ref_ca_map_items[stage_id].shape)
|
317 |
+
# Mid example: 1 torch.Size([50, 1, 1, 8, 64, 1])
|
318 |
+
# Up example: 3 torch.Size([50, 3, 1, 8, 256, 1])
|
319 |
+
|
320 |
+
# obj_ref_ca_map_items is all cond ca maps for current down/mid/up for the single object generation.
|
321 |
+
obj_ref_ca_map_items = obj_ref_ca_map_items[stage_id][index]
|
322 |
+
|
323 |
+
for attn_map, obj_ref_ca_map in zip(attn_map_items, obj_ref_ca_map_items):
|
324 |
+
attn_map = attn_map.squeeze(dim=0)
|
325 |
+
# b is the number of heads, not batch
|
326 |
+
b, i, j = attn_map.shape
|
327 |
+
H = W = int(math.sqrt(i))
|
328 |
+
# obj_ref_ca_map only has one text token (the 0 at the last dimension)
|
329 |
+
|
330 |
+
assert obj_ref_ca_map.ndim == 4, f"{obj_ref_ca_map.ndim}"
|
331 |
+
obj_ref_ca_map = obj_ref_ca_map[0, :, :, 0]
|
332 |
+
|
333 |
+
# Same mask for all heads
|
334 |
+
obj_mask = torch.zeros(size=(H, W), device="cuda")
|
335 |
+
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
|
336 |
+
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
|
337 |
+
obj_mask[y_min: y_max, x_min: x_max] = 1
|
338 |
+
|
339 |
+
# keep 1d mask
|
340 |
+
obj_mask = obj_mask.reshape(1, -1)
|
341 |
+
|
342 |
+
for obj_position in object_positions[obj_idx]:
|
343 |
+
ca_map_obj = attn_map[:, :, obj_position]
|
344 |
+
|
345 |
+
ca_map_obj_masked = ca_map_obj * obj_mask
|
346 |
+
obj_ref_ca_map_masked = obj_ref_ca_map * obj_mask
|
347 |
+
# We found dividing by object mask size makes the loss too small. Since the normalized masked attn has mean value inversely proportional to the mask size, summing the values up spatially gives a relatively standard scale to add to other losses.
|
348 |
+
activation_value = (torch.abs(ca_map_obj_masked / ca_map_obj_masked.sum(dim=-1, keepdim=True) - obj_ref_ca_map_masked / obj_ref_ca_map_masked.sum(dim=-1, keepdim=True))).sum(dim=-1) # / obj_mask.sum()
|
349 |
+
|
350 |
+
obj_loss += torch.mean(activation_value, dim=0)
|
351 |
+
|
352 |
+
# The normalization for len(obj_ref_ca_map_items) is at the outside of this function.
|
353 |
+
loss += obj_loss / (len(obj_boxes) * len(object_positions[obj_idx]))
|
354 |
+
|
355 |
+
if verbose:
|
356 |
+
print(f"reference cross-attention obj_loss: {obj_loss.item() / (len(obj_boxes) * len(object_positions[obj_idx])):.3f}")
|
357 |
+
|
358 |
+
return loss
|
utils/latents.py
CHANGED
@@ -44,9 +44,10 @@ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inferenc
|
|
44 |
|
45 |
# Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
|
46 |
if use_fast_schedule:
|
47 |
-
# If we use fast schedule, we only
|
48 |
composed_latents = torch.zeros((fast_after_steps + 1, *latents_bg.shape), dtype=dtype)
|
49 |
else:
|
|
|
50 |
composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
|
51 |
composed_latents[0] = latents_bg
|
52 |
|
@@ -73,7 +74,7 @@ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inferenc
|
|
73 |
latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
|
74 |
foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
|
75 |
mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
|
76 |
-
composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all * mask_tensor_expanded
|
77 |
|
78 |
composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
|
79 |
return composed_latents, foreground_indices
|
|
|
44 |
|
45 |
# Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
|
46 |
if use_fast_schedule:
|
47 |
+
# If we use fast schedule, we only compose the frozen steps because the later steps do not match.
|
48 |
composed_latents = torch.zeros((fast_after_steps + 1, *latents_bg.shape), dtype=dtype)
|
49 |
else:
|
50 |
+
# Otherwise we compose all steps so that we don't need to compose again if we change the frozen steps.
|
51 |
composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
|
52 |
composed_latents[0] = latents_bg
|
53 |
|
|
|
74 |
latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
|
75 |
foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
|
76 |
mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
|
77 |
+
composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all[:fast_after_steps + 1] * mask_tensor_expanded
|
78 |
|
79 |
composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
|
80 |
return composed_latents, foreground_indices
|
utils/parse.py
CHANGED
@@ -1,33 +1,39 @@
|
|
1 |
import ast
|
2 |
-
import os
|
3 |
-
import json
|
4 |
from matplotlib.patches import Polygon
|
5 |
from matplotlib.collections import PatchCollection
|
6 |
import matplotlib.pyplot as plt
|
7 |
import numpy as np
|
8 |
-
import
|
9 |
import inflect
|
10 |
|
11 |
p = inflect.engine()
|
12 |
|
13 |
img_dir = "imgs"
|
|
|
14 |
bg_prompt_text = "Background prompt: "
|
|
|
|
|
|
|
|
|
15 |
# h, w
|
16 |
box_scale = (512, 512)
|
17 |
size = box_scale
|
18 |
size_h, size_w = size
|
19 |
print(f"Using box scale: {box_scale}")
|
20 |
|
|
|
21 |
def parse_input(text=None, no_input=False):
|
|
|
|
|
22 |
if not text:
|
23 |
if no_input:
|
24 |
return
|
25 |
|
26 |
text = input("Enter the response: ")
|
27 |
-
if
|
28 |
-
text = text.split(
|
29 |
|
30 |
-
text_split = text.split(
|
31 |
if len(text_split) == 2:
|
32 |
gen_boxes, bg_prompt = text_split
|
33 |
elif len(text_split) == 1:
|
@@ -38,8 +44,8 @@ def parse_input(text=None, no_input=False):
|
|
38 |
while not bg_prompt:
|
39 |
# Ignore the empty lines in the response
|
40 |
bg_prompt = input("Enter the background prompt: ").strip()
|
41 |
-
if
|
42 |
-
bg_prompt = bg_prompt.split(
|
43 |
else:
|
44 |
raise ValueError(f"text: {text}")
|
45 |
try:
|
@@ -54,7 +60,70 @@ def parse_input(text=None, no_input=False):
|
|
54 |
|
55 |
return gen_boxes, bg_prompt
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3):
|
|
|
|
|
|
|
58 |
if len(gen_boxes) == 0:
|
59 |
return []
|
60 |
|
@@ -62,9 +131,13 @@ def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=
|
|
62 |
gen_boxes_new = []
|
63 |
for gen_box in gen_boxes:
|
64 |
if isinstance(gen_box, dict):
|
|
|
|
|
65 |
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
|
66 |
box_dict_format = True
|
67 |
else:
|
|
|
|
|
68 |
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
|
69 |
if bbox_w <= 0 or bbox_h <= 0:
|
70 |
# Empty boxes
|
@@ -73,6 +146,12 @@ def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=
|
|
73 |
if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]:
|
74 |
# Ignore the background boxes
|
75 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
gen_boxes_new.append(gen_box)
|
77 |
|
78 |
gen_boxes = gen_boxes_new
|
@@ -99,9 +178,11 @@ def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=
|
|
99 |
|
100 |
# Used if scale_boxes is True
|
101 |
shift = -bbox_left_x_min
|
102 |
-
|
|
|
|
|
103 |
|
104 |
-
scale = min(
|
105 |
|
106 |
for gen_box in gen_boxes:
|
107 |
if box_dict_format:
|
@@ -165,7 +246,7 @@ def draw_boxes(anns):
|
|
165 |
ax.add_collection(p)
|
166 |
|
167 |
|
168 |
-
def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
|
169 |
if len(gen_boxes) == 0:
|
170 |
return
|
171 |
|
@@ -183,7 +264,7 @@ def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
|
|
183 |
|
184 |
if bg_prompt is not None:
|
185 |
ax = plt.gca()
|
186 |
-
ax.text(0, 0, bg_prompt, style='italic',
|
187 |
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
|
188 |
|
189 |
c = (np.zeros((1, 3)))
|
@@ -200,12 +281,6 @@ def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
|
|
200 |
draw_boxes(anns)
|
201 |
if show:
|
202 |
plt.show()
|
203 |
-
else:
|
204 |
-
print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}")
|
205 |
-
if ind is not None:
|
206 |
-
plt.savefig(f"{img_dir}/boxes_{ind}.png")
|
207 |
-
plt.savefig(f"{img_dir}/boxes.png")
|
208 |
-
|
209 |
|
210 |
def show_masks(masks):
|
211 |
masks_to_show = np.zeros((*size, 3), dtype=np.float32)
|
|
|
1 |
import ast
|
|
|
|
|
2 |
from matplotlib.patches import Polygon
|
3 |
from matplotlib.collections import PatchCollection
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
+
import warnings
|
7 |
import inflect
|
8 |
|
9 |
p = inflect.engine()
|
10 |
|
11 |
img_dir = "imgs"
|
12 |
+
objects_text = "Objects: "
|
13 |
bg_prompt_text = "Background prompt: "
|
14 |
+
bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
|
15 |
+
neg_prompt_text = "Negative prompt: "
|
16 |
+
neg_prompt_text_no_trailing_space = neg_prompt_text.rstrip()
|
17 |
+
|
18 |
# h, w
|
19 |
box_scale = (512, 512)
|
20 |
size = box_scale
|
21 |
size_h, size_w = size
|
22 |
print(f"Using box scale: {box_scale}")
|
23 |
|
24 |
+
|
25 |
def parse_input(text=None, no_input=False):
|
26 |
+
warnings.warn("Parsing input without negative prompt is deprecated.")
|
27 |
+
|
28 |
if not text:
|
29 |
if no_input:
|
30 |
return
|
31 |
|
32 |
text = input("Enter the response: ")
|
33 |
+
if objects_text in text:
|
34 |
+
text = text.split(objects_text)[1]
|
35 |
|
36 |
+
text_split = text.split(bg_prompt_text_no_trailing_space)
|
37 |
if len(text_split) == 2:
|
38 |
gen_boxes, bg_prompt = text_split
|
39 |
elif len(text_split) == 1:
|
|
|
44 |
while not bg_prompt:
|
45 |
# Ignore the empty lines in the response
|
46 |
bg_prompt = input("Enter the background prompt: ").strip()
|
47 |
+
if bg_prompt_text_no_trailing_space in bg_prompt:
|
48 |
+
bg_prompt = bg_prompt.split(bg_prompt_text_no_trailing_space)[1]
|
49 |
else:
|
50 |
raise ValueError(f"text: {text}")
|
51 |
try:
|
|
|
60 |
|
61 |
return gen_boxes, bg_prompt
|
62 |
|
63 |
+
def parse_input_with_negative(text=None, no_input=False):
|
64 |
+
# no_input: should not request interactive input
|
65 |
+
|
66 |
+
if not text:
|
67 |
+
if no_input:
|
68 |
+
return
|
69 |
+
|
70 |
+
text = input("Enter the response: ")
|
71 |
+
if objects_text in text:
|
72 |
+
text = text.split(objects_text)[1]
|
73 |
+
|
74 |
+
text_split = text.split(bg_prompt_text_no_trailing_space)
|
75 |
+
if len(text_split) == 2:
|
76 |
+
gen_boxes, text_rem = text_split
|
77 |
+
elif len(text_split) == 1:
|
78 |
+
if no_input:
|
79 |
+
return
|
80 |
+
gen_boxes = text
|
81 |
+
text_rem = ""
|
82 |
+
while not text_rem:
|
83 |
+
# Ignore the empty lines in the response
|
84 |
+
text_rem = input("Enter the background prompt: ").strip()
|
85 |
+
if bg_prompt_text_no_trailing_space in text_rem:
|
86 |
+
text_rem = text_rem.split(bg_prompt_text_no_trailing_space)[1]
|
87 |
+
else:
|
88 |
+
raise ValueError(f"text: {text}")
|
89 |
+
|
90 |
+
text_split = text_rem.split(neg_prompt_text_no_trailing_space)
|
91 |
+
|
92 |
+
if len(text_split) == 2:
|
93 |
+
bg_prompt, neg_prompt = text_split
|
94 |
+
elif len(text_split) == 1:
|
95 |
+
bg_prompt = text_rem
|
96 |
+
# Negative prompt is optional: if it's not provided, we default to empty string
|
97 |
+
neg_prompt = ""
|
98 |
+
if not no_input:
|
99 |
+
# Ignore the empty lines in the response
|
100 |
+
neg_prompt = input("Enter the negative prompt: ").strip()
|
101 |
+
if neg_prompt_text_no_trailing_space in neg_prompt:
|
102 |
+
neg_prompt = neg_prompt.split(neg_prompt_text_no_trailing_space)[1]
|
103 |
+
else:
|
104 |
+
raise ValueError(f"text: {text}")
|
105 |
+
|
106 |
+
try:
|
107 |
+
gen_boxes = ast.literal_eval(gen_boxes)
|
108 |
+
except SyntaxError as e:
|
109 |
+
# Sometimes the response is in plain text
|
110 |
+
if "No objects" in gen_boxes or gen_boxes.strip() == "":
|
111 |
+
gen_boxes = []
|
112 |
+
else:
|
113 |
+
raise e
|
114 |
+
bg_prompt = bg_prompt.strip()
|
115 |
+
neg_prompt = neg_prompt.strip()
|
116 |
+
|
117 |
+
# LLM may return "None" to mean no negative prompt provided.
|
118 |
+
if neg_prompt == "None":
|
119 |
+
neg_prompt = ""
|
120 |
+
|
121 |
+
return gen_boxes, bg_prompt, neg_prompt
|
122 |
+
|
123 |
def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3):
|
124 |
+
if gen_boxes is None:
|
125 |
+
return []
|
126 |
+
|
127 |
if len(gen_boxes) == 0:
|
128 |
return []
|
129 |
|
|
|
131 |
gen_boxes_new = []
|
132 |
for gen_box in gen_boxes:
|
133 |
if isinstance(gen_box, dict):
|
134 |
+
if not gen_box['bounding_box']:
|
135 |
+
continue
|
136 |
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
|
137 |
box_dict_format = True
|
138 |
else:
|
139 |
+
if not gen_box[1]:
|
140 |
+
continue
|
141 |
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
|
142 |
if bbox_w <= 0 or bbox_h <= 0:
|
143 |
# Empty boxes
|
|
|
146 |
if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]:
|
147 |
# Ignore the background boxes
|
148 |
continue
|
149 |
+
|
150 |
+
if bbox_x < 0 or bbox_y < 0 or bbox_x + bbox_w > size[1] or bbox_y + bbox_h > size[0]:
|
151 |
+
# Out of bounds boxes exist: we need to scale and shift all the boxes
|
152 |
+
print(f"**Some boxes are out of bounds: {gen_box}, scaling all the boxes to fit**")
|
153 |
+
scale_boxes = True
|
154 |
+
|
155 |
gen_boxes_new.append(gen_box)
|
156 |
|
157 |
gen_boxes = gen_boxes_new
|
|
|
178 |
|
179 |
# Used if scale_boxes is True
|
180 |
shift = -bbox_left_x_min
|
181 |
+
# Make sure the boxes fit horizontally and vertically
|
182 |
+
scale_w = size_w / (bbox_right_x_max - bbox_left_x_min)
|
183 |
+
scale_h = size_h / (bbox_bottom_y_max - bbox_top_y_min)
|
184 |
|
185 |
+
scale = min(scale_w, scale_h, max_scale)
|
186 |
|
187 |
for gen_box in gen_boxes:
|
188 |
if box_dict_format:
|
|
|
246 |
ax.add_collection(p)
|
247 |
|
248 |
|
249 |
+
def show_boxes(gen_boxes, bg_prompt=None, neg_prompt=None, ind=None, show=False):
|
250 |
if len(gen_boxes) == 0:
|
251 |
return
|
252 |
|
|
|
264 |
|
265 |
if bg_prompt is not None:
|
266 |
ax = plt.gca()
|
267 |
+
ax.text(0, 0, bg_prompt + f"(Neg: {neg_prompt})" if neg_prompt else bg_prompt, style='italic',
|
268 |
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
|
269 |
|
270 |
c = (np.zeros((1, 3)))
|
|
|
281 |
draw_boxes(anns)
|
282 |
if show:
|
283 |
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
def show_masks(masks):
|
286 |
masks_to_show = np.zeros((*size, 3), dtype=np.float32)
|
utils/utils.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import torch
|
2 |
from PIL import ImageDraw
|
3 |
import numpy as np
|
4 |
-
import os
|
5 |
import gc
|
6 |
|
7 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
1 |
import torch
|
2 |
from PIL import ImageDraw
|
3 |
import numpy as np
|
|
|
4 |
import gc
|
5 |
|
6 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
utils/vis.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import math
|
3 |
+
import utils
|
4 |
+
from . import parse
|
5 |
+
|
6 |
+
save_ind = 0
|
7 |
+
|
8 |
+
def visualize(image, title, colorbar=False, show_plot=True, **kwargs):
|
9 |
+
plt.title(title)
|
10 |
+
plt.imshow(image, **kwargs)
|
11 |
+
if colorbar:
|
12 |
+
plt.colorbar()
|
13 |
+
if show_plot:
|
14 |
+
plt.show()
|
15 |
+
|
16 |
+
def visualize_arrays(image_title_pairs, colorbar_index=-1, show_plot=True, figsize=None, **kwargs):
|
17 |
+
if figsize is not None:
|
18 |
+
plt.figure(figsize=figsize)
|
19 |
+
num_subplots = len(image_title_pairs)
|
20 |
+
for idx, image_title_pair in enumerate(image_title_pairs):
|
21 |
+
plt.subplot(1, num_subplots, idx+1)
|
22 |
+
if isinstance(image_title_pair, (list, tuple)):
|
23 |
+
image, title = image_title_pair
|
24 |
+
else:
|
25 |
+
image, title = image_title_pair, None
|
26 |
+
|
27 |
+
if title is not None:
|
28 |
+
plt.title(title)
|
29 |
+
|
30 |
+
plt.imshow(image, **kwargs)
|
31 |
+
if idx == colorbar_index:
|
32 |
+
plt.colorbar()
|
33 |
+
|
34 |
+
if show_plot:
|
35 |
+
plt.show()
|
36 |
+
|
37 |
+
def visualize_masked_latents(latents_all, masked_latents, timestep_T=False, timestep_0=True):
|
38 |
+
if timestep_T:
|
39 |
+
# from T to 0
|
40 |
+
latent_idx = 0
|
41 |
+
|
42 |
+
plt.subplot(1, 2, 1)
|
43 |
+
plt.title("latents_all (t=T)")
|
44 |
+
plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
|
45 |
+
|
46 |
+
plt.subplot(1, 2, 2)
|
47 |
+
plt.title("mask latents (t=T)")
|
48 |
+
plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
|
49 |
+
|
50 |
+
plt.show()
|
51 |
+
|
52 |
+
if timestep_0:
|
53 |
+
latent_idx = -1
|
54 |
+
plt.subplot(1, 2, 1)
|
55 |
+
plt.title("latents_all (t=0)")
|
56 |
+
plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
|
57 |
+
|
58 |
+
plt.subplot(1, 2, 2)
|
59 |
+
plt.title("mask latents (t=0)")
|
60 |
+
plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
|
61 |
+
|
62 |
+
plt.show()
|
63 |
+
|
64 |
+
# This function has not been adapted to new `saved_attn`.
|
65 |
+
def visualize_attn(token_map, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False):
|
66 |
+
"""
|
67 |
+
Visualize cross attention: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token
|
68 |
+
cross_attention_probs_tensors:
|
69 |
+
One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors`
|
70 |
+
stage_id: index of downsampling/mid/upsaming block
|
71 |
+
block_id: index of the transformer block
|
72 |
+
"""
|
73 |
+
|
74 |
+
plt.figure(figsize=(20, 8))
|
75 |
+
|
76 |
+
for token_id in range(len(token_map)):
|
77 |
+
token = token_map[token_id]
|
78 |
+
plt.subplot(1, len(token_map), token_id + 1)
|
79 |
+
plt.title(token)
|
80 |
+
attn = cross_attention_probs_tensors[stage_id][visualize_step_start:].mean(dim=0)[block_id]
|
81 |
+
|
82 |
+
if not input_ca_has_condition_only:
|
83 |
+
assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
|
84 |
+
attn = attn[1]
|
85 |
+
else:
|
86 |
+
assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
|
87 |
+
attn = attn[0]
|
88 |
+
|
89 |
+
attn = attn.mean(dim=0)[:, token_id]
|
90 |
+
H = W = int(math.sqrt(attn.shape[0]))
|
91 |
+
attn = attn.reshape((H, W))
|
92 |
+
plt.imshow(attn.cpu().numpy())
|
93 |
+
|
94 |
+
plt.show()
|
95 |
+
|
96 |
+
# This function has not been adapted to new `saved_attn`.
|
97 |
+
def visualize_across_timesteps(token_id, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False):
|
98 |
+
"""
|
99 |
+
Visualize cross attention for one token, across timesteps: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token
|
100 |
+
cross_attention_probs_tensors:
|
101 |
+
One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors`
|
102 |
+
stage_id: index of downsampling/mid/upsaming block
|
103 |
+
block_id: index of the transformer block
|
104 |
+
|
105 |
+
`visualize_step_start` is not used. We visualize all timesteps.
|
106 |
+
"""
|
107 |
+
plt.figure(figsize=(50, 8))
|
108 |
+
|
109 |
+
attn_stage = cross_attention_probs_tensors[stage_id]
|
110 |
+
num_inference_steps = attn_stage.shape[0]
|
111 |
+
|
112 |
+
for t in range(num_inference_steps):
|
113 |
+
plt.subplot(1, num_inference_steps, t + 1)
|
114 |
+
plt.title(f"t: {t}")
|
115 |
+
|
116 |
+
attn = attn_stage[t][block_id]
|
117 |
+
|
118 |
+
if not input_ca_has_condition_only:
|
119 |
+
assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
|
120 |
+
attn = attn[1]
|
121 |
+
else:
|
122 |
+
assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
|
123 |
+
attn = attn[0]
|
124 |
+
|
125 |
+
attn = attn.mean(dim=0)[:, token_id]
|
126 |
+
H = W = int(math.sqrt(attn.shape[0]))
|
127 |
+
attn = attn.reshape((H, W))
|
128 |
+
plt.imshow(attn.cpu().numpy())
|
129 |
+
plt.axis("off")
|
130 |
+
plt.tight_layout()
|
131 |
+
|
132 |
+
plt.show()
|
133 |
+
|
134 |
+
def visualize_bboxes(bboxes, H, W):
|
135 |
+
num_boxes = len(bboxes)
|
136 |
+
for ind, bbox in enumerate(bboxes):
|
137 |
+
plt.subplot(1, num_boxes, ind + 1)
|
138 |
+
fg_mask = utils.proportion_to_mask(bbox, H, W)
|
139 |
+
plt.title(f"transformed bbox ({ind})")
|
140 |
+
plt.imshow(fg_mask.cpu().numpy())
|
141 |
+
plt.show()
|
142 |
+
|
143 |
+
def display(image, save_prefix="", ind=None):
|
144 |
+
global save_ind
|
145 |
+
if save_prefix != "":
|
146 |
+
save_prefix = save_prefix + "_"
|
147 |
+
ind = f"{ind}_" if ind is not None else ""
|
148 |
+
path = f"{parse.img_dir}/{save_prefix}{ind}{save_ind}.png"
|
149 |
+
|
150 |
+
print(f"Saved to {path}")
|
151 |
+
|
152 |
+
image.save(path)
|
153 |
+
save_ind = save_ind + 1
|