Tony Lian commited on
Commit
89f6983
1 Parent(s): e32648c

Update: add attention guidance and refactor the code

Browse files
Files changed (13) hide show
  1. app.py +77 -149
  2. examples.py +56 -6
  3. generation.py +412 -130
  4. models/modeling_utils.py +0 -874
  5. models/pipelines.py +352 -2
  6. models/sam.py +4 -2
  7. utils/attn.py +140 -0
  8. utils/boxdiff.py +259 -0
  9. utils/guidance.py +358 -0
  10. utils/latents.py +3 -2
  11. utils/parse.py +93 -18
  12. utils/utils.py +0 -1
  13. utils/vis.py +153 -0
app.py CHANGED
@@ -1,65 +1,27 @@
1
  import gradio as gr
2
  import numpy as np
3
- import ast
4
- from matplotlib.patches import Polygon
5
- from matplotlib.collections import PatchCollection
6
  import matplotlib.pyplot as plt
7
- from utils.parse import filter_boxes
8
  from generation import run as run_ours
9
  from baseline import run as run_baseline
10
  import torch
11
  from shared import DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
12
- from examples import stage1_examples, stage2_examples
13
 
14
- print(f"Is CUDA available: {torch.cuda.is_available()}")
15
- if torch.cuda.is_available():
16
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
-
18
- box_scale = (512, 512)
19
- size = box_scale
20
-
21
- bg_prompt_text = "Background prompt: "
22
-
23
- default_template = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512, and the bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object. Make the boxes larger if possible. Do not put objects that are already provided in the bounding boxes into the background prompt. If needed, you can make reasonable guesses. Generate the object descriptions and background prompts in English even if the caption might not be in English. Do not include non-existing or excluded objects in the background prompt. Please refer to the example below for the desired format.
24
-
25
- Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
26
- Objects: [('a green car', [21, 181, 211, 159]), ('a blue truck', [269, 181, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
27
- Background prompt: A realistic image of a landscape scene
28
-
29
- Caption: A watercolor painting of a wooden table in the living room with an apple on it
30
- Objects: [('a wooden table', [65, 243, 344, 206]), ('a apple', [206, 306, 81, 69])]
31
- Background prompt: A watercolor painting of a living room
32
-
33
- Caption: A watercolor painting of two pandas eating bamboo in a forest
34
- Objects: [('a panda eating bambooo', [30, 171, 212, 226]), ('a panda eating bambooo', [264, 173, 222, 221])]
35
- Background prompt: A watercolor painting of a forest
36
-
37
- Caption: A realistic image of four skiers standing in a line on the snow near a palm tree
38
- Objects: [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 180, 103, 180])]
39
- Background prompt: A realistic image of an outdoor scene with snow
40
 
41
- Caption: An oil painting of a pink dolphin jumping on the left of a steam boat on the sea
42
- Objects: [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]
43
- Background prompt: An oil painting of the sea
44
-
45
- Caption: A realistic image of a cat playing with a dog in a park with flowers
46
- Objects: [('a playful cat', [51, 67, 271, 324]), ('a playful dog', [302, 119, 211, 228])]
47
- Background prompt: A realistic image of a park with flowers
48
-
49
- Caption: 一个客厅场景的油画,墙上挂着电视,电视下面是一个柜子,柜子上有一个花瓶。
50
- Objects: [('a tv', [88, 85, 335, 203]), ('a cabinet', [57, 308, 404, 201]), ('a flower vase', [166, 222, 92, 108])]
51
- Background prompt: An oil painting of a living room scene"""
52
-
53
- simplified_prompt = """{template}
54
-
55
- Caption: {prompt}
56
- Objects: """
57
 
58
- prompt_placeholder = "A realistic photo of a gray cat and an orange dog on the grass."
 
 
 
 
 
59
 
60
- layout_placeholder = """Caption: A realistic photo of a gray cat and an orange dog on the grass.
61
- Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
62
- Background prompt: A realistic photo of a grassy area."""
63
 
64
  def get_lmd_prompt(prompt, template=default_template):
65
  if prompt == "":
@@ -71,10 +33,10 @@ def get_lmd_prompt(prompt, template=default_template):
71
  def get_layout_image(response):
72
  if response == "":
73
  response = layout_placeholder
74
- gen_boxes, bg_prompt = parse_input(response)
75
  fig = plt.figure(figsize=(8, 8))
76
  # https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
77
- show_boxes(gen_boxes, bg_prompt)
78
  # If we haven't already shown or saved the plot, then we need to
79
  # draw the figure first...
80
  fig.canvas.draw()
@@ -88,32 +50,41 @@ def get_layout_image(response):
88
  def get_layout_image_gallery(response):
89
  return [get_layout_image(response)]
90
 
91
- def get_ours_image(response, overall_prompt_override="", seed=0, num_inference_steps=20, dpm_scheduler=True, use_autocast=False, fg_seed_start=20, fg_blending_ratio=0.1, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta=0.3, so_negative_prompt=DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT, show_so_imgs=False, scale_boxes=False):
92
  if response == "":
93
  response = layout_placeholder
94
- gen_boxes, bg_prompt = parse_input(response)
95
  gen_boxes = filter_boxes(gen_boxes, scale_boxes=scale_boxes)
96
  spec = {
97
  # prompt is unused
98
  'prompt': '',
99
  'gen_boxes': gen_boxes,
100
- 'bg_prompt': bg_prompt
 
101
  }
102
 
103
  if dpm_scheduler:
104
  scheduler_key = "dpm_scheduler"
105
  else:
106
  scheduler_key = "scheduler"
107
-
 
 
108
  image_np, so_img_list = run_ours(
109
  spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
110
  fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
111
- gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
112
- so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt, so_batch_size=2
 
113
  )
114
  images = [image_np]
115
  if show_so_imgs:
116
  images.extend([np.asarray(so_img) for so_img in so_img_list])
 
 
 
 
 
117
  return images
118
 
119
  def get_baseline_image(prompt, seed=0):
@@ -126,73 +97,6 @@ def get_baseline_image(prompt, seed=0):
126
  image_np = run_baseline(prompt, bg_seed=seed, scheduler_key=scheduler_key, num_inference_steps=num_inference_steps)
127
  return [image_np]
128
 
129
- def parse_input(text=None):
130
- try:
131
- if "Objects: " in text:
132
- text = text.split("Objects: ")[1]
133
-
134
- text_split = text.split(bg_prompt_text)
135
- if len(text_split) == 2:
136
- gen_boxes, bg_prompt = text_split
137
- gen_boxes = ast.literal_eval(gen_boxes)
138
- bg_prompt = bg_prompt.strip()
139
- except Exception as e:
140
- raise gr.Error(f"response format invalid: {e} (text: {text})")
141
-
142
- return gen_boxes, bg_prompt
143
-
144
- def draw_boxes(anns):
145
- ax = plt.gca()
146
- ax.set_autoscale_on(False)
147
- polygons = []
148
- color = []
149
- for ann in anns:
150
- c = (np.random.random((1, 3))*0.6+0.4)
151
- [bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
152
- poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
153
- [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
154
- np_poly = np.array(poly).reshape((4, 2))
155
- polygons.append(Polygon(np_poly))
156
- color.append(c)
157
-
158
- # print(ann)
159
- name = ann['name'] if 'name' in ann else str(ann['category_id'])
160
- ax.text(bbox_x, bbox_y, name, style='italic',
161
- bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
162
-
163
- p = PatchCollection(polygons, facecolor='none',
164
- edgecolors=color, linewidths=2)
165
- ax.add_collection(p)
166
-
167
-
168
- def show_boxes(gen_boxes, bg_prompt=None):
169
- anns = [{'name': gen_box[0], 'bbox': gen_box[1]}
170
- for gen_box in gen_boxes]
171
-
172
- # White background (to allow line to show on the edge)
173
- I = np.ones((size[0]+4, size[1]+4, 3), dtype=np.uint8) * 255
174
-
175
- plt.imshow(I)
176
- plt.axis('off')
177
-
178
- if bg_prompt is not None:
179
- ax = plt.gca()
180
- ax.text(0, 0, bg_prompt, style='italic',
181
- bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
182
-
183
- c = np.zeros((1, 3))
184
- [bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0])
185
- poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
186
- [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
187
- np_poly = np.array(poly).reshape((4, 2))
188
- polygons = [Polygon(np_poly)]
189
- color = [c]
190
- p = PatchCollection(polygons, facecolor='none',
191
- edgecolors=color, linewidths=2)
192
- ax.add_collection(p)
193
-
194
- draw_boxes(anns)
195
-
196
  duplicate_html = '<a style="display:inline-block" href="https://huggingface.co/spaces/longlian/llm-grounded-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a>'
197
 
198
  html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
@@ -200,15 +104,28 @@ html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to
200
  <h2><a href='https://llm-grounded-diffusion.github.io/'>Project Page</a> | <a href='https://bair.berkeley.edu/blog/2023/05/23/lmd/'>5-minute Blog Post</a> | <a href='https://arxiv.org/pdf/2305.13655.pdf'>ArXiv Paper</a> | <a href='https://github.com/TonyLianLong/LLM-groundedDiffusion'>Github</a> | <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</h2>
201
  <p><b>Tips:</b><p>
202
  <p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
203
- <p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
204
- <p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.</p>
205
  <p>4. The diffusion model only runs 20 steps by default in this demo. You can make it run more steps to get higher quality images (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
206
  <p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (<b>Currently we are using a T4 GPU on this space, which is quite slow, and you can add a A10G to make it 5x faster</b>) {duplicate_html}</p>
207
  <br/>
208
- <p>Implementation note: In this demo, we replace the attention manipulation in our layout-guided Stable Diffusion described in our paper with GLIGEN due to much faster inference speed (<b>FlashAttention supported, no backprop needed</b> during inference). Compared to vanilla GLIGEN, we have better coherence. Other parts of text-to-image pipeline, including single object generation and SAM, remain the same. The settings and examples in the prompt are simplified in this demo.</p>
209
- <style>.btn {{flex-grow: unset !important;}} </style>
210
  """
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  with gr.Blocks(
213
  title="LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models"
214
  ) as g:
@@ -230,42 +147,53 @@ with gr.Blocks(
230
  inputs=[prompt],
231
  outputs=[output],
232
  fn=get_lmd_prompt,
233
- cache_examples=True
 
234
  )
235
 
236
  with gr.Tab("Stage 2 (New). Layout to Image generation"):
237
  with gr.Row():
238
  with gr.Column(scale=1):
239
- response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
240
- overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
241
- num_inference_steps = gr.Slider(1, 250, value=50, step=1, label="Number of denoising steps (set to >=50 for higher generation quality)")
 
 
242
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
243
  with gr.Accordion("Advanced options (play around for better generation)", open=False):
244
- frozen_step_ratio = gr.Slider(0, 1, value=0.5, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
245
- gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.4, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
246
- dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
247
- use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)", show_label=False, value=True)
248
- fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
249
- fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
250
- so_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for single object generation", value=DEFAULT_SO_NEGATIVE_PROMPT)
251
- overall_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for overall generation", value=DEFAULT_OVERALL_NEGATIVE_PROMPT)
252
- show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False, value=False)
253
- scale_boxes = gr.Checkbox(label="Scale bounding boxes to just fit the scene", show_label=False, value=False)
 
 
 
 
 
254
  visualize_btn = gr.Button("Visualize Layout", elem_classes="btn")
255
  generate_btn = gr.Button("Generate Image from Layout", variant='primary', elem_classes="btn")
256
  with gr.Column(scale=1):
257
  gallery = gr.Gallery(
258
  label="Generated image", show_label=False, elem_id="gallery", columns=[1], rows=[1], object_fit="contain", preview=True
259
  )
 
 
260
  visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
261
- generate_btn.click(fn=get_ours_image, inputs=[response, overall_prompt_override, seed, num_inference_steps, dpm_scheduler, use_autocast, fg_seed_start, fg_blending_ratio, frozen_step_ratio, gligen_scheduled_sampling_beta, so_negative_prompt, overall_negative_prompt, show_so_imgs, scale_boxes], outputs=gallery, api_name="layout-to-image")
262
 
263
  gr.Examples(
264
  examples=stage2_examples,
265
  inputs=[response, overall_prompt_override, seed],
266
  outputs=[gallery],
267
  fn=get_ours_image,
268
- cache_examples=True
 
269
  )
270
 
271
  with gr.Tab("Baseline: Stable Diffusion"):
@@ -274,8 +202,7 @@ with gr.Blocks(
274
  sd_prompt = gr.Textbox(lines=2, label="Prompt for baseline SD", placeholder=prompt_placeholder)
275
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
276
  generate_btn = gr.Button("Generate", elem_classes="btn")
277
- # with gr.Column(scale=1):
278
- # output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img")
279
  with gr.Column(scale=1):
280
  gallery = gr.Gallery(
281
  label="Generated image", show_label=False, elem_id="gallery2", columns=[1], rows=[1], object_fit="contain", preview=True
@@ -287,7 +214,8 @@ with gr.Blocks(
287
  inputs=[sd_prompt],
288
  outputs=[gallery],
289
  fn=get_baseline_image,
290
- cache_examples=True
 
291
  )
292
 
293
  g.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import os
 
 
4
  import matplotlib.pyplot as plt
5
+ from utils.parse import filter_boxes, parse_input_with_negative, show_boxes
6
  from generation import run as run_ours
7
  from baseline import run as run_baseline
8
  import torch
9
  from shared import DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
10
+ from examples import stage1_examples, stage2_examples, default_template, simplified_prompt, prompt_placeholder, layout_placeholder
11
 
12
+ cuda_available = torch.cuda.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ if cuda_available:
17
+ gpu_memory = torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory
18
+ low_memory = gpu_memory <= 16 * 1024 ** 3
19
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}. With GPU memory: {gpu_memory}. Low memory: {low_memory}")
20
+ else:
21
+ low_memory = False
22
 
23
+ cache_examples = True
24
+ default_num_inference_steps = 20 if low_memory else 50
 
25
 
26
  def get_lmd_prompt(prompt, template=default_template):
27
  if prompt == "":
 
33
  def get_layout_image(response):
34
  if response == "":
35
  response = layout_placeholder
36
+ gen_boxes, bg_prompt, neg_prompt = parse_input_with_negative(response, no_input=True)
37
  fig = plt.figure(figsize=(8, 8))
38
  # https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
39
+ show_boxes(gen_boxes, bg_prompt, neg_prompt)
40
  # If we haven't already shown or saved the plot, then we need to
41
  # draw the figure first...
42
  fig.canvas.draw()
 
50
  def get_layout_image_gallery(response):
51
  return [get_layout_image(response)]
52
 
53
+ def get_ours_image(response, overall_prompt_override="", seed=0, num_inference_steps=250, dpm_scheduler=True, use_autocast=False, fg_seed_start=20, fg_blending_ratio=0.1, frozen_step_ratio=0.5, attn_guidance_step_ratio=0.6, gligen_scheduled_sampling_beta=0.4, attn_guidance_scale=20, use_ref_ca=True, so_negative_prompt=DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT, show_so_imgs=False, scale_boxes=False):
54
  if response == "":
55
  response = layout_placeholder
56
+ gen_boxes, bg_prompt, neg_prompt = parse_input_with_negative(response, no_input=True)
57
  gen_boxes = filter_boxes(gen_boxes, scale_boxes=scale_boxes)
58
  spec = {
59
  # prompt is unused
60
  'prompt': '',
61
  'gen_boxes': gen_boxes,
62
+ 'bg_prompt': bg_prompt,
63
+ 'extra_neg_prompt': neg_prompt
64
  }
65
 
66
  if dpm_scheduler:
67
  scheduler_key = "dpm_scheduler"
68
  else:
69
  scheduler_key = "scheduler"
70
+
71
+ overall_max_index_step = int(attn_guidance_step_ratio * num_inference_steps)
72
+
73
  image_np, so_img_list = run_ours(
74
  spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
75
  fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
76
+ so_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, overall_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
77
+ use_ref_ca=use_ref_ca, so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt,
78
+ loss_scale=attn_guidance_scale, max_index_step=0, overall_loss_scale=attn_guidance_scale, overall_max_index_step=overall_max_index_step,
79
  )
80
  images = [image_np]
81
  if show_so_imgs:
82
  images.extend([np.asarray(so_img) for so_img in so_img_list])
83
+
84
+ if cuda_available:
85
+ print(f"Max GPU memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")
86
+ torch.cuda.reset_max_memory_allocated()
87
+
88
  return images
89
 
90
  def get_baseline_image(prompt, seed=0):
 
97
  image_np = run_baseline(prompt, bg_seed=seed, scheduler_key=scheduler_key, num_inference_steps=num_inference_steps)
98
  return [image_np]
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  duplicate_html = '<a style="display:inline-block" href="https://huggingface.co/spaces/longlian/llm-grounded-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a>'
101
 
102
  html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
 
104
  <h2><a href='https://llm-grounded-diffusion.github.io/'>Project Page</a> | <a href='https://bair.berkeley.edu/blog/2023/05/23/lmd/'>5-minute Blog Post</a> | <a href='https://arxiv.org/pdf/2305.13655.pdf'>ArXiv Paper</a> | <a href='https://github.com/TonyLianLong/LLM-groundedDiffusion'>Github</a> | <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</h2>
105
  <p><b>Tips:</b><p>
106
  <p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
107
+ <p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the objects bigger or move the objects).</p>
108
+ <p>3. You can also try prompts in Simplified Chinese. You need to leave "prompt for overall image" empty in this case. If you want to try prompts in another language, translate the first line of last example to your language.</p>
109
  <p>4. The diffusion model only runs 20 steps by default in this demo. You can make it run more steps to get higher quality images (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
110
  <p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (<b>Currently we are using a T4 GPU on this space, which is quite slow, and you can add a A10G to make it 5x faster</b>) {duplicate_html}</p>
111
  <br/>
112
+ <p>Implementation note (updated): In this demo, we provide a few modes: faster generation by disabling attention/per-box guidance. The standard version describes what is implemented for the paper. You can set GLIGEN guidance steps ratio to 0 to disable GLIGEN and use only the original SD weights.</p>
113
+ <style>.btn {{flex-grow: unset !important;}} </p>
114
  """
115
 
116
+ def preset_change(preset):
117
+ # frozen_step_ratio, attn_guidance_step_ratio, attn_guidance_scale, use_ref_ca, so_negative_prompt
118
+ if preset == "Standard":
119
+ return gr.update(value=0.5, interactive=True), gr.update(value=0.6, interactive=True), gr.update(interactive=True), gr.update(value=True, interactive=True), gr.update(interactive=True)
120
+ elif preset == "Faster (disable attention guidance)":
121
+ return gr.update(value=0.5, interactive=True), gr.update(value=0, interactive=False), gr.update(interactive=False), gr.update(value=True, interactive=True), gr.update(interactive=True)
122
+ elif preset == "Faster (disable per-box guidance)":
123
+ return gr.update(value=0, interactive=False), gr.update(value=0.6, interactive=True), gr.update(interactive=True), gr.update(value=False, interactive=False), gr.update(interactive=False)
124
+ elif preset == "Fastest (disable both)":
125
+ return gr.update(value=0, interactive=False), gr.update(value=0, interactive=False), gr.update(interactive=False), gr.update(value=False, interactive=False), gr.update(interactive=True)
126
+ else:
127
+ raise gr.Error(f"Unknown preset {preset}")
128
+
129
  with gr.Blocks(
130
  title="LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models"
131
  ) as g:
 
147
  inputs=[prompt],
148
  outputs=[output],
149
  fn=get_lmd_prompt,
150
+ cache_examples=cache_examples,
151
+ label="example_stage1"
152
  )
153
 
154
  with gr.Tab("Stage 2 (New). Layout to Image generation"):
155
  with gr.Row():
156
  with gr.Column(scale=1):
157
+ overall_prompt_override = gr.Textbox(lines=2, label="Prompt for the overall image (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
158
+ response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed here)", placeholder=layout_placeholder)
159
+ num_inference_steps = gr.Slider(1, 100 if low_memory else 250, value=default_num_inference_steps, step=1, label="Number of denoising steps (set to >=50 for higher generation quality)")
160
+ # Using a environment variable allows setting default to faster/fastest on low-end GPUs.
161
+ preset = gr.Radio(label="Guidance: apply less control for faster generation", choices=["Standard", "Faster (disable attention guidance)", "Faster (disable per-box guidance)", "Fastest (disable both)"], value="Faster (disable attention guidance)" if low_memory else "Standard")
162
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
163
  with gr.Accordion("Advanced options (play around for better generation)", open=False):
164
+ with gr.Tab("Guidance"):
165
+ frozen_step_ratio = gr.Slider(0, 1, value=0.5, step=0.1, label="Foreground frozen steps ratio (higher: stronger attribute binding; lower: higher coherence")
166
+ gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.4, step=0.1, label="GLIGEN guidance steps ratio (the beta value, higher: stronger GLIGEN guidance)")
167
+ attn_guidance_step_ratio = gr.Slider(0, 1, value=0.6, step=0.01, label="Attention guidance steps ratio (higher: stronger attention guidance; lower: faster and higher coherence")
168
+ attn_guidance_scale = gr.Slider(0, 50, value=20, step=0.5, label="Attention guidance scale: 0 means no attention guidance.")
169
+ use_ref_ca = gr.Checkbox(label="Using per-box attention to guide reference attention", show_label=False, value=True)
170
+ with gr.Tab("Generation"):
171
+ dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
172
+ use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)" + " [enabled due to low GPU memory]" if low_memory else "", show_label=False, value=True, interactive=not low_memory)
173
+ fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
174
+ fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
175
+ scale_boxes = gr.Checkbox(label="Scale bounding boxes to just fit the scene", show_label=False, value=False)
176
+ so_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for single object generation", value=DEFAULT_SO_NEGATIVE_PROMPT)
177
+ overall_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for overall generation", value=DEFAULT_OVERALL_NEGATIVE_PROMPT)
178
+ show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False, value=False)
179
  visualize_btn = gr.Button("Visualize Layout", elem_classes="btn")
180
  generate_btn = gr.Button("Generate Image from Layout", variant='primary', elem_classes="btn")
181
  with gr.Column(scale=1):
182
  gallery = gr.Gallery(
183
  label="Generated image", show_label=False, elem_id="gallery", columns=[1], rows=[1], object_fit="contain", preview=True
184
  )
185
+ preset.change(preset_change, [preset], [frozen_step_ratio, attn_guidance_step_ratio, attn_guidance_scale, use_ref_ca, so_negative_prompt])
186
+ prompt.change(None, [prompt], overall_prompt_override, _js="(x) => x")
187
  visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
188
+ generate_btn.click(fn=get_ours_image, inputs=[response, overall_prompt_override, seed, num_inference_steps, dpm_scheduler, use_autocast, fg_seed_start, fg_blending_ratio, frozen_step_ratio, attn_guidance_step_ratio, gligen_scheduled_sampling_beta, attn_guidance_scale, use_ref_ca, so_negative_prompt, overall_negative_prompt, show_so_imgs, scale_boxes], outputs=gallery, api_name="layout-to-image")
189
 
190
  gr.Examples(
191
  examples=stage2_examples,
192
  inputs=[response, overall_prompt_override, seed],
193
  outputs=[gallery],
194
  fn=get_ours_image,
195
+ cache_examples=cache_examples,
196
+ label="example_ours"
197
  )
198
 
199
  with gr.Tab("Baseline: Stable Diffusion"):
 
202
  sd_prompt = gr.Textbox(lines=2, label="Prompt for baseline SD", placeholder=prompt_placeholder)
203
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
204
  generate_btn = gr.Button("Generate", elem_classes="btn")
205
+
 
206
  with gr.Column(scale=1):
207
  gallery = gr.Gallery(
208
  label="Generated image", show_label=False, elem_id="gallery2", columns=[1], rows=[1], object_fit="contain", preview=True
 
214
  inputs=[sd_prompt],
215
  outputs=[gallery],
216
  fn=get_baseline_image,
217
+ cache_examples=cache_examples,
218
+ label="example_sd"
219
  )
220
 
221
  g.launch()
examples.py CHANGED
@@ -1,3 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  stage1_examples = [
2
  ["""A realistic photo of a wooden table with an apple on the left and a pear on the right."""],
3
  ["""A realistic photo of 4 TVs on a wall."""],
@@ -10,25 +52,33 @@ stage1_examples = [
10
 
11
  # Layout, seed
12
  stage2_examples = [
13
- ["""Caption: A realistic photo of a wooden table with an apple on the left and a pear on the right.
14
  Objects: [('a wooden table', [30, 30, 452, 452]), ('an apple', [52, 223, 50, 60]), ('a pear', [400, 240, 50, 60])]
15
- Background prompt: A realistic photo""", "", 0],
16
  ["""Caption: A realistic photo of 4 TVs on a wall.
17
  Objects: [('a TV', [12, 108, 120, 100]), ('a TV', [132, 112, 120, 100]), ('a TV', [252, 104, 120, 100]), ('a TV', [372, 106, 120, 100])]
18
- Background prompt: A realistic photo of a wall""", "", 0],
19
  ["""Caption: A realistic photo of a gray cat and an orange dog on the grass.
20
  Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
21
- Background prompt: A realistic photo of a grassy area.""", "", 0],
22
  ["""Caption: 一个室内场景的水彩画,一个桌子上面放着一盘水果
23
  Objects: [('a table', [81, 242, 350, 210]), ('a plate of fruits', [151, 287, 210, 117])]
24
  Background prompt: A watercolor painting of an indoor scene""", "", 1],
25
  ["""Caption: In an empty indoor scene, a blue cube directly above a red cube with a vase on the left of them.
26
  Objects: [('a blue cube', [232, 116, 76, 76]), ('a red cube', [232, 212, 76, 76]), ('a vase', [100, 198, 62, 144])]
27
- Background prompt: An empty indoor scene""", "", 2],
28
  ["""Caption: A realistic photo of a wooden table without bananas in an indoor scene
29
  Objects: [('a wooden table', [75, 256, 365, 156])]
30
- Background prompt: A realistic photo of an indoor scene""", "", 3],
 
31
  ["""Caption: A realistic photo of two cars on the road.
32
  Objects: [('a car', [20, 242, 235, 185]), ('a car', [275, 246, 215, 180])]
33
  Background prompt: A realistic photo of a road.""", "A realistic photo of two cars on the road.", 4],
34
  ]
 
 
 
 
 
 
 
 
1
+ default_template = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512. The top-left corner has coordinate [0, 0]. The bottom-right corner has coordinnate [512, 512]. The bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object (i.e., start the object name with "a" or "an" if possible). Do not put objects that are already provided in the bounding boxes into the background prompt. Do not include non-existing or excluded objects in the background prompt. If needed, you can make reasonable guesses. Please refer to the example below for the desired format.
2
+
3
+ Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
4
+ Objects: [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
5
+ Background prompt: A realistic landscape scene
6
+ Negative prompt:
7
+
8
+ Caption: A realistic top-down view of a wooden table with two apples on it
9
+ Objects: [('a wooden table', [20, 148, 472, 216]), ('an apple', [150, 226, 100, 100]), ('an apple', [280, 226, 100, 100])]
10
+ Background prompt: A realistic top-down view
11
+ Negative prompt:
12
+
13
+ Caption: A realistic scene of three skiers standing in a line on the snow near a palm tree
14
+ Objects: [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]
15
+ Background prompt: A realistic outdoor scene with snow
16
+ Negative prompt:
17
+
18
+ Caption: An oil painting of a pink dolphin jumping on the left of a steam boat on the sea
19
+ Objects: [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]
20
+ Background prompt: An oil painting of the sea
21
+ Negative prompt:
22
+
23
+ Caption: A cute cat and an angry dog without birds
24
+ Objects: [('a cute cat', [51, 67, 271, 324]), ('an angry dog', [302, 119, 211, 228])]
25
+ Background prompt: A realistic scene
26
+ Negative prompt: birds
27
+
28
+ Caption: Two pandas in a forest without flowers
29
+ Objects: [('a panda', [30, 171, 212, 226]), ('a panda', [264, 173, 222, 221])]
30
+ Background prompt: A forest
31
+ Negative prompt: flowers
32
+
33
+ Caption: 一个客厅场景的油画,墙上挂着一幅画,电视下面是一个柜子,柜子上有一个花瓶,画里没有椅子。
34
+ Objects: [('a painting', [88, 85, 335, 203]), ('a cabinet', [57, 308, 404, 201]), ('a flower vase', [166, 222, 92, 108]), ('a flower vase', [328, 222, 92, 108])]
35
+ Background prompt: An oil painting of a living room scene
36
+ Negative prompt: chairs"""
37
+
38
+ simplified_prompt = """{template}
39
+
40
+ Caption: {prompt}
41
+ Objects: """
42
+
43
  stage1_examples = [
44
  ["""A realistic photo of a wooden table with an apple on the left and a pear on the right."""],
45
  ["""A realistic photo of 4 TVs on a wall."""],
 
52
 
53
  # Layout, seed
54
  stage2_examples = [
55
+ ["""Caption: A realistic top-down view of a wooden table with an apple on the left and a pear on the right.
56
  Objects: [('a wooden table', [30, 30, 452, 452]), ('an apple', [52, 223, 50, 60]), ('a pear', [400, 240, 50, 60])]
57
+ Background prompt: A realistic top-down view of a room""", "A realistic top-down view of a wooden table with an apple on the left and a pear on the right.", 0],
58
  ["""Caption: A realistic photo of 4 TVs on a wall.
59
  Objects: [('a TV', [12, 108, 120, 100]), ('a TV', [132, 112, 120, 100]), ('a TV', [252, 104, 120, 100]), ('a TV', [372, 106, 120, 100])]
60
+ Background prompt: A realistic photo of a wall""", "A realistic photo of 4 TVs on a wall.", 0],
61
  ["""Caption: A realistic photo of a gray cat and an orange dog on the grass.
62
  Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
63
+ Background prompt: A realistic photo of a grassy area.""", "A realistic photo of a gray cat and an orange dog on the grass.", 0],
64
  ["""Caption: 一个室内场景的水彩画,一个桌子上面放着一盘水果
65
  Objects: [('a table', [81, 242, 350, 210]), ('a plate of fruits', [151, 287, 210, 117])]
66
  Background prompt: A watercolor painting of an indoor scene""", "", 1],
67
  ["""Caption: In an empty indoor scene, a blue cube directly above a red cube with a vase on the left of them.
68
  Objects: [('a blue cube', [232, 116, 76, 76]), ('a red cube', [232, 212, 76, 76]), ('a vase', [100, 198, 62, 144])]
69
+ Background prompt: An empty indoor scene""", "In an empty indoor scene, a blue cube directly above a red cube with a vase on the left of them.", 2],
70
  ["""Caption: A realistic photo of a wooden table without bananas in an indoor scene
71
  Objects: [('a wooden table', [75, 256, 365, 156])]
72
+ Background prompt: A realistic photo of an indoor scene
73
+ Negative prompt: bananas""", "A realistic photo of a wooden table without bananas in an indoor scene", 3],
74
  ["""Caption: A realistic photo of two cars on the road.
75
  Objects: [('a car', [20, 242, 235, 185]), ('a car', [275, 246, 215, 180])]
76
  Background prompt: A realistic photo of a road.""", "A realistic photo of two cars on the road.", 4],
77
  ]
78
+
79
+
80
+ prompt_placeholder = "A realistic photo of a gray cat and an orange dog on the grass."
81
+
82
+ layout_placeholder = """Caption: A realistic photo of a gray cat and an orange dog on the grass.
83
+ Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
84
+ Background prompt: A realistic photo of a grassy area."""
generation.py CHANGED
@@ -1,19 +1,24 @@
1
- version = "v3.0"
2
-
3
  import torch
4
- import numpy as np
5
  import models
6
  import utils
7
  from models import pipelines, sam
8
- from utils import parse, latents
9
- from shared import model_dict, sam_model_dict, DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
10
- import gc
 
 
 
 
11
 
12
  verbose = False
13
- # Accelerates per-box generation
14
- use_fast_schedule = True
15
 
16
- vae, tokenizer, text_encoder, unet, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.dtype
 
 
 
 
 
 
17
 
18
  model_dict.update(sam_model_dict)
19
 
@@ -21,195 +26,472 @@ model_dict.update(sam_model_dict)
21
  # Hyperparams
22
  height = 512 # default height of Stable Diffusion
23
  width = 512 # default width of Stable Diffusion
24
- H, W = height // 8, width // 8 # size of the latent
25
  guidance_scale = 7.5 # Scale for classifier-free guidance
26
 
27
  # batch size that is not 1 is not supported
28
  overall_batch_size = 1
29
 
 
 
 
30
  # discourage masks with confidence below
31
  discourage_mask_below_confidence = 0.85
32
 
33
  # discourage masks with iou (with coarse binarized attention mask) below
34
  discourage_mask_below_coarse_iou = 0.25
35
 
 
 
 
36
  run_ind = None
37
 
38
 
39
- def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings,
40
- sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
41
- verbose=False, scheduler_key=None, visualize=True, batch_size=None, **kwargs):
42
- # batch_size=None: does not limit the batch size (pass all input together)
43
-
44
- # prompts and words are not used since we don't have cross-attention control in this function
45
-
46
- input_latents = torch.cat(input_latents_list, dim=0)
47
-
48
- # We need to "unsqueeze" to tell that we have only one box and phrase in each batch item
49
- bboxes, phrases = [[item] for item in bboxes], [[item] for item in phrases]
50
-
51
- input_len = len(bboxes)
52
- assert len(bboxes) == len(phrases), f"{len(bboxes)} != {len(phrases)}"
53
-
54
- if batch_size is None:
55
- batch_size = input_len
56
-
57
- run_times = int(np.ceil(input_len / batch_size))
58
- mask_selected_list, single_object_pil_images_box_ann, latents_all = [], [], []
59
- for batch_idx in range(run_times):
60
- input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
61
- bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
62
- input_embeddings_batch = input_embeddings[0], input_embeddings[1][batch_idx * batch_size:(batch_idx + 1) * batch_size]
63
-
64
- _, single_object_images_batch, single_object_pil_images_box_ann_batch, latents_all_batch = pipelines.generate_gligen(
65
- model_dict, input_latents_batch, input_embeddings_batch, num_inference_steps, bboxes_batch, phrases_batch, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
66
- guidance_scale=guidance_scale, return_saved_cross_attn=False,
67
- return_box_vis=True, save_all_latents=True, batched_condition=True, scheduler_key=scheduler_key, **kwargs
68
- )
69
-
70
- gc.collect()
71
- torch.cuda.empty_cache()
72
-
73
- # `sam_refine_boxes` also calls `empty_cache` so we don't need to explicitly empty the cache again.
74
- mask_selected, _ = sam.sam_refine_boxes(sam_input_images=single_object_images_batch, boxes=bboxes_batch, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
75
-
76
- mask_selected_list.append(np.array(mask_selected)[:, 0])
77
- single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
78
- latents_all.append(latents_all_batch)
79
-
80
- single_object_pil_images_box_ann, latents_all = sum(single_object_pil_images_box_ann, []), torch.cat(latents_all, dim=1)
81
-
82
- # mask_selected_list: List(batch)[List(image)[List(box)[Array of shape (64, 64)]]]
83
-
84
- mask_selected = np.concatenate(mask_selected_list, axis=0)
85
- mask_selected = mask_selected.reshape((-1, *mask_selected.shape[-2:]))
86
-
87
- assert mask_selected.shape[0] == input_latents.shape[0], f"{mask_selected.shape[0]} != {input_latents.shape[0]}"
88
-
89
- print(mask_selected.shape)
90
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  mask_selected_tensor = torch.tensor(mask_selected)
92
-
93
- latents_all = latents_all.transpose(0,1)[:,:,None,...]
94
-
95
- gc.collect()
96
- torch.cuda.empty_cache()
97
-
98
- return latents_all, mask_selected_tensor, single_object_pil_images_box_ann
99
-
100
- def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_list, so_input_embeddings, verbose=False, **kwargs):
101
- latents_all_list, mask_tensor_list = [], []
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if not so_prompt_phrase_word_box_list:
104
- return latents_all_list, mask_tensor_list
105
-
106
- prompts, bboxes, phrases, words = [], [], [], []
107
 
108
- for prompt, phrase, word, box in so_prompt_phrase_word_box_list:
109
- prompts.append(prompt)
110
- bboxes.append(box)
111
- phrases.append(phrase)
112
- words.append(word)
113
-
114
- latents_all_list, mask_tensor_list, so_img_list = generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings=so_input_embeddings, verbose=verbose, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- return latents_all_list, mask_tensor_list, so_img_list
117
 
118
 
119
  # Note: need to keep the supervision, especially the box corrdinates, corresponds to each other in single object and overall.
120
 
 
121
  def run(
122
- spec, bg_seed = 1, overall_prompt_override="", fg_seed_start = 20, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta = 0.3, num_inference_steps = 20,
123
- so_center_box = False, fg_blending_ratio = 0.1, scheduler_key='dpm_scheduler', so_negative_prompt = DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt = DEFAULT_OVERALL_NEGATIVE_PROMPT, so_horizontal_center_only = True,
124
- align_with_overall_bboxes = False, horizontal_shift_only = True, use_autocast = False, so_batch_size = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  ):
126
- """
127
  so_center_box: using centered box in single object generation
128
  so_horizontal_center_only: move to the center horizontally only
129
-
130
  align_with_overall_bboxes: Align the center of the mask, latents, and cross-attention with the center of the box in overall bboxes
131
  horizontal_shift_only: only shift horizontally for the alignment of mask, latents, and cross-attention
132
  """
133
-
134
- print("generation:", spec, bg_seed, fg_seed_start, frozen_step_ratio, gligen_scheduled_sampling_beta)
135
-
136
- frozen_step_ratio = min(max(frozen_step_ratio, 0.), 1.)
137
  frozen_steps = int(num_inference_steps * frozen_step_ratio)
138
 
139
- if True:
140
- so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes = parse.convert_spec(spec, height, width, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  if overall_prompt_override and overall_prompt_override.strip():
143
  overall_prompt = overall_prompt_override.strip()
144
 
145
- overall_phrases, overall_words, overall_bboxes = [item[0] for item in overall_phrases_words_bboxes], [item[1] for item in overall_phrases_words_bboxes], [item[2] for item in overall_phrases_words_bboxes]
 
 
 
 
146
 
147
  # The so box is centered but the overall boxes are not (since we need to place to the right place).
148
  if so_center_box:
149
- so_prompt_phrase_word_box_list = [(prompt, phrase, word, utils.get_centered_box(bbox, horizontal_center_only=so_horizontal_center_only)) for prompt, phrase, word, bbox in so_prompt_phrase_word_box_list]
 
 
 
 
 
 
 
 
 
 
150
  if verbose:
151
- print(f"centered so_prompt_phrase_word_box_list: {so_prompt_phrase_word_box_list}")
 
 
152
  so_boxes = [item[-1] for item in so_prompt_phrase_word_box_list]
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  sam_refine_kwargs = dict(
155
- discourage_mask_below_confidence=discourage_mask_below_confidence, discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
156
- height=height, width=width, H=H, W=W
 
 
 
 
157
  )
158
-
 
 
 
 
 
159
  # Note that so and overall use different negative prompts
160
 
161
  with torch.autocast("cuda", enabled=use_autocast):
162
  so_prompts = [item[0] for item in so_prompt_phrase_word_box_list]
163
  if so_prompts:
164
- so_input_embeddings = models.encode_prompts(prompts=so_prompts, tokenizer=tokenizer, text_encoder=text_encoder, negative_prompt=so_negative_prompt, one_uncond_input_only=True)
 
 
 
 
 
 
165
  else:
166
  so_input_embeddings = []
167
 
168
- overall_input_embeddings = models.encode_prompts(prompts=[overall_prompt], tokenizer=tokenizer, negative_prompt=overall_negative_prompt, text_encoder=text_encoder)
169
-
170
  input_latents_list, latents_bg = latents.get_input_latents_list(
171
- model_dict, bg_seed=bg_seed, fg_seed_start=fg_seed_start,
172
- so_boxes=so_boxes, fg_blending_ratio=fg_blending_ratio, height=height, width=width, verbose=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  )
174
- latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
175
- so_prompt_phrase_word_box_list, input_latents_list,
176
- gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
177
- sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose, batch_size=so_batch_size,
178
- fast_after_steps=frozen_steps if use_fast_schedule else None, fast_rate=2
 
 
 
 
 
 
 
 
 
179
  )
180
 
181
- composed_latents, foreground_indices, offset_list = latents.compose_latents_with_alignment(
182
- model_dict, latents_all_list, mask_tensor_list, num_inference_steps,
183
- overall_batch_size, height, width, latents_bg=latents_bg,
184
- align_with_overall_bboxes=align_with_overall_bboxes, overall_bboxes=overall_bboxes,
185
- horizontal_shift_only=horizontal_shift_only, use_fast_schedule=use_fast_schedule, fast_after_steps=frozen_steps
186
  )
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  overall_bboxes_flattened, overall_phrases_flattened = [], []
189
  for overall_bboxes_item, overall_phrase in zip(overall_bboxes, overall_phrases):
190
  for overall_bbox in overall_bboxes_item:
191
  overall_bboxes_flattened.append(overall_bbox)
192
  overall_phrases_flattened.append(overall_phrase)
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # Generate with composed latents
195
 
196
  # Foreground should be frozen
197
  frozen_mask = foreground_indices != 0
198
-
199
- regen_latents, images = pipelines.generate_gligen(
200
- model_dict, composed_latents, overall_input_embeddings, num_inference_steps,
201
- overall_bboxes_flattened, overall_phrases_flattened, guidance_scale=guidance_scale,
202
- gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
203
- frozen_steps=frozen_steps, frozen_mask=frozen_mask, scheduler_key=scheduler_key
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
 
206
- print(f"Generation with spatial guidance from input latents and first {frozen_steps} steps frozen (directly from the composed latents input)")
 
 
207
  print("Generation from composed latents (with semantic guidance)")
208
 
209
- # display(Image.fromarray(images[0]), "img", run_ind)
210
-
211
- gc.collect()
212
- torch.cuda.empty_cache()
213
-
214
- return images[0], so_img_list
215
 
 
 
 
 
1
  import torch
 
2
  import models
3
  import utils
4
  from models import pipelines, sam
5
+ from utils import parse, guidance, attn, latents, vis
6
+ from shared import (
7
+ model_dict,
8
+ sam_model_dict,
9
+ DEFAULT_SO_NEGATIVE_PROMPT,
10
+ DEFAULT_OVERALL_NEGATIVE_PROMPT,
11
+ )
12
 
13
  verbose = False
 
 
14
 
15
+ vae, tokenizer, text_encoder, unet, dtype = (
16
+ model_dict.vae,
17
+ model_dict.tokenizer,
18
+ model_dict.text_encoder,
19
+ model_dict.unet,
20
+ model_dict.dtype,
21
+ )
22
 
23
  model_dict.update(sam_model_dict)
24
 
 
26
  # Hyperparams
27
  height = 512 # default height of Stable Diffusion
28
  width = 512 # default width of Stable Diffusion
29
+ H, W = height // 8, width // 8 # size of the latent
30
  guidance_scale = 7.5 # Scale for classifier-free guidance
31
 
32
  # batch size that is not 1 is not supported
33
  overall_batch_size = 1
34
 
35
+ # semantic guidance kwargs (single object)
36
+ guidance_attn_keys = pipelines.DEFAULT_GUIDANCE_ATTN_KEYS
37
+
38
  # discourage masks with confidence below
39
  discourage_mask_below_confidence = 0.85
40
 
41
  # discourage masks with iou (with coarse binarized attention mask) below
42
  discourage_mask_below_coarse_iou = 0.25
43
 
44
+ # This is controls the foreground variations
45
+ fg_blending_ratio = 0.1
46
+
47
  run_ind = None
48
 
49
 
50
+ def generate_single_object_with_box(
51
+ prompt,
52
+ box,
53
+ phrase,
54
+ word,
55
+ input_latents,
56
+ input_embeddings,
57
+ semantic_guidance_kwargs,
58
+ obj_attn_key,
59
+ saved_cross_attn_keys,
60
+ sam_refine_kwargs,
61
+ num_inference_steps,
62
+ gligen_scheduled_sampling_beta=0.3,
63
+ verbose=False,
64
+ visualize=False,
65
+ **kwargs,
66
+ ):
67
+ bboxes, phrases, words = [box], [phrase], [word]
68
+
69
+ if verbose:
70
+ print(f"Getting token map (prompt: {prompt})")
71
+
72
+ object_positions, word_token_indices = guidance.get_phrase_indices(
73
+ tokenizer=tokenizer,
74
+ prompt=prompt,
75
+ phrases=phrases,
76
+ words=words,
77
+ return_word_token_indices=True,
78
+ # Since the prompt for single object is from background prompt + object name, we will not have the case of not found
79
+ add_suffix_if_not_found=False,
80
+ verbose=verbose,
81
+ )
82
+ # phrases only has one item, so we select the first item in word_token_indices
83
+ word_token_index = word_token_indices[0]
84
+
85
+ if verbose:
86
+ print("word_token_index:", word_token_index)
87
+
88
+ # `offload_guidance_cross_attn_to_cpu` will greatly slow down generation
89
+ (
90
+ latents,
91
+ single_object_images,
92
+ saved_attns,
93
+ single_object_pil_images_box_ann,
94
+ latents_all,
95
+ ) = pipelines.generate_gligen(
96
+ model_dict,
97
+ input_latents,
98
+ input_embeddings,
99
+ num_inference_steps,
100
+ bboxes,
101
+ phrases,
102
+ gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
103
+ guidance_scale=guidance_scale,
104
+ return_saved_cross_attn=True,
105
+ semantic_guidance=True,
106
+ semantic_guidance_bboxes=bboxes,
107
+ semantic_guidance_object_positions=object_positions,
108
+ semantic_guidance_kwargs=semantic_guidance_kwargs,
109
+ saved_cross_attn_keys=[obj_attn_key, *saved_cross_attn_keys],
110
+ return_cond_ca_only=True,
111
+ return_token_ca_only=word_token_index,
112
+ offload_cross_attn_to_cpu=False,
113
+ return_box_vis=True,
114
+ save_all_latents=True,
115
+ dynamic_num_inference_steps=True,
116
+ **kwargs,
117
+ )
118
+ # `saved_cross_attn_keys` kwargs may have duplicates
119
+
120
+ utils.free_memory()
121
+
122
+ single_object_pil_image_box_ann = single_object_pil_images_box_ann[0]
123
+
124
+ if visualize:
125
+ print("Single object image")
126
+ vis.display(single_object_pil_image_box_ann)
127
+
128
+ mask_selected, conf_score_selected = sam.sam_refine_box(
129
+ sam_input_image=single_object_images[0],
130
+ box=box,
131
+ model_dict=model_dict,
132
+ verbose=verbose,
133
+ **sam_refine_kwargs,
134
+ )
135
+
136
  mask_selected_tensor = torch.tensor(mask_selected)
137
+
138
+ if verbose:
139
+ vis.visualize(mask_selected, "Mask (selected) after resize")
140
+ # This is only for visualizations
141
+ masked_latents = latents_all * mask_selected_tensor[None, None, None, ...]
142
+ vis.visualize_masked_latents(
143
+ latents_all, masked_latents, timestep_T=False, timestep_0=True
144
+ )
145
+
146
+ return (
147
+ latents_all,
148
+ mask_selected_tensor,
149
+ saved_attns,
150
+ single_object_pil_image_box_ann,
151
+ )
152
+
153
+
154
+ def get_masked_latents_all_list(
155
+ so_prompt_phrase_word_box_list,
156
+ input_latents_list,
157
+ so_input_embeddings,
158
+ verbose=False,
159
+ **kwargs,
160
+ ):
161
+ latents_all_list, mask_tensor_list, saved_attns_list, so_img_list = [], [], [], []
162
+
163
  if not so_prompt_phrase_word_box_list:
164
+ return latents_all_list, mask_tensor_list, saved_attns_list
 
 
165
 
166
+ so_uncond_embeddings, so_cond_embeddings = so_input_embeddings
167
+
168
+ for idx, ((prompt, phrase, word, box), input_latents) in enumerate(
169
+ zip(so_prompt_phrase_word_box_list, input_latents_list)
170
+ ):
171
+ so_current_cond_embeddings = so_cond_embeddings[idx : idx + 1]
172
+ so_current_text_embeddings = torch.cat(
173
+ [so_uncond_embeddings, so_current_cond_embeddings], dim=0
174
+ )
175
+ so_current_input_embeddings = (
176
+ so_current_text_embeddings,
177
+ so_uncond_embeddings,
178
+ so_current_cond_embeddings,
179
+ )
180
+
181
+ latents_all, mask_tensor, saved_attns, so_img = generate_single_object_with_box(
182
+ prompt,
183
+ box,
184
+ phrase,
185
+ word,
186
+ input_latents,
187
+ input_embeddings=so_current_input_embeddings,
188
+ verbose=verbose,
189
+ **kwargs,
190
+ )
191
+ latents_all_list.append(latents_all)
192
+ mask_tensor_list.append(mask_tensor)
193
+ saved_attns_list.append(saved_attns)
194
+ so_img_list.append(so_img)
195
 
196
+ return latents_all_list, mask_tensor_list, saved_attns_list, so_img_list
197
 
198
 
199
  # Note: need to keep the supervision, especially the box corrdinates, corresponds to each other in single object and overall.
200
 
201
+
202
  def run(
203
+ spec,
204
+ bg_seed=1,
205
+ overall_prompt_override="",
206
+ fg_seed_start=20,
207
+ frozen_step_ratio=0.4,
208
+ num_inference_steps=20,
209
+ loss_scale=20,
210
+ loss_threshold=5.0,
211
+ max_iter=[2] * 5 + [1] * 10,
212
+ max_index_step=15,
213
+ overall_loss_scale=20,
214
+ overall_loss_threshold=5.0,
215
+ overall_max_iter=[4] * 5 + [3] * 5 + [2] * 5 + [2] * 5 + [1] * 10,
216
+ overall_max_index_step=30,
217
+ so_gligen_scheduled_sampling_beta=0.4,
218
+ overall_gligen_scheduled_sampling_beta=0.4,
219
+ ref_ca_loss_weight=0.5,
220
+ so_center_box=False,
221
+ fg_blending_ratio=0.1,
222
+ scheduler_key="dpm_scheduler",
223
+ so_negative_prompt=DEFAULT_SO_NEGATIVE_PROMPT,
224
+ overall_negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT,
225
+ so_horizontal_center_only=True,
226
+ align_with_overall_bboxes=False,
227
+ horizontal_shift_only=True,
228
+ use_fast_schedule=True,
229
+ # Transfer the cross-attention from single object generation (with ref_ca_saved_attns)
230
+ # Use reference cross attention to guide the cross attention in the overall generation
231
+ use_ref_ca=True,
232
+ use_autocast=False,
233
  ):
234
+ """
235
  so_center_box: using centered box in single object generation
236
  so_horizontal_center_only: move to the center horizontally only
237
+
238
  align_with_overall_bboxes: Align the center of the mask, latents, and cross-attention with the center of the box in overall bboxes
239
  horizontal_shift_only: only shift horizontally for the alignment of mask, latents, and cross-attention
240
  """
241
+
242
+ frozen_step_ratio = min(max(frozen_step_ratio, 0.0), 1.0)
 
 
243
  frozen_steps = int(num_inference_steps * frozen_step_ratio)
244
 
245
+ print(
246
+ "generation:",
247
+ spec,
248
+ bg_seed,
249
+ fg_seed_start,
250
+ frozen_step_ratio,
251
+ so_gligen_scheduled_sampling_beta,
252
+ overall_gligen_scheduled_sampling_beta,
253
+ overall_max_index_step,
254
+ )
255
+
256
+ (
257
+ so_prompt_phrase_word_box_list,
258
+ overall_prompt,
259
+ overall_phrases_words_bboxes,
260
+ ) = parse.convert_spec(spec, height, width, verbose=verbose)
261
 
262
  if overall_prompt_override and overall_prompt_override.strip():
263
  overall_prompt = overall_prompt_override.strip()
264
 
265
+ overall_phrases, overall_words, overall_bboxes = (
266
+ [item[0] for item in overall_phrases_words_bboxes],
267
+ [item[1] for item in overall_phrases_words_bboxes],
268
+ [item[2] for item in overall_phrases_words_bboxes],
269
+ )
270
 
271
  # The so box is centered but the overall boxes are not (since we need to place to the right place).
272
  if so_center_box:
273
+ so_prompt_phrase_word_box_list = [
274
+ (
275
+ prompt,
276
+ phrase,
277
+ word,
278
+ utils.get_centered_box(
279
+ bbox, horizontal_center_only=so_horizontal_center_only
280
+ ),
281
+ )
282
+ for prompt, phrase, word, bbox in so_prompt_phrase_word_box_list
283
+ ]
284
  if verbose:
285
+ print(
286
+ f"centered so_prompt_phrase_word_box_list: {so_prompt_phrase_word_box_list}"
287
+ )
288
  so_boxes = [item[-1] for item in so_prompt_phrase_word_box_list]
289
 
290
+ so_negative_prompt = DEFAULT_SO_NEGATIVE_PROMPT
291
+ overall_negative_prompt = DEFAULT_OVERALL_NEGATIVE_PROMPT
292
+ if "extra_neg_prompt" in spec and spec["extra_neg_prompt"]:
293
+ so_negative_prompt = spec["extra_neg_prompt"] + ", " + so_negative_prompt
294
+ overall_negative_prompt = (
295
+ spec["extra_neg_prompt"] + ", " + overall_negative_prompt
296
+ )
297
+
298
+ semantic_guidance_kwargs = dict(
299
+ loss_scale=loss_scale,
300
+ loss_threshold=loss_threshold,
301
+ max_iter=max_iter,
302
+ max_index_step=max_index_step,
303
+ use_ratio_based_loss=False,
304
+ guidance_attn_keys=guidance_attn_keys,
305
+ verbose=True,
306
+ )
307
+
308
  sam_refine_kwargs = dict(
309
+ discourage_mask_below_confidence=discourage_mask_below_confidence,
310
+ discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
311
+ height=height,
312
+ width=width,
313
+ H=H,
314
+ W=W,
315
  )
316
+
317
+ if verbose:
318
+ vis.visualize_bboxes(
319
+ bboxes=[item[-1] for item in so_prompt_phrase_word_box_list], H=H, W=W
320
+ )
321
+
322
  # Note that so and overall use different negative prompts
323
 
324
  with torch.autocast("cuda", enabled=use_autocast):
325
  so_prompts = [item[0] for item in so_prompt_phrase_word_box_list]
326
  if so_prompts:
327
+ so_input_embeddings = models.encode_prompts(
328
+ prompts=so_prompts,
329
+ tokenizer=tokenizer,
330
+ text_encoder=text_encoder,
331
+ negative_prompt=so_negative_prompt,
332
+ one_uncond_input_only=True,
333
+ )
334
  else:
335
  so_input_embeddings = []
336
 
 
 
337
  input_latents_list, latents_bg = latents.get_input_latents_list(
338
+ model_dict,
339
+ bg_seed=bg_seed,
340
+ fg_seed_start=fg_seed_start,
341
+ so_boxes=so_boxes,
342
+ fg_blending_ratio=fg_blending_ratio,
343
+ height=height,
344
+ width=width,
345
+ verbose=False,
346
+ )
347
+
348
+ if use_fast_schedule:
349
+ fast_after_steps = max(frozen_steps, overall_max_index_step) if use_ref_ca else frozen_steps
350
+ else:
351
+ fast_after_steps = None
352
+
353
+ if use_ref_ca or frozen_steps > 0:
354
+ (
355
+ latents_all_list,
356
+ mask_tensor_list,
357
+ saved_attns_list,
358
+ so_img_list,
359
+ ) = get_masked_latents_all_list(
360
+ so_prompt_phrase_word_box_list,
361
+ input_latents_list,
362
+ gligen_scheduled_sampling_beta=so_gligen_scheduled_sampling_beta,
363
+ semantic_guidance_kwargs=semantic_guidance_kwargs,
364
+ obj_attn_key=("down", 2, 1, 0),
365
+ saved_cross_attn_keys=guidance_attn_keys if use_ref_ca else [],
366
+ sam_refine_kwargs=sam_refine_kwargs,
367
+ so_input_embeddings=so_input_embeddings,
368
+ num_inference_steps=num_inference_steps,
369
+ scheduler_key=scheduler_key,
370
+ verbose=verbose,
371
+ fast_after_steps=fast_after_steps,
372
+ fast_rate=2,
373
+ )
374
+ else:
375
+ # No per-box guidance
376
+ (latents_all_list, mask_tensor_list, saved_attns_list, so_img_list) = [], [], [], []
377
+
378
+ (
379
+ composed_latents,
380
+ foreground_indices,
381
+ offset_list,
382
+ ) = latents.compose_latents_with_alignment(
383
+ model_dict,
384
+ latents_all_list,
385
+ mask_tensor_list,
386
+ num_inference_steps,
387
+ overall_batch_size,
388
+ height,
389
+ width,
390
+ latents_bg=latents_bg,
391
+ align_with_overall_bboxes=align_with_overall_bboxes,
392
+ overall_bboxes=overall_bboxes,
393
+ horizontal_shift_only=horizontal_shift_only,
394
+ use_fast_schedule=use_fast_schedule,
395
+ fast_after_steps=fast_after_steps,
396
  )
397
+
398
+ # NOTE: need to ensure overall embeddings are generated after the update of overall prompt
399
+ (
400
+ overall_object_positions,
401
+ overall_word_token_indices,
402
+ overall_prompt
403
+ ) = guidance.get_phrase_indices(
404
+ tokenizer=tokenizer,
405
+ prompt=overall_prompt,
406
+ phrases=overall_phrases,
407
+ words=overall_words,
408
+ verbose=verbose,
409
+ return_word_token_indices=True,
410
+ add_suffix_if_not_found=True
411
  )
412
 
413
+ overall_input_embeddings = models.encode_prompts(
414
+ prompts=[overall_prompt],
415
+ tokenizer=tokenizer,
416
+ negative_prompt=overall_negative_prompt,
417
+ text_encoder=text_encoder,
418
  )
419
+
420
+ if use_ref_ca:
421
+ # ref_ca_saved_attns has the same hierarchy as bboxes
422
+ ref_ca_saved_attns = []
423
+
424
+ flattened_box_idx = 0
425
+ for bboxes in overall_bboxes:
426
+ # bboxes: correspond to a phrase
427
+ ref_ca_current_phrase_saved_attns = []
428
+ for bbox in bboxes:
429
+ # each individual bbox
430
+ saved_attns = saved_attns_list[flattened_box_idx]
431
+ if align_with_overall_bboxes:
432
+ offset = offset_list[flattened_box_idx]
433
+ saved_attns = attn.shift_saved_attns(
434
+ saved_attns,
435
+ offset,
436
+ guidance_attn_keys=guidance_attn_keys,
437
+ horizontal_shift_only=horizontal_shift_only,
438
+ )
439
+ ref_ca_current_phrase_saved_attns.append(saved_attns)
440
+ flattened_box_idx += 1
441
+ ref_ca_saved_attns.append(ref_ca_current_phrase_saved_attns)
442
+
443
  overall_bboxes_flattened, overall_phrases_flattened = [], []
444
  for overall_bboxes_item, overall_phrase in zip(overall_bboxes, overall_phrases):
445
  for overall_bbox in overall_bboxes_item:
446
  overall_bboxes_flattened.append(overall_bbox)
447
  overall_phrases_flattened.append(overall_phrase)
448
 
449
+ # This is currently not-shared with the single object one.
450
+ overall_semantic_guidance_kwargs = dict(
451
+ loss_scale=overall_loss_scale,
452
+ loss_threshold=overall_loss_threshold,
453
+ max_iter=overall_max_iter,
454
+ max_index_step=overall_max_index_step,
455
+ # ref_ca comes from the attention map of the word token of the phrase in single object generation, so we apply it only to the word token of the phrase in overall generation.
456
+ ref_ca_word_token_only=True,
457
+ # If a word is not provided, we use the last token.
458
+ ref_ca_last_token_only=True,
459
+ ref_ca_saved_attns=ref_ca_saved_attns if use_ref_ca else None,
460
+ word_token_indices=overall_word_token_indices,
461
+ guidance_attn_keys=guidance_attn_keys,
462
+ ref_ca_loss_weight=ref_ca_loss_weight,
463
+ use_ratio_based_loss=False,
464
+ verbose=True,
465
+ )
466
+
467
  # Generate with composed latents
468
 
469
  # Foreground should be frozen
470
  frozen_mask = foreground_indices != 0
471
+
472
+ _, images = pipelines.generate_gligen(
473
+ model_dict,
474
+ composed_latents,
475
+ overall_input_embeddings,
476
+ num_inference_steps,
477
+ overall_bboxes_flattened,
478
+ overall_phrases_flattened,
479
+ guidance_scale=guidance_scale,
480
+ gligen_scheduled_sampling_beta=overall_gligen_scheduled_sampling_beta,
481
+ semantic_guidance=True,
482
+ semantic_guidance_bboxes=overall_bboxes,
483
+ semantic_guidance_object_positions=overall_object_positions,
484
+ semantic_guidance_kwargs=overall_semantic_guidance_kwargs,
485
+ frozen_steps=frozen_steps,
486
+ frozen_mask=frozen_mask,
487
+ scheduler_key=scheduler_key,
488
  )
489
 
490
+ print(
491
+ f"Generation with spatial guidance from input latents and first {frozen_steps} steps frozen (directly from the composed latents input)"
492
+ )
493
  print("Generation from composed latents (with semantic guidance)")
494
 
495
+ utils.free_memory()
 
 
 
 
 
496
 
497
+ return images[0], so_img_list
models/modeling_utils.py DELETED
@@ -1,874 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import inspect
18
- import itertools
19
- import os
20
- from functools import partial
21
- from typing import Any, Callable, List, Optional, Tuple, Union
22
-
23
- import torch
24
- from torch import Tensor, device
25
-
26
- from diffusers import __version__
27
- from diffusers.utils import (
28
- CONFIG_NAME,
29
- DIFFUSERS_CACHE,
30
- FLAX_WEIGHTS_NAME,
31
- HF_HUB_OFFLINE,
32
- SAFETENSORS_WEIGHTS_NAME,
33
- WEIGHTS_NAME,
34
- _add_variant,
35
- _get_model_file,
36
- deprecate,
37
- is_accelerate_available,
38
- is_safetensors_available,
39
- is_torch_version,
40
- logging,
41
- )
42
-
43
-
44
- logger = logging.get_logger(__name__)
45
-
46
-
47
- if is_torch_version(">=", "1.9.0"):
48
- _LOW_CPU_MEM_USAGE_DEFAULT = True
49
- else:
50
- _LOW_CPU_MEM_USAGE_DEFAULT = False
51
-
52
-
53
- if is_accelerate_available():
54
- import accelerate
55
- from accelerate.utils import set_module_tensor_to_device
56
- from accelerate.utils.versions import is_torch_version
57
-
58
- if is_safetensors_available():
59
- import safetensors
60
-
61
-
62
- def get_parameter_device(parameter: torch.nn.Module):
63
- try:
64
- parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
65
- return next(parameters_and_buffers).device
66
- except StopIteration:
67
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
68
-
69
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
70
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
71
- return tuples
72
-
73
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
74
- first_tuple = next(gen)
75
- return first_tuple[1].device
76
-
77
-
78
- def get_parameter_dtype(parameter: torch.nn.Module):
79
- try:
80
- params = tuple(parameter.parameters())
81
- if len(params) > 0:
82
- return params[0].dtype
83
-
84
- buffers = tuple(parameter.buffers())
85
- if len(buffers) > 0:
86
- return buffers[0].dtype
87
-
88
- except StopIteration:
89
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
90
-
91
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
92
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
93
- return tuples
94
-
95
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
96
- first_tuple = next(gen)
97
- return first_tuple[1].dtype
98
-
99
-
100
- def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
101
- """
102
- Reads a checkpoint file, returning properly formatted errors if they arise.
103
- """
104
- try:
105
- if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
106
- return torch.load(checkpoint_file, map_location="cpu")
107
- else:
108
- return safetensors.torch.load_file(checkpoint_file, device="cpu")
109
- except Exception as e:
110
- try:
111
- with open(checkpoint_file) as f:
112
- if f.read().startswith("version"):
113
- raise OSError(
114
- "You seem to have cloned a repository without having git-lfs installed. Please install "
115
- "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
116
- "you cloned."
117
- )
118
- else:
119
- raise ValueError(
120
- f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
121
- "model. Make sure you have saved the model properly."
122
- ) from e
123
- except (UnicodeDecodeError, ValueError):
124
- raise OSError(
125
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
126
- f"at '{checkpoint_file}'. "
127
- "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
128
- )
129
-
130
-
131
- def _load_state_dict_into_model(model_to_load, state_dict):
132
- # Convert old format to new format if needed from a PyTorch state_dict
133
- # copy state_dict so _load_from_state_dict can modify it
134
- state_dict = state_dict.copy()
135
- error_msgs = []
136
-
137
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
138
- # so we need to apply the function recursively.
139
- def load(module: torch.nn.Module, prefix=""):
140
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
141
- module._load_from_state_dict(*args)
142
-
143
- for name, child in module._modules.items():
144
- if child is not None:
145
- load(child, prefix + name + ".")
146
-
147
- load(model_to_load)
148
-
149
- return error_msgs
150
-
151
-
152
- class ModelMixin(torch.nn.Module):
153
- r"""
154
- Base class for all models.
155
-
156
- [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
157
- and saving models.
158
-
159
- - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
160
- [`~models.ModelMixin.save_pretrained`].
161
- """
162
- config_name = CONFIG_NAME
163
- _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
164
- _supports_gradient_checkpointing = False
165
-
166
- def __init__(self):
167
- super().__init__()
168
-
169
- def __getattr__(self, name: str) -> Any:
170
- """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
171
- config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
172
- __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
173
- https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
174
- """
175
-
176
- is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
177
- is_attribute = name in self.__dict__
178
-
179
- if is_in_config and not is_attribute:
180
- deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
181
- deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
182
- return self._internal_dict[name]
183
-
184
- # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
185
- return super().__getattr__(name)
186
-
187
- @property
188
- def is_gradient_checkpointing(self) -> bool:
189
- """
190
- Whether gradient checkpointing is activated for this model or not.
191
-
192
- Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
193
- activations".
194
- """
195
- return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
196
-
197
- def enable_gradient_checkpointing(self):
198
- """
199
- Activates gradient checkpointing for the current model.
200
-
201
- Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
202
- activations".
203
- """
204
- if not self._supports_gradient_checkpointing:
205
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
206
- self.apply(partial(self._set_gradient_checkpointing, value=True))
207
-
208
- def disable_gradient_checkpointing(self):
209
- """
210
- Deactivates gradient checkpointing for the current model.
211
-
212
- Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
213
- activations".
214
- """
215
- if self._supports_gradient_checkpointing:
216
- self.apply(partial(self._set_gradient_checkpointing, value=False))
217
-
218
- def set_use_memory_efficient_attention_xformers(
219
- self, valid: bool, attention_op: Optional[Callable] = None
220
- ) -> None:
221
- # Recursively walk through all the children.
222
- # Any children which exposes the set_use_memory_efficient_attention_xformers method
223
- # gets the message
224
- def fn_recursive_set_mem_eff(module: torch.nn.Module):
225
- if hasattr(module, "set_use_memory_efficient_attention_xformers"):
226
- module.set_use_memory_efficient_attention_xformers(valid, attention_op)
227
-
228
- for child in module.children():
229
- fn_recursive_set_mem_eff(child)
230
-
231
- for module in self.children():
232
- if isinstance(module, torch.nn.Module):
233
- fn_recursive_set_mem_eff(module)
234
-
235
- def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
236
- r"""
237
- Enable memory efficient attention as implemented in xformers.
238
-
239
- When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
240
- time. Speed up at training time is not guaranteed.
241
-
242
- Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
243
- is used.
244
-
245
- Parameters:
246
- attention_op (`Callable`, *optional*):
247
- Override the default `None` operator for use as `op` argument to the
248
- [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
249
- function of xFormers.
250
-
251
- Examples:
252
-
253
- ```py
254
- >>> import torch
255
- >>> from diffusers import UNet2DConditionModel
256
- >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
257
-
258
- >>> model = UNet2DConditionModel.from_pretrained(
259
- ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
260
- ... )
261
- >>> model = model.to("cuda")
262
- >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
263
- ```
264
- """
265
- self.set_use_memory_efficient_attention_xformers(True, attention_op)
266
-
267
- def disable_xformers_memory_efficient_attention(self):
268
- r"""
269
- Disable memory efficient attention as implemented in xformers.
270
- """
271
- self.set_use_memory_efficient_attention_xformers(False)
272
-
273
- def save_pretrained(
274
- self,
275
- save_directory: Union[str, os.PathLike],
276
- is_main_process: bool = True,
277
- save_function: Callable = None,
278
- safe_serialization: bool = False,
279
- variant: Optional[str] = None,
280
- ):
281
- """
282
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
283
- `[`~models.ModelMixin.from_pretrained`]` class method.
284
-
285
- Arguments:
286
- save_directory (`str` or `os.PathLike`):
287
- Directory to which to save. Will be created if it doesn't exist.
288
- is_main_process (`bool`, *optional*, defaults to `True`):
289
- Whether the process calling this is the main process or not. Useful when in distributed training like
290
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
291
- the main process to avoid race conditions.
292
- save_function (`Callable`):
293
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
294
- need to replace `torch.save` by another method. Can be configured with the environment variable
295
- `DIFFUSERS_SAVE_MODE`.
296
- safe_serialization (`bool`, *optional*, defaults to `False`):
297
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
298
- variant (`str`, *optional*):
299
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
300
- """
301
- if safe_serialization and not is_safetensors_available():
302
- raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
303
-
304
- if os.path.isfile(save_directory):
305
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
306
- return
307
-
308
- os.makedirs(save_directory, exist_ok=True)
309
-
310
- model_to_save = self
311
-
312
- # Attach architecture to the config
313
- # Save the config
314
- if is_main_process:
315
- model_to_save.save_config(save_directory)
316
-
317
- # Save the model
318
- state_dict = model_to_save.state_dict()
319
-
320
- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
321
- weights_name = _add_variant(weights_name, variant)
322
-
323
- # Save the model
324
- if safe_serialization:
325
- safetensors.torch.save_file(
326
- state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
327
- )
328
- else:
329
- torch.save(state_dict, os.path.join(save_directory, weights_name))
330
-
331
- logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
332
-
333
- @classmethod
334
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
335
- r"""
336
- Instantiate a pretrained pytorch model from a pre-trained model configuration.
337
-
338
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
339
- the model, you should first set it back in training mode with `model.train()`.
340
-
341
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
342
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
343
- task.
344
-
345
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
346
- weights are discarded.
347
-
348
- Parameters:
349
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
350
- Can be either:
351
-
352
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
353
- Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
354
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
355
- `./my_model_directory/`.
356
-
357
- cache_dir (`Union[str, os.PathLike]`, *optional*):
358
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
359
- standard cache should not be used.
360
- torch_dtype (`str` or `torch.dtype`, *optional*):
361
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
362
- will be automatically derived from the model's weights.
363
- force_download (`bool`, *optional*, defaults to `False`):
364
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
365
- cached versions if they exist.
366
- resume_download (`bool`, *optional*, defaults to `False`):
367
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
368
- file exists.
369
- proxies (`Dict[str, str]`, *optional*):
370
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
371
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
372
- output_loading_info(`bool`, *optional*, defaults to `False`):
373
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
374
- local_files_only(`bool`, *optional*, defaults to `False`):
375
- Whether or not to only look at local files (i.e., do not try to download the model).
376
- use_auth_token (`str` or *bool*, *optional*):
377
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
378
- when running `diffusers-cli login` (stored in `~/.huggingface`).
379
- revision (`str`, *optional*, defaults to `"main"`):
380
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
381
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
382
- identifier allowed by git.
383
- from_flax (`bool`, *optional*, defaults to `False`):
384
- Load the model weights from a Flax checkpoint save file.
385
- subfolder (`str`, *optional*, defaults to `""`):
386
- In case the relevant files are located inside a subfolder of the model repo (either remote in
387
- huggingface.co or downloaded locally), you can specify the folder name here.
388
-
389
- mirror (`str`, *optional*):
390
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
391
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
392
- Please refer to the mirror site for more information.
393
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
394
- A map that specifies where each submodule should go. It doesn't need to be refined to each
395
- parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
396
- same device.
397
-
398
- To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
399
- more information about each option see [designing a device
400
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
401
- max_memory (`Dict`, *optional*):
402
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
403
- GPU and the available CPU RAM if unset.
404
- offload_folder (`str` or `os.PathLike`, *optional*):
405
- If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
406
- offload_state_dict (`bool`, *optional*):
407
- If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
408
- RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
409
- `True` when there is some disk offload.
410
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
411
- Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
412
- also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
413
- model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
414
- setting this argument to `True` will raise an error.
415
- variant (`str`, *optional*):
416
- If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
417
- ignored when using `from_flax`.
418
- use_safetensors (`bool`, *optional*, defaults to `None`):
419
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
420
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
421
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
422
-
423
- <Tip>
424
-
425
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
426
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
427
-
428
- </Tip>
429
-
430
- <Tip>
431
-
432
- Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
433
- this method in a firewalled environment.
434
-
435
- </Tip>
436
-
437
- """
438
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
439
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
440
- force_download = kwargs.pop("force_download", False)
441
- from_flax = kwargs.pop("from_flax", False)
442
- resume_download = kwargs.pop("resume_download", False)
443
- proxies = kwargs.pop("proxies", None)
444
- output_loading_info = kwargs.pop("output_loading_info", False)
445
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
446
- use_auth_token = kwargs.pop("use_auth_token", None)
447
- revision = kwargs.pop("revision", None)
448
- torch_dtype = kwargs.pop("torch_dtype", None)
449
- subfolder = kwargs.pop("subfolder", None)
450
- device_map = kwargs.pop("device_map", None)
451
- max_memory = kwargs.pop("max_memory", None)
452
- offload_folder = kwargs.pop("offload_folder", None)
453
- offload_state_dict = kwargs.pop("offload_state_dict", False)
454
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
455
- variant = kwargs.pop("variant", None)
456
- use_safetensors = kwargs.pop("use_safetensors", None)
457
-
458
- if use_safetensors and not is_safetensors_available():
459
- raise ValueError(
460
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
461
- )
462
-
463
- allow_pickle = False
464
- if use_safetensors is None:
465
- use_safetensors = is_safetensors_available()
466
- allow_pickle = True
467
-
468
- if low_cpu_mem_usage and not is_accelerate_available():
469
- low_cpu_mem_usage = False
470
- logger.warning(
471
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
472
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
473
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
474
- " install accelerate\n```\n."
475
- )
476
-
477
- if device_map is not None and not is_accelerate_available():
478
- raise NotImplementedError(
479
- "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
480
- " `device_map=None`. You can install accelerate with `pip install accelerate`."
481
- )
482
-
483
- # Check if we can handle device_map and dispatching the weights
484
- if device_map is not None and not is_torch_version(">=", "1.9.0"):
485
- raise NotImplementedError(
486
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
487
- " `device_map=None`."
488
- )
489
-
490
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
491
- raise NotImplementedError(
492
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
493
- " `low_cpu_mem_usage=False`."
494
- )
495
-
496
- if low_cpu_mem_usage is False and device_map is not None:
497
- raise ValueError(
498
- f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
499
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
500
- )
501
-
502
- # Load config if we don't provide a configuration
503
- config_path = pretrained_model_name_or_path
504
-
505
- user_agent = {
506
- "diffusers": __version__,
507
- "file_type": "model",
508
- "framework": "pytorch",
509
- }
510
-
511
- # load config
512
- config, unused_kwargs, commit_hash = cls.load_config(
513
- config_path,
514
- cache_dir=cache_dir,
515
- return_unused_kwargs=True,
516
- return_commit_hash=True,
517
- force_download=force_download,
518
- resume_download=resume_download,
519
- proxies=proxies,
520
- local_files_only=local_files_only,
521
- use_auth_token=use_auth_token,
522
- revision=revision,
523
- subfolder=subfolder,
524
- device_map=device_map,
525
- max_memory=max_memory,
526
- offload_folder=offload_folder,
527
- offload_state_dict=offload_state_dict,
528
- user_agent=user_agent,
529
- **kwargs,
530
- )
531
-
532
- # load model
533
- model_file = None
534
- if from_flax:
535
- model_file = _get_model_file(
536
- pretrained_model_name_or_path,
537
- weights_name=FLAX_WEIGHTS_NAME,
538
- cache_dir=cache_dir,
539
- force_download=force_download,
540
- resume_download=resume_download,
541
- proxies=proxies,
542
- local_files_only=local_files_only,
543
- use_auth_token=use_auth_token,
544
- revision=revision,
545
- subfolder=subfolder,
546
- user_agent=user_agent,
547
- commit_hash=commit_hash,
548
- )
549
- model = cls.from_config(config, **unused_kwargs)
550
-
551
- # Convert the weights
552
- from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
553
-
554
- model = load_flax_checkpoint_in_pytorch_model(model, model_file)
555
- else:
556
- if use_safetensors:
557
- try:
558
- model_file = _get_model_file(
559
- pretrained_model_name_or_path,
560
- weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
561
- cache_dir=cache_dir,
562
- force_download=force_download,
563
- resume_download=resume_download,
564
- proxies=proxies,
565
- local_files_only=local_files_only,
566
- use_auth_token=use_auth_token,
567
- revision=revision,
568
- subfolder=subfolder,
569
- user_agent=user_agent,
570
- commit_hash=commit_hash,
571
- )
572
- except IOError as e:
573
- if not allow_pickle:
574
- raise e
575
- pass
576
- if model_file is None:
577
- model_file = _get_model_file(
578
- pretrained_model_name_or_path,
579
- weights_name=_add_variant(WEIGHTS_NAME, variant),
580
- cache_dir=cache_dir,
581
- force_download=force_download,
582
- resume_download=resume_download,
583
- proxies=proxies,
584
- local_files_only=local_files_only,
585
- use_auth_token=use_auth_token,
586
- revision=revision,
587
- subfolder=subfolder,
588
- user_agent=user_agent,
589
- commit_hash=commit_hash,
590
- )
591
-
592
- if low_cpu_mem_usage:
593
- # Instantiate model with empty weights
594
- with accelerate.init_empty_weights():
595
- model = cls.from_config(config, **unused_kwargs)
596
-
597
- # if device_map is None, load the state dict and move the params from meta device to the cpu
598
- if device_map is None:
599
- param_device = "cpu"
600
- state_dict = load_state_dict(model_file, variant=variant)
601
- model._convert_deprecated_attention_blocks(state_dict)
602
- # move the params from meta device to cpu
603
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
604
- if len(missing_keys) > 0:
605
- raise ValueError(
606
- f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
607
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
608
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
609
- " those weights or else make sure your checkpoint file is correct."
610
- )
611
-
612
- empty_state_dict = model.state_dict()
613
- for param_name, param in state_dict.items():
614
- accepts_dtype = "dtype" in set(
615
- inspect.signature(set_module_tensor_to_device).parameters.keys()
616
- )
617
-
618
- if empty_state_dict[param_name].shape != param.shape:
619
- raise ValueError(
620
- f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
621
- )
622
-
623
- if accepts_dtype:
624
- set_module_tensor_to_device(
625
- model, param_name, param_device, value=param, dtype=torch_dtype
626
- )
627
- else:
628
- set_module_tensor_to_device(model, param_name, param_device, value=param)
629
- else: # else let accelerate handle loading and dispatching.
630
- # Load weights and dispatch according to the device_map
631
- # by default the device_map is None and the weights are loaded on the CPU
632
- accelerate.load_checkpoint_and_dispatch(
633
- model,
634
- model_file,
635
- device_map,
636
- max_memory=max_memory,
637
- offload_folder=offload_folder,
638
- offload_state_dict=offload_state_dict,
639
- dtype=torch_dtype,
640
- )
641
-
642
- loading_info = {
643
- "missing_keys": [],
644
- "unexpected_keys": [],
645
- "mismatched_keys": [],
646
- "error_msgs": [],
647
- }
648
- else:
649
- model = cls.from_config(config, **unused_kwargs)
650
-
651
- state_dict = load_state_dict(model_file, variant=variant)
652
- model._convert_deprecated_attention_blocks(state_dict)
653
-
654
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
655
- model,
656
- state_dict,
657
- model_file,
658
- pretrained_model_name_or_path,
659
- ignore_mismatched_sizes=ignore_mismatched_sizes,
660
- )
661
-
662
- loading_info = {
663
- "missing_keys": missing_keys,
664
- "unexpected_keys": unexpected_keys,
665
- "mismatched_keys": mismatched_keys,
666
- "error_msgs": error_msgs,
667
- }
668
-
669
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
670
- raise ValueError(
671
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
672
- )
673
- elif torch_dtype is not None:
674
- model = model.to(torch_dtype)
675
-
676
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
677
-
678
- # Set model in evaluation mode to deactivate DropOut modules by default
679
- model.eval()
680
- if output_loading_info:
681
- return model, loading_info
682
-
683
- return model
684
-
685
- @classmethod
686
- def _load_pretrained_model(
687
- cls,
688
- model,
689
- state_dict,
690
- resolved_archive_file,
691
- pretrained_model_name_or_path,
692
- ignore_mismatched_sizes=False,
693
- ):
694
- # Retrieve missing & unexpected_keys
695
- model_state_dict = model.state_dict()
696
- loaded_keys = list(state_dict.keys())
697
-
698
- expected_keys = list(model_state_dict.keys())
699
-
700
- original_loaded_keys = loaded_keys
701
-
702
- missing_keys = list(set(expected_keys) - set(loaded_keys))
703
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
704
-
705
- # Make sure we are able to load base models as well as derived models (with heads)
706
- model_to_load = model
707
-
708
- def _find_mismatched_keys(
709
- state_dict,
710
- model_state_dict,
711
- loaded_keys,
712
- ignore_mismatched_sizes,
713
- ):
714
- mismatched_keys = []
715
- if ignore_mismatched_sizes:
716
- for checkpoint_key in loaded_keys:
717
- model_key = checkpoint_key
718
-
719
- if (
720
- model_key in model_state_dict
721
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
722
- ):
723
- mismatched_keys.append(
724
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
725
- )
726
- del state_dict[checkpoint_key]
727
- return mismatched_keys
728
-
729
- if state_dict is not None:
730
- # Whole checkpoint
731
- mismatched_keys = _find_mismatched_keys(
732
- state_dict,
733
- model_state_dict,
734
- original_loaded_keys,
735
- ignore_mismatched_sizes,
736
- )
737
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
738
-
739
- if len(error_msgs) > 0:
740
- error_msg = "\n\t".join(error_msgs)
741
- if "size mismatch" in error_msg:
742
- error_msg += (
743
- "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
744
- )
745
- raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
746
-
747
- if len(unexpected_keys) > 0:
748
- logger.warning(
749
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
750
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
751
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
752
- " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
753
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
754
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
755
- " identical (initializing a BertForSequenceClassification model from a"
756
- " BertForSequenceClassification model)."
757
- )
758
- else:
759
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
760
- if len(missing_keys) > 0:
761
- logger.warning(
762
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
763
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
764
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
765
- )
766
- elif len(mismatched_keys) == 0:
767
- logger.info(
768
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
769
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
770
- f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
771
- " without further training."
772
- )
773
- if len(mismatched_keys) > 0:
774
- mismatched_warning = "\n".join(
775
- [
776
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
777
- for key, shape1, shape2 in mismatched_keys
778
- ]
779
- )
780
- logger.warning(
781
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
782
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
783
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
784
- " able to use it for predictions and inference."
785
- )
786
-
787
- return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
788
-
789
- @property
790
- def device(self) -> device:
791
- """
792
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
793
- device).
794
- """
795
- return get_parameter_device(self)
796
-
797
- @property
798
- def dtype(self) -> torch.dtype:
799
- """
800
- `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
801
- """
802
- return get_parameter_dtype(self)
803
-
804
- def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
805
- """
806
- Get number of (optionally, trainable or non-embeddings) parameters in the module.
807
-
808
- Args:
809
- only_trainable (`bool`, *optional*, defaults to `False`):
810
- Whether or not to return only the number of trainable parameters
811
-
812
- exclude_embeddings (`bool`, *optional*, defaults to `False`):
813
- Whether or not to return only the number of non-embeddings parameters
814
-
815
- Returns:
816
- `int`: The number of parameters.
817
- """
818
-
819
- if exclude_embeddings:
820
- embedding_param_names = [
821
- f"{name}.weight"
822
- for name, module_type in self.named_modules()
823
- if isinstance(module_type, torch.nn.Embedding)
824
- ]
825
- non_embedding_parameters = [
826
- parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
827
- ]
828
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
829
- else:
830
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
831
-
832
- def _convert_deprecated_attention_blocks(self, state_dict):
833
- deprecated_attention_block_paths = []
834
-
835
- def recursive_find_attn_block(name, module):
836
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
837
- deprecated_attention_block_paths.append(name)
838
-
839
- for sub_name, sub_module in module.named_children():
840
- sub_name = sub_name if name == "" else f"{name}.{sub_name}"
841
- recursive_find_attn_block(sub_name, sub_module)
842
-
843
- recursive_find_attn_block("", self)
844
-
845
- # NOTE: we have to check if the deprecated parameters are in the state dict
846
- # because it is possible we are loading from a state dict that was already
847
- # converted
848
-
849
- for path in deprecated_attention_block_paths:
850
- # group_norm path stays the same
851
-
852
- # query -> to_q
853
- if f"{path}.query.weight" in state_dict:
854
- state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
855
- if f"{path}.query.bias" in state_dict:
856
- state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
857
-
858
- # key -> to_k
859
- if f"{path}.key.weight" in state_dict:
860
- state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
861
- if f"{path}.key.bias" in state_dict:
862
- state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
863
-
864
- # value -> to_v
865
- if f"{path}.value.weight" in state_dict:
866
- state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
867
- if f"{path}.value.bias" in state_dict:
868
- state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
869
-
870
- # proj_attn -> to_out.0
871
- if f"{path}.proj_attn.weight" in state_dict:
872
- state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
873
- if f"{path}.proj_attn.bias" in state_dict:
874
- state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/pipelines.py CHANGED
@@ -1,12 +1,85 @@
1
  import torch
2
  from tqdm import tqdm
 
3
  import utils
4
- from utils import schedule
5
  from PIL import Image
6
  import gc
7
  import numpy as np
8
  from .attention import GatedSelfAttentionDense
9
  from .models import process_input_embeddings, torch_device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  @torch.no_grad()
12
  def encode(model_dict, image, generator):
@@ -53,6 +126,126 @@ def decode(vae, latents):
53
 
54
  return images
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @torch.no_grad()
57
  def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False, scheduler_key='dpm_scheduler'):
58
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
@@ -132,9 +325,13 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
132
  frozen_steps=20, frozen_mask=None,
133
  return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
134
  offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
 
135
  return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
136
  """
137
  The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
 
 
 
138
  """
139
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
140
 
@@ -161,6 +358,9 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
161
  if fast_after_steps is not None:
162
  scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
163
 
 
 
 
164
  if frozen_mask is not None:
165
  frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
166
 
@@ -171,6 +371,23 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
171
 
172
  boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  if return_saved_cross_attn:
175
  saved_attns = []
176
 
@@ -196,6 +413,9 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
196
  if index == num_grounding_steps:
197
  gligen_enable_fuser(unet, False)
198
 
 
 
 
199
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
200
  latent_model_input = torch.cat([latents] * 2)
201
 
@@ -215,7 +435,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
215
  # perform guidance
216
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
217
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
218
-
219
  if dynamic_num_inference_steps:
220
  schedule.dynamically_adjust_inference_steps(scheduler, index, t)
221
 
@@ -225,12 +445,17 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
225
  if frozen_mask is not None and index < frozen_steps:
226
  latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
227
 
 
228
  if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
229
  if offload_latents_to_cpu:
230
  latents_all.append(latents.cpu())
231
  else:
232
  latents_all.append(latents)
233
 
 
 
 
 
234
  # Turn off fuser for typical SD
235
  gligen_enable_fuser(unet, False)
236
  images = decode(vae, latents)
@@ -247,3 +472,128 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
247
 
248
  return tuple(ret)
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from tqdm import tqdm
3
+ from utils import guidance, schedule, boxdiff
4
  import utils
 
5
  from PIL import Image
6
  import gc
7
  import numpy as np
8
  from .attention import GatedSelfAttentionDense
9
  from .models import process_input_embeddings, torch_device
10
+ import warnings
11
+
12
+ # All keys: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
13
+ # Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
14
+ DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
15
+
16
+ def latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, loss_scale = 30, loss_threshold = 0.2, max_iter = 5, max_index_step = 10, cross_attention_kwargs=None, ref_ca_saved_attns=None, guidance_attn_keys=None, verbose=False, clear_cache=False, **kwargs):
17
+
18
+ iteration = 0
19
+
20
+ if index < max_index_step:
21
+ if isinstance(max_iter, list):
22
+ if len(max_iter) > index:
23
+ max_iter = max_iter[index]
24
+ else:
25
+ max_iter = max_iter[-1]
26
+
27
+ if verbose:
28
+ print(f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}")
29
+
30
+ while (loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step):
31
+ saved_attn = {}
32
+ full_cross_attention_kwargs = {
33
+ 'save_attn_to_dict': saved_attn,
34
+ 'save_keys': guidance_attn_keys,
35
+ }
36
+
37
+ if cross_attention_kwargs is not None:
38
+ full_cross_attention_kwargs.update(cross_attention_kwargs)
39
+
40
+ latents.requires_grad_(True)
41
+ latent_model_input = latents
42
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
43
+
44
+ unet(latent_model_input, t, encoder_hidden_states=cond_embeddings, return_cross_attention_probs=False, cross_attention_kwargs=full_cross_attention_kwargs)
45
+
46
+ # TODO: could return the attention maps for the required blocks only and not necessarily the final output
47
+ # update latents with guidance
48
+ loss = guidance.compute_ca_lossv3(saved_attn=saved_attn, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys, ref_ca_saved_attns=ref_ca_saved_attns, index=index, verbose=verbose, **kwargs) * loss_scale
49
+
50
+ if torch.isnan(loss):
51
+ print("**Loss is NaN**")
52
+
53
+ del full_cross_attention_kwargs, saved_attn
54
+ # call gc.collect() here may release some memory
55
+
56
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
57
+
58
+ latents.requires_grad_(False)
59
+
60
+ if hasattr(scheduler, 'sigmas'):
61
+ latents = latents - grad_cond * scheduler.sigmas[index] ** 2
62
+ elif hasattr(scheduler, 'alphas_cumprod'):
63
+ warnings.warn("Using guidance scaled with alphas_cumprod")
64
+ # Scaling with classifier guidance
65
+ alpha_prod_t = scheduler.alphas_cumprod[t]
66
+ # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
67
+ # DDIM: https://arxiv.org/pdf/2010.02502.pdf
68
+ scale = (1 - alpha_prod_t) ** (0.5)
69
+ latents = latents - scale * grad_cond
70
+ else:
71
+ # NOTE: no scaling is performed
72
+ warnings.warn("No scaling in guidance is performed")
73
+ latents = latents - grad_cond
74
+ iteration += 1
75
+
76
+ if clear_cache:
77
+ utils.free_memory()
78
+
79
+ if verbose:
80
+ print(f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}")
81
+
82
+ return latents, loss
83
 
84
  @torch.no_grad()
85
  def encode(model_dict, image, generator):
 
126
 
127
  return images
128
 
129
+ def generate_semantic_guidance(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, object_positions, guidance_scale = 7.5, semantic_guidance_kwargs=None,
130
+ return_cross_attn=False, return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None, offload_guidance_cross_attn_to_cpu=False,
131
+ offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True, return_box_vis=False, show_progress=True, save_all_latents=False,
132
+ dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2, use_boxdiff=False):
133
+ """
134
+ object_positions: object indices in text tokens
135
+ return_cross_attn: should be deprecated. Use `return_saved_cross_attn` and the new format.
136
+ """
137
+ vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
138
+ text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
139
+
140
+ # Just in case that we have in-place ops
141
+ latents = latents.clone()
142
+
143
+ if save_all_latents:
144
+ # offload to cpu to save space
145
+ if offload_latents_to_cpu:
146
+ latents_all = [latents.cpu()]
147
+ else:
148
+ latents_all = [latents]
149
+
150
+ scheduler.set_timesteps(num_inference_steps)
151
+ if fast_after_steps is not None:
152
+ scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
153
+
154
+ if dynamic_num_inference_steps:
155
+ original_num_inference_steps = scheduler.num_inference_steps
156
+
157
+ cross_attention_probs_down = []
158
+ cross_attention_probs_mid = []
159
+ cross_attention_probs_up = []
160
+
161
+ loss = torch.tensor(10000.)
162
+
163
+ # TODO: we can also save necessary tokens only to save memory.
164
+ # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
165
+ guidance_cross_attention_kwargs = {
166
+ 'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
167
+ 'enable_flash_attn': False
168
+ }
169
+
170
+ if return_saved_cross_attn:
171
+ saved_attns = []
172
+
173
+ main_cross_attention_kwargs = {
174
+ 'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
175
+ 'return_cond_ca_only': return_cond_ca_only,
176
+ 'return_token_ca_only': return_token_ca_only,
177
+ 'save_keys': saved_cross_attn_keys,
178
+ }
179
+
180
+ # Repeating keys leads to different weights for each key.
181
+ # assert len(set(semantic_guidance_kwargs['guidance_attn_keys'])) == len(semantic_guidance_kwargs['guidance_attn_keys']), f"guidance_attn_keys not unique: {semantic_guidance_kwargs['guidance_attn_keys']}"
182
+
183
+ for index, t in enumerate(tqdm(scheduler.timesteps, disable=not show_progress)):
184
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
185
+
186
+ if bboxes:
187
+ if use_boxdiff:
188
+ latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
189
+ else:
190
+ # If encountered None in `guidance_attn_keys`, please be sure to check whether `guidance_attn_keys` is added in `semantic_guidance_kwargs`. Default value has been removed.
191
+ latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
192
+
193
+ # predict the noise residual
194
+ with torch.no_grad():
195
+ latent_model_input = torch.cat([latents] * 2)
196
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
197
+
198
+ main_cross_attention_kwargs['save_attn_to_dict'] = {}
199
+
200
+ unet_output = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, return_cross_attention_probs=return_cross_attn, cross_attention_kwargs=main_cross_attention_kwargs)
201
+ noise_pred = unet_output.sample
202
+
203
+ if return_cross_attn:
204
+ cross_attention_probs_down.append(unet_output.cross_attention_probs_down)
205
+ cross_attention_probs_mid.append(unet_output.cross_attention_probs_mid)
206
+ cross_attention_probs_up.append(unet_output.cross_attention_probs_up)
207
+
208
+ if return_saved_cross_attn:
209
+ saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
210
+
211
+ del main_cross_attention_kwargs['save_attn_to_dict']
212
+
213
+ # perform guidance
214
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
215
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
216
+
217
+ if dynamic_num_inference_steps:
218
+ schedule.dynamically_adjust_inference_steps(scheduler, index, t)
219
+
220
+ # compute the previous noisy sample x_t -> x_t-1
221
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
222
+
223
+ if save_all_latents:
224
+ if offload_latents_to_cpu:
225
+ latents_all.append(latents.cpu())
226
+ else:
227
+ latents_all.append(latents)
228
+
229
+ if dynamic_num_inference_steps:
230
+ # Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
231
+ scheduler.num_inference_steps = original_num_inference_steps
232
+
233
+ images = decode(vae, latents)
234
+
235
+ ret = [latents, images]
236
+
237
+ if return_cross_attn:
238
+ ret.append((cross_attention_probs_down, cross_attention_probs_mid, cross_attention_probs_up))
239
+ if return_saved_cross_attn:
240
+ ret.append(saved_attns)
241
+ if return_box_vis:
242
+ pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
243
+ ret.append(pil_images)
244
+ if save_all_latents:
245
+ latents_all = torch.stack(latents_all, dim=0)
246
+ ret.append(latents_all)
247
+ return tuple(ret)
248
+
249
  @torch.no_grad()
250
  def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False, scheduler_key='dpm_scheduler'):
251
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
 
325
  frozen_steps=20, frozen_mask=None,
326
  return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
327
  offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
328
+ semantic_guidance=False, semantic_guidance_bboxes=None, semantic_guidance_object_positions=None, semantic_guidance_kwargs=None,
329
  return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
330
  """
331
  The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
332
+ batched:
333
+ Enabled: bboxes and phrases should be a list (batch dimension) of items (specify the bboxes/phrases of each image in the batch).
334
+ Disabled: bboxes and phrases should be a list of bboxes and phrases specifying the bboxes/phrases of one image (no batch dimension).
335
  """
336
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
337
 
 
358
  if fast_after_steps is not None:
359
  scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
360
 
361
+ if dynamic_num_inference_steps:
362
+ original_num_inference_steps = scheduler.num_inference_steps
363
+
364
  if frozen_mask is not None:
365
  frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
366
 
 
371
 
372
  boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
373
 
374
+ if semantic_guidance_bboxes and semantic_guidance:
375
+ loss = torch.tensor(10000.)
376
+ # TODO: we can also save necessary tokens only to save memory.
377
+ # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
378
+ guidance_cross_attention_kwargs = {
379
+ 'offload_cross_attn_to_cpu': False,
380
+ 'enable_flash_attn': False,
381
+ 'gligen': {
382
+ 'boxes': boxes[:condition_len // 2],
383
+ 'positive_embeddings': phrase_embeddings[:condition_len // 2],
384
+ 'masks': masks[:condition_len // 2],
385
+ 'fuser_attn_kwargs': {
386
+ 'enable_flash_attn': False,
387
+ }
388
+ }
389
+ }
390
+
391
  if return_saved_cross_attn:
392
  saved_attns = []
393
 
 
413
  if index == num_grounding_steps:
414
  gligen_enable_fuser(unet, False)
415
 
416
+ if semantic_guidance_bboxes and semantic_guidance:
417
+ with torch.enable_grad():
418
+ latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, semantic_guidance_bboxes, semantic_guidance_object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
419
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
420
  latent_model_input = torch.cat([latents] * 2)
421
 
 
435
  # perform guidance
436
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
437
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
438
+
439
  if dynamic_num_inference_steps:
440
  schedule.dynamically_adjust_inference_steps(scheduler, index, t)
441
 
 
445
  if frozen_mask is not None and index < frozen_steps:
446
  latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
447
 
448
+ # Do not save the latents in the fast steps
449
  if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
450
  if offload_latents_to_cpu:
451
  latents_all.append(latents.cpu())
452
  else:
453
  latents_all.append(latents)
454
 
455
+ if dynamic_num_inference_steps:
456
+ # Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
457
+ scheduler.num_inference_steps = original_num_inference_steps
458
+
459
  # Turn off fuser for typical SD
460
  gligen_enable_fuser(unet, False)
461
  images = decode(vae, latents)
 
472
 
473
  return tuple(ret)
474
 
475
+
476
+ def get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength):
477
+ # get the original timestep using init_timestep
478
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
479
+
480
+ t_start = max(num_inference_steps - init_timestep, 0)
481
+
482
+ # safety for t_start overflow to prevent empty timsteps slice
483
+ if t_start == 0:
484
+ return inverse_scheduler.timesteps, num_inference_steps
485
+ timesteps = inverse_scheduler.timesteps[:-t_start]
486
+
487
+ return timesteps, num_inference_steps - t_start
488
+
489
+ @torch.no_grad()
490
+ def invert(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5):
491
+ """
492
+ latents: encoded from the image, should not have noise (t = 0)
493
+
494
+ returns inverted_latents for all time steps
495
+ """
496
+ vae, tokenizer, text_encoder, unet, scheduler, inverse_scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.inverse_scheduler, model_dict.dtype
497
+ text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
498
+
499
+ inverse_scheduler.set_timesteps(num_inference_steps, device=latents.device)
500
+ # We need to invert all steps because we need them to generate the background.
501
+ timesteps, num_inference_steps = get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength=1.0)
502
+
503
+ inverted_latents = [latents.cpu()]
504
+ for t in tqdm(timesteps[:-1]):
505
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
506
+ if guidance_scale > 0.:
507
+ latent_model_input = torch.cat([latents] * 2)
508
+
509
+ latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)
510
+
511
+ # predict the noise residual
512
+ with torch.no_grad():
513
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
514
+
515
+ # perform guidance
516
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
517
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
518
+ else:
519
+ latent_model_input = latents
520
+
521
+ latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)
522
+
523
+ # predict the noise residual
524
+ with torch.no_grad():
525
+ noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample
526
+
527
+ # perform guidance
528
+ noise_pred = noise_pred_uncond
529
+
530
+ # compute the previous noisy sample x_t -> x_t-1
531
+ latents = inverse_scheduler.step(noise_pred, t, latents).prev_sample
532
+
533
+ inverted_latents.append(latents.cpu())
534
+
535
+ assert len(inverted_latents) == len(timesteps)
536
+ # timestep is the first dimension
537
+ inverted_latents = torch.stack(list(reversed(inverted_latents)), dim=0)
538
+
539
+ return inverted_latents
540
+
541
+ def generate_partial_frozen(model_dict, latents_all, frozen_mask, input_embeddings, num_inference_steps, frozen_steps, guidance_scale = 7.5, bboxes=None, phrases=None, object_positions=None, semantic_guidance_kwargs=None, offload_guidance_cross_attn_to_cpu=False, use_boxdiff=False):
542
+ vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
543
+ text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
544
+
545
+ scheduler.set_timesteps(num_inference_steps)
546
+ frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
547
+
548
+ latents = latents_all[0]
549
+
550
+ if bboxes:
551
+ # With semantic guidance
552
+ loss = torch.tensor(10000.)
553
+
554
+ # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
555
+ guidance_cross_attention_kwargs = {
556
+ 'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
557
+ # Getting invalid argument on backward, probably due to insufficient shared memory
558
+ 'enable_flash_attn': False
559
+ }
560
+
561
+ for index, t in enumerate(tqdm(scheduler.timesteps)):
562
+ if bboxes:
563
+ # With semantic guidance, `guidance_attn_keys` should be in `semantic_guidance_kwargs`
564
+ if use_boxdiff:
565
+ latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
566
+ else:
567
+ latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
568
+
569
+ with torch.no_grad():
570
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
571
+ latent_model_input = torch.cat([latents] * 2)
572
+
573
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
574
+
575
+ # predict the noise residual
576
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
577
+
578
+ # perform guidance
579
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
580
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
581
+
582
+ # compute the previous noisy sample x_t -> x_t-1
583
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
584
+
585
+ if index < frozen_steps:
586
+ latents = latents_all[index+1] * frozen_mask + latents * (1. - frozen_mask)
587
+
588
+ # scale and decode the image latents with vae
589
+ scaled_latents = 1 / 0.18215 * latents
590
+ with torch.no_grad():
591
+ image = vae.decode(scaled_latents).sample
592
+
593
+ image = (image / 2 + 0.5).clamp(0, 1)
594
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
595
+ images = (image * 255).round().astype("uint8")
596
+
597
+ ret = [latents, images]
598
+
599
+ return tuple(ret)
models/sam.py CHANGED
@@ -164,8 +164,10 @@ def sam_refine_attn(sam_input_image, token_attn_np, model_dict, height, width, H
164
  return mask_selected, conf_score_selected
165
 
166
  def sam_refine_box(sam_input_image, box, *args, **kwargs):
167
- sam_input_images, boxes = [sam_input_image], [box]
168
- return sam_refine_boxes(sam_input_images, boxes, *args, **kwargs)
 
 
169
 
170
  def sam_refine_boxes(sam_input_images, boxes, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
171
  # (w, h)
 
164
  return mask_selected, conf_score_selected
165
 
166
  def sam_refine_box(sam_input_image, box, *args, **kwargs):
167
+ # One image with one box
168
+ sam_input_images, boxes = [sam_input_image], [[box]]
169
+ mask_selected_batched_list, conf_score_selected_batched_list = sam_refine_boxes(sam_input_images, boxes, *args, **kwargs)
170
+ return mask_selected_batched_list[0][0], conf_score_selected_batched_list[0][0]
171
 
172
  def sam_refine_boxes(sam_input_images, boxes, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
173
  # (w, h)
utils/attn.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # visualization-related functions are in vis
2
+ import numbers
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ import utils
8
+
9
+ def get_token_attnv2(token_id, saved_attns, attn_key, visualize_step_start=10, input_ca_has_condition_only=False, return_np=False):
10
+ """
11
+ saved_attns: a list of saved_attn (list is across timesteps)
12
+
13
+ moves to cpu by default
14
+ """
15
+ saved_attns = saved_attns[visualize_step_start:]
16
+
17
+ saved_attns = [saved_attn[attn_key].cpu() for saved_attn in saved_attns]
18
+
19
+ attn = torch.stack(saved_attns, dim=0).mean(dim=0)
20
+
21
+ # print("attn shape", attn.shape)
22
+
23
+ # attn: (batch, head, spatial, text)
24
+
25
+ if not input_ca_has_condition_only:
26
+ assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
27
+ attn = attn[1]
28
+ else:
29
+ assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
30
+ attn = attn[0]
31
+ attn = attn.mean(dim=0)[:, token_id]
32
+ H = W = int(math.sqrt(attn.shape[0]))
33
+ attn = attn.reshape((H, W))
34
+
35
+ if return_np:
36
+ return attn.numpy()
37
+
38
+ return attn
39
+
40
+ def shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, horizontal_shift_only=False):
41
+ """
42
+ `horizontal_shift_only`: only shift horizontally. If you use `offset` from `compose_latents_with_alignment` with `horizontal_shift_only=True`, the `offset` already has y_offset = 0 and this option is not needed.
43
+ """
44
+ x_offset, y_offset = offset
45
+ if horizontal_shift_only:
46
+ y_offset = 0.
47
+
48
+ new_saved_attns_item = {}
49
+ for k in guidance_attn_keys:
50
+ attn_map = saved_attns_item[k]
51
+
52
+ attn_size = attn_map.shape[-2]
53
+ attn_h = attn_w = int(math.sqrt(attn_size))
54
+ # Example dimensions: [batch_size, num_heads, 8, 8, num_tokens]
55
+ attn_map = attn_map.unflatten(2, (attn_h, attn_w))
56
+ attn_map = utils.shift_tensor(
57
+ attn_map, x_offset, y_offset,
58
+ offset_normalized=True, ignore_last_dim=True
59
+ )
60
+ attn_map = attn_map.flatten(2, 3)
61
+
62
+ new_saved_attns_item[k] = attn_map
63
+
64
+ return new_saved_attns_item
65
+
66
+ def shift_saved_attns(saved_attns, offset, guidance_attn_keys, **kwargs):
67
+ # Iterate over timesteps
68
+ shifted_saved_attns = [shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, **kwargs) for saved_attns_item in saved_attns]
69
+
70
+ return shifted_saved_attns
71
+
72
+
73
+ class GaussianSmoothing(nn.Module):
74
+ """
75
+ Apply gaussian smoothing on a
76
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
77
+ in the input using a depthwise convolution.
78
+ Arguments:
79
+ channels (int, sequence): Number of channels of the input tensors. Output will
80
+ have this number of channels as well.
81
+ kernel_size (int, sequence): Size of the gaussian kernel.
82
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
83
+ dim (int, optional): The number of dimensions of the data.
84
+ Default value is 2 (spatial).
85
+
86
+ Credit: https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, sigma, dim=2):
90
+ super(GaussianSmoothing, self).__init__()
91
+ if isinstance(kernel_size, numbers.Number):
92
+ kernel_size = [kernel_size] * dim
93
+ if isinstance(sigma, numbers.Number):
94
+ sigma = [sigma] * dim
95
+
96
+ # The gaussian kernel is the product of the
97
+ # gaussian function of each dimension.
98
+ kernel = 1
99
+ meshgrids = torch.meshgrid(
100
+ [
101
+ torch.arange(size, dtype=torch.float32)
102
+ for size in kernel_size
103
+ ]
104
+ )
105
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
106
+ mean = (size - 1) / 2
107
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
108
+ torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
109
+
110
+ # Make sure sum of values in gaussian kernel equals 1.
111
+ kernel = kernel / torch.sum(kernel)
112
+
113
+ # Reshape to depthwise convolutional weight
114
+ kernel = kernel.view(1, 1, *kernel.size())
115
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
116
+
117
+ self.register_buffer('weight', kernel)
118
+ self.groups = channels
119
+
120
+ if dim == 1:
121
+ self.conv = F.conv1d
122
+ elif dim == 2:
123
+ self.conv = F.conv2d
124
+ elif dim == 3:
125
+ self.conv = F.conv3d
126
+ else:
127
+ raise RuntimeError(
128
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(
129
+ dim)
130
+ )
131
+
132
+ def forward(self, input):
133
+ """
134
+ Apply gaussian filter to input.
135
+ Arguments:
136
+ input (torch.Tensor): Input to apply gaussian filter on.
137
+ Returns:
138
+ filtered (torch.Tensor): Filtered output.
139
+ """
140
+ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
utils/boxdiff.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an reimplementation boxdiff baseline for reference and comparison. It is not used in the Web UI and not enabled by default since the current attention guidance implementation (in `guidance`), which uses attention maps from multiple levels and attention transfer, seems to be more robust and coherent.
3
+
4
+ Credit: https://github.com/showlab/BoxDiff/blob/master/pipeline/sd_pipeline_boxdiff.py
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import math
10
+ import warnings
11
+ import gc
12
+ from collections.abc import Iterable
13
+ import utils
14
+ from . import guidance
15
+ from .attn import GaussianSmoothing
16
+
17
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
18
+
19
+
20
+ def _compute_max_attention_per_index(attention_maps: torch.Tensor,
21
+ object_positions: List[List[int]],
22
+ smooth_attentions: bool = False,
23
+ sigma: float = 0.5,
24
+ kernel_size: int = 3,
25
+ normalize_eot: bool = False,
26
+ bboxes: List[List[int]] = None,
27
+ P: float = 0.2,
28
+ L: int = 1,
29
+ ) -> List[torch.Tensor]:
30
+ """ Computes the maximum attention value for each of the tokens we wish to alter. """
31
+ last_idx = -1
32
+ assert not normalize_eot, "normalize_eot is unimplemented"
33
+
34
+ attention_for_text = attention_maps[:, :, 1:last_idx]
35
+ attention_for_text *= 100
36
+ attention_for_text = F.softmax(attention_for_text, dim=-1)
37
+
38
+ # Extract the maximum values
39
+ max_indices_list_fg = []
40
+ max_indices_list_bg = []
41
+ dist_x = []
42
+ dist_y = []
43
+
44
+ for obj_idx, text_positions_per_obj in enumerate(object_positions):
45
+ for text_position_per_obj in text_positions_per_obj:
46
+ # Shift indices since we removed the first token
47
+ image = attention_for_text[:, :, text_position_per_obj - 1]
48
+ H, W = image.shape
49
+
50
+ obj_mask = torch.zeros_like(image)
51
+ corner_mask_x = torch.zeros(
52
+ (W,), device=obj_mask.device, dtype=obj_mask.dtype)
53
+ corner_mask_y = torch.zeros(
54
+ (H,), device=obj_mask.device, dtype=obj_mask.dtype)
55
+
56
+ obj_boxes = bboxes[obj_idx]
57
+
58
+ # We support two level (one box per phrase) and three level (multiple boxes per phrase)
59
+ if not isinstance(obj_boxes[0], Iterable):
60
+ obj_boxes = [obj_boxes]
61
+
62
+ for obj_box in obj_boxes:
63
+ x_min, y_min, x_max, y_max = utils.scale_proportion(
64
+ obj_box, H=H, W=W)
65
+ obj_mask[y_min: y_max, x_min: x_max] = 1
66
+
67
+ corner_mask_x[max(x_min - L, 0): min(x_min + L + 1, W)] = 1.
68
+ corner_mask_x[max(x_max - L, 0): min(x_max + L + 1, W)] = 1.
69
+ corner_mask_y[max(y_min - L, 0): min(y_min + L + 1, H)] = 1.
70
+ corner_mask_y[max(y_max - L, 0): min(y_max + L + 1, H)] = 1.
71
+
72
+ bg_mask = 1 - obj_mask
73
+
74
+ if smooth_attentions:
75
+ smoothing = GaussianSmoothing(
76
+ channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
77
+ input = F.pad(image.unsqueeze(0).unsqueeze(0),
78
+ (1, 1, 1, 1), mode='reflect')
79
+ image = smoothing(input).squeeze(0).squeeze(0)
80
+
81
+ # Inner-Box constraint
82
+ k = (obj_mask.sum() * P).long()
83
+ max_indices_list_fg.append(
84
+ (image * obj_mask).reshape(-1).topk(k)[0].mean())
85
+
86
+ # Outer-Box constraint
87
+ k = (bg_mask.sum() * P).long()
88
+ max_indices_list_bg.append(
89
+ (image * bg_mask).reshape(-1).topk(k)[0].mean())
90
+
91
+ # Corner Constraint
92
+ gt_proj_x = torch.max(obj_mask, dim=0).values
93
+ gt_proj_y = torch.max(obj_mask, dim=1).values
94
+
95
+ # create gt according to the number L
96
+ dist_x.append((F.l1_loss(image.max(dim=0)[
97
+ 0], gt_proj_x, reduction='none') * corner_mask_x).mean())
98
+ dist_y.append((F.l1_loss(image.max(dim=1)[
99
+ 0], gt_proj_y, reduction='none') * corner_mask_y).mean())
100
+
101
+ return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y
102
+
103
+
104
+ def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor],
105
+ dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor:
106
+ """ Computes the attend-and-excite loss using the maximum attention value for each token. """
107
+ losses_fg = [max(0, 1. - curr_max)
108
+ for curr_max in max_attention_per_index_fg]
109
+ losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg]
110
+ loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y)
111
+
112
+ # print(f"{losses_fg}, {losses_bg}, {dist_x}, {dist_y}, {loss}")
113
+
114
+ if return_losses:
115
+ return max(losses_fg), losses_fg
116
+ else:
117
+ return max(losses_fg), loss
118
+
119
+
120
+ def compute_ca_loss_boxdiff(saved_attn, bboxes, object_positions, guidance_attn_keys, ref_ca_saved_attns=None, ref_ca_last_token_only=True, ref_ca_word_token_only=False, word_token_indices=None, index=None, ref_ca_loss_weight=1.0, verbose=False, **kwargs):
121
+ """
122
+ v3 is equivalent to v2 but with new dictionary format for attention maps.
123
+ The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
124
+ `AttnProcessor` will put attention maps into the `save_attn_to_dict`.
125
+
126
+ `index` is the timestep.
127
+ `ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token).
128
+ `ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens.
129
+ """
130
+ loss = torch.tensor(0).float().cuda()
131
+ object_number = len(bboxes)
132
+ if object_number == 0:
133
+ return loss
134
+
135
+ attn_map_list = []
136
+
137
+ for attn_key in guidance_attn_keys:
138
+ # We only have 1 cross attention for mid.
139
+ attn_map_integrated = saved_attn[attn_key]
140
+ if not attn_map_integrated.is_cuda:
141
+ attn_map_integrated = attn_map_integrated.cuda()
142
+ # Example dimension: [20, 64, 77]
143
+ attn_map = attn_map_integrated.squeeze(dim=0)
144
+ attn_map_list.append(attn_map)
145
+ # This averages both across layers and across attention heads
146
+ attn_map = torch.cat(attn_map_list, dim=0).mean(dim=0)
147
+ loss = add_ca_loss_per_attn_map_to_loss_boxdiff(
148
+ loss, attn_map, object_number, bboxes, object_positions, verbose=verbose, **kwargs)
149
+
150
+ if ref_ca_saved_attns is not None:
151
+ warnings.warn('Attention reference loss is enabled in boxdiff mode. The original boxdiff does not have attention reference loss.')
152
+
153
+ ref_loss = torch.tensor(0).float().cuda()
154
+ ref_loss = guidance.add_ref_ca_loss_per_attn_map_to_lossv2(
155
+ ref_loss, saved_attn=saved_attn, object_number=object_number, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys,
156
+ ref_ca_saved_attns=ref_ca_saved_attns, ref_ca_last_token_only=ref_ca_last_token_only, ref_ca_word_token_only=ref_ca_word_token_only, word_token_indices=word_token_indices, verbose=verbose, index=index, loss_weight=ref_ca_loss_weight
157
+ )
158
+ print(f"loss {loss.item():.3f}, reference attention loss (weighted) {ref_loss.item():.3f}")
159
+ loss += ref_loss
160
+
161
+ return loss
162
+
163
+
164
+ def add_ca_loss_per_attn_map_to_loss_boxdiff(original_loss, attention_maps, object_number, bboxes, object_positions, P=0.2, L=1, smooth_attentions=True, sigma=0.5, kernel_size=3, normalize_eot=False, verbose=False):
165
+ # NOTE: normalize_eot is enabled in SD v2.1 in boxdiff
166
+ i, j = attention_maps.shape
167
+ H = W = int(math.sqrt(i))
168
+
169
+ attention_maps = attention_maps.view(H, W, j)
170
+ # attention_maps is aggregated cross attn map across layers and steps
171
+ # attention_maps shape: [H, W, 77]
172
+ max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = _compute_max_attention_per_index(
173
+ attention_maps=attention_maps,
174
+ object_positions=object_positions,
175
+ smooth_attentions=smooth_attentions,
176
+ sigma=sigma,
177
+ kernel_size=kernel_size,
178
+ normalize_eot=normalize_eot,
179
+ bboxes=bboxes,
180
+ P=P,
181
+ L=L
182
+ )
183
+
184
+ _, loss = _compute_loss(max_attention_per_index_fg,
185
+ max_attention_per_index_bg, dist_x, dist_y)
186
+
187
+ return original_loss + loss
188
+
189
+
190
+ def latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, amp_loss_scale=10, latent_scale=20, scale_range=(1., 0.5), max_index_step=25, cross_attention_kwargs=None, ref_ca_saved_attns=None, guidance_attn_keys=None, verbose=False, **kwargs):
191
+ """
192
+ amp_loss_scale: this scales the loss but will de-scale before applying for latents. This is to prevent overflow/underflow with amp, not to adjust the update step size.
193
+ latent_scale: this scales the step size for update (scale_factor in boxdiff).
194
+ """
195
+
196
+ if index < max_index_step:
197
+ saved_attn = {}
198
+ full_cross_attention_kwargs = {
199
+ 'save_attn_to_dict': saved_attn,
200
+ 'save_keys': guidance_attn_keys,
201
+ }
202
+
203
+ if cross_attention_kwargs is not None:
204
+ full_cross_attention_kwargs.update(cross_attention_kwargs)
205
+
206
+ latents.requires_grad_(True)
207
+ latent_model_input = latents
208
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
209
+
210
+ unet(latent_model_input, t, encoder_hidden_states=cond_embeddings,
211
+ return_cross_attention_probs=False, cross_attention_kwargs=full_cross_attention_kwargs)
212
+
213
+ # TODO: could return the attention maps for the required blocks only and not necessarily the final output
214
+ # update latents with guidance
215
+ loss = compute_ca_loss_boxdiff(saved_attn=saved_attn, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys,
216
+ ref_ca_saved_attns=ref_ca_saved_attns, index=index, verbose=verbose, **kwargs) * amp_loss_scale
217
+
218
+ if torch.isnan(loss):
219
+ print("**Loss is NaN**")
220
+
221
+ del full_cross_attention_kwargs, saved_attn
222
+ # call gc.collect() here may release some memory
223
+
224
+ grad_cond = torch.autograd.grad(
225
+ loss.requires_grad_(True), [latents])[0]
226
+
227
+ latents.requires_grad_(False)
228
+
229
+ if True:
230
+ warnings.warn("Using guidance scaled with sqrt scale")
231
+ # According to boxdiff's implementation: https://github.com/Sierkinhane/BoxDiff/blob/16ffb677a9128128e04553a0200870a526731be0/pipeline/sd_pipeline_boxdiff.py#L616
232
+ scale = (scale_range[0] + (scale_range[1] - scale_range[0])
233
+ * index / (len(scheduler.timesteps) - 1)) ** (0.5)
234
+ latents = latents - latent_scale * scale / amp_loss_scale * grad_cond
235
+ elif hasattr(scheduler, 'sigmas'):
236
+ warnings.warn("Using guidance scaled with sigmas")
237
+ scale = scheduler.sigmas[index] ** 2
238
+ latents = latents - grad_cond * scale
239
+ elif hasattr(scheduler, 'alphas_cumprod'):
240
+ warnings.warn("Using guidance scaled with alphas_cumprod")
241
+ # Scaling with classifier guidance
242
+ alpha_prod_t = scheduler.alphas_cumprod[t]
243
+ # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
244
+ # DDIM: https://arxiv.org/pdf/2010.02502.pdf
245
+ scale = (1 - alpha_prod_t) ** (0.5)
246
+ latents = latents - latent_scale * scale / amp_loss_scale * grad_cond
247
+ else:
248
+ warnings.warn("No scaling in guidance is performed")
249
+ scale = 1
250
+ latents = latents - grad_cond
251
+
252
+ gc.collect()
253
+ torch.cuda.empty_cache()
254
+
255
+ if verbose:
256
+ print(
257
+ f"time index {index}, loss: {loss.item() / amp_loss_scale:.3f} (de-scaled with scale {amp_loss_scale:.1f}), latent grad scale: {scale:.3f}")
258
+
259
+ return latents, loss
utils/guidance.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from collections.abc import Iterable
5
+ import warnings
6
+
7
+ import utils
8
+
9
+ # A list mapping: prompt index to str (prompt in a list of token str)
10
+ def get_token_map(tokenizer, prompt, verbose=False, padding="do_not_pad"):
11
+ fg_prompt_tokens = tokenizer([prompt], padding=padding, max_length=77, return_tensors="np")
12
+ input_ids = fg_prompt_tokens['input_ids'][0]
13
+
14
+ # index_to_last_with = np.max(np.where(input_ids == 593))
15
+ # index_to_last_eot = np.max(np.where(input_ids == 49407))
16
+
17
+ token_map = []
18
+ for ind, item in enumerate(input_ids.tolist()):
19
+
20
+ token = tokenizer._convert_id_to_token(item)
21
+ if verbose:
22
+ print(f"{ind}, {token} ({item})")
23
+
24
+ token_map.append(token)
25
+
26
+ # If we don't pad, we don't need to break.
27
+ # if item == tokenizer.eos_token_id:
28
+ # break
29
+
30
+ return token_map
31
+
32
+ def get_phrase_indices(tokenizer, prompt, phrases, verbose=False, words=None, include_eos=False, token_map=None, return_word_token_indices=False, add_suffix_if_not_found=False):
33
+ for obj in phrases:
34
+ # Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
35
+ if obj not in prompt:
36
+ prompt += "| " + obj
37
+
38
+ if token_map is None:
39
+ # We allow using a pre-computed token map.
40
+ token_map = get_token_map(tokenizer, prompt=prompt, verbose=verbose, padding="do_not_pad")
41
+ token_map_str = " ".join(token_map)
42
+
43
+ object_positions = []
44
+ word_token_indices = []
45
+ for obj_ind, obj in enumerate(phrases):
46
+ phrase_token_map = get_token_map(tokenizer, prompt=obj, verbose=verbose, padding="do_not_pad")
47
+ # Remove <bos> and <eos> in substr
48
+ phrase_token_map = phrase_token_map[1:-1]
49
+ phrase_token_map_len = len(phrase_token_map)
50
+ phrase_token_map_str = " ".join(phrase_token_map)
51
+
52
+ if verbose:
53
+ print("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
54
+
55
+ # Count the number of token before substr
56
+ # The substring comes with a trailing space that needs to be removed by minus one in the index.
57
+ obj_first_index = len(token_map_str[:token_map_str.index(phrase_token_map_str)-1].split(" "))
58
+
59
+ obj_position = list(range(obj_first_index, obj_first_index + phrase_token_map_len))
60
+ if include_eos:
61
+ obj_position.append(token_map.index(tokenizer.eos_token))
62
+ object_positions.append(obj_position)
63
+
64
+ if return_word_token_indices:
65
+ # Picking the last token in the specification
66
+ if words is None:
67
+ so_token_index = object_positions[0][-1]
68
+ # Picking the noun or perform pooling on attention with the tokens may be better
69
+ print(f"Picking the last token \"{token_map[so_token_index]}\" ({so_token_index}) as attention token for extracting attention for SAM, which might not be the right one")
70
+ else:
71
+ word = words[obj_ind]
72
+ word_token_map = get_token_map(tokenizer, prompt=word, verbose=verbose, padding="do_not_pad")
73
+ # Get the index of the last token of word (the occurrence in phrase) in the prompt. Note that we skip the <eos> token through indexing with -2.
74
+ so_token_index = obj_first_index + phrase_token_map.index(word_token_map[-2])
75
+
76
+ if verbose:
77
+ print("so_token_index:", so_token_index)
78
+
79
+ word_token_indices.append(so_token_index)
80
+
81
+ if return_word_token_indices:
82
+ if add_suffix_if_not_found:
83
+ return object_positions, word_token_indices, prompt
84
+ return object_positions, word_token_indices
85
+
86
+ if add_suffix_if_not_found:
87
+ return object_positions, prompt
88
+
89
+ return object_positions
90
+
91
+ def add_ca_loss_per_attn_map_to_loss(loss, attn_map, object_number, bboxes, object_positions, use_ratio_based_loss=True, fg_top_p=0.2, bg_top_p=0.2, fg_weight=1.0, bg_weight=1.0, verbose=False):
92
+ """
93
+ fg_top_p, bg_top_p, fg_weight, and bg_weight are only used with max-based loss
94
+ """
95
+
96
+ # Uncomment to debug:
97
+ # print(fg_top_p, bg_top_p, fg_weight, bg_weight)
98
+
99
+ # b is the number of heads, not batch
100
+ b, i, j = attn_map.shape
101
+ H = W = int(math.sqrt(i))
102
+ for obj_idx in range(object_number):
103
+ obj_loss = 0
104
+ mask = torch.zeros(size=(H, W), device="cuda")
105
+ obj_boxes = bboxes[obj_idx]
106
+
107
+ # We support two level (one box per phrase) and three level (multiple boxes per phrase)
108
+ if not isinstance(obj_boxes[0], Iterable):
109
+ obj_boxes = [obj_boxes]
110
+
111
+ for obj_box in obj_boxes:
112
+ # x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
113
+ x_min, y_min, x_max, y_max = utils.scale_proportion(obj_box, H=H, W=W)
114
+ mask[y_min: y_max, x_min: x_max] = 1
115
+
116
+ for obj_position in object_positions[obj_idx]:
117
+ # Could potentially optimize to compute this for loop in batch.
118
+ # Could crop the ref cross attention before saving to save memory.
119
+
120
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
121
+
122
+ if use_ratio_based_loss:
123
+ warnings.warn("Using ratio-based loss, which is deprecated. Max-based loss is recommended. The scale may be different.")
124
+ # Original loss function (ratio-based loss function)
125
+
126
+ # Enforces the attention to be within the mask only. Does not enforce within-mask distribution.
127
+ activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
128
+ obj_loss += torch.mean((1 - activation_value) ** 2)
129
+ # if verbose:
130
+ # print(f"enforce attn to be within the mask loss: {torch.mean((1 - activation_value) ** 2).item():.2f}")
131
+ else:
132
+ # Max-based loss function
133
+
134
+ # shape: (b, H * W)
135
+ ca_map_obj = attn_map[:, :, obj_position] # .reshape(b, H, W)
136
+ k_fg = (mask.sum() * fg_top_p).long().clamp_(min=1)
137
+ k_bg = ((1 - mask).sum() * bg_top_p).long().clamp_(min=1)
138
+
139
+ mask_1d = mask.view(1, -1)
140
+
141
+ # Take the topk over spatial dimension, and then take the sum over heads dim
142
+ # The mean is over k_fg and k_bg dimension, so we don't need to sum and divide on our own.
143
+ obj_loss += (1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1)).sum(dim=0) * fg_weight
144
+ obj_loss += ((ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1)).sum(dim=0) * bg_weight
145
+
146
+ loss += obj_loss / len(object_positions[obj_idx])
147
+
148
+ return loss
149
+
150
+ def add_ref_ca_loss_per_attn_map_to_lossv2(loss, saved_attn, object_number, bboxes, object_positions, guidance_attn_keys, ref_ca_saved_attns, ref_ca_last_token_only, ref_ca_word_token_only, word_token_indices, index, loss_weight, eps=1e-5, verbose=False):
151
+ """
152
+ This adds the ca loss with ref. Note that this should be used with ca loss without ref since it only enforces the mse of the normalized ca between ref and target.
153
+
154
+ `ref_ca_saved_attn` should have the same structure as bboxes and object_positions (until the inner content, which should be similar to saved_attn).
155
+ """
156
+
157
+ if loss_weight == 0.:
158
+ # Skip computing the reference loss if the loss weight is 0.
159
+ return loss
160
+
161
+ for obj_idx in range(object_number):
162
+ obj_loss = 0
163
+
164
+ obj_boxes = bboxes[obj_idx]
165
+ obj_ref_ca_saved_attns = ref_ca_saved_attns[obj_idx]
166
+
167
+ # We support two level (one box per phrase) and three level (multiple boxes per phrase)
168
+ if not isinstance(obj_boxes[0], Iterable):
169
+ obj_boxes = [obj_boxes]
170
+ obj_ref_ca_saved_attns = [obj_ref_ca_saved_attns]
171
+
172
+ assert len(obj_boxes) == len(obj_ref_ca_saved_attns), f"obj_boxes: {len(obj_boxes)}, obj_ref_ca_saved_attns: {len(obj_ref_ca_saved_attns)}"
173
+
174
+ for obj_box, obj_ref_ca_saved_attn in zip(obj_boxes, obj_ref_ca_saved_attns):
175
+ # obj_ref_ca_map_items has all timesteps.
176
+ # Format: (timestep (index), attn_key, batch, heads, 2d dim, num text tokens (selected 1))
177
+
178
+ # Different from ca_loss without ref, which has one loss for all boxes for a phrase (a set of object positions), we have one loss per box.
179
+
180
+ # obj_ref_ca_saved_attn_items: select the timestep
181
+ obj_ref_ca_saved_attn = obj_ref_ca_saved_attn[index]
182
+
183
+ for attn_key in guidance_attn_keys:
184
+ attn_map = saved_attn[attn_key]
185
+ if not attn_map.is_cuda:
186
+ attn_map = attn_map.cuda()
187
+ attn_map = attn_map.squeeze(dim=0)
188
+
189
+ obj_ref_ca_map = obj_ref_ca_saved_attn[attn_key]
190
+ if not obj_ref_ca_map.is_cuda:
191
+ obj_ref_ca_map = obj_ref_ca_map.cuda()
192
+ # obj_ref_ca_map: (batch, heads, 2d dim, num text token)
193
+ # `squeeze` on `obj_ref_ca_map` is combined with the subsequent indexing
194
+
195
+ # b is the number of heads, not batch
196
+ b, i, j = attn_map.shape
197
+ H = W = int(math.sqrt(i))
198
+ # `obj_ref_ca_map` only has one text token (the 0 at the last dimension)
199
+
200
+ assert obj_ref_ca_map.ndim == 4, f"{obj_ref_ca_map.shape}"
201
+ obj_ref_ca_map = obj_ref_ca_map[0, :, :, 0]
202
+
203
+ # Same mask for all heads
204
+ obj_mask = torch.zeros(size=(H, W), device="cuda")
205
+ # x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
206
+ x_min, y_min, x_max, y_max = utils.scale_proportion(obj_box, H=H, W=W)
207
+ obj_mask[y_min: y_max, x_min: x_max] = 1
208
+
209
+ # keep 1d mask
210
+ obj_mask = obj_mask.reshape(1, -1)
211
+
212
+ # Optimize the loss over the last phrase token only (assuming the indices in `object_positions[obj_idx]` is sorted)
213
+ if ref_ca_word_token_only:
214
+ object_positions_to_iterate = [word_token_indices[obj_idx]]
215
+ elif ref_ca_last_token_only:
216
+ object_positions_to_iterate = [object_positions[obj_idx][-1]]
217
+ else:
218
+ print(f"Applying attention transfer from one attention to all attention maps in object positions {object_positions[obj_idx]}, which is likely to be incorrect")
219
+ object_positions_to_iterate = object_positions[obj_idx]
220
+ for obj_position in object_positions_to_iterate:
221
+ ca_map_obj = attn_map[:, :, obj_position]
222
+
223
+ ca_map_obj_masked = ca_map_obj * obj_mask
224
+
225
+ # Add eps because the sum can be very small, causing NaN
226
+ ca_map_obj_masked_normalized = ca_map_obj_masked / (ca_map_obj_masked.sum(dim=-1, keepdim=True) + eps)
227
+ obj_ref_ca_map_masked = obj_ref_ca_map * obj_mask
228
+ obj_ref_ca_map_masked_normalized = obj_ref_ca_map_masked / (obj_ref_ca_map_masked.sum(dim=-1, keepdim=True) + eps)
229
+
230
+ # We found dividing by object mask size makes the loss too small. Since the normalized masked attn has mean value inversely proportional to the mask size, summing the values up spatially gives a relatively standard scale to add to other losses.
231
+ activation_value = (torch.abs(ca_map_obj_masked_normalized - obj_ref_ca_map_masked_normalized)).sum(dim=-1)
232
+
233
+ obj_loss += torch.mean(activation_value, dim=0)
234
+
235
+ # The normalization for len(obj_ref_ca_map_items) is at the outside of this function.
236
+ # Note that we assume we have at least one box for each object
237
+ loss += loss_weight * obj_loss / (len(obj_boxes) * len(object_positions_to_iterate))
238
+
239
+ if verbose:
240
+ print(f"reference cross-attention obj_loss: unweighted {obj_loss.item() / (len(obj_boxes) * len(object_positions[obj_idx])):.3f}, weighted {loss_weight * obj_loss.item() / (len(obj_boxes) * len(object_positions[obj_idx])):.3f}")
241
+
242
+ return loss
243
+
244
+ def compute_ca_lossv3(saved_attn, bboxes, object_positions, guidance_attn_keys, ref_ca_saved_attns=None, ref_ca_last_token_only=True, ref_ca_word_token_only=False, word_token_indices=None, index=None, ref_ca_loss_weight=1.0, verbose=False, **kwargs):
245
+ """
246
+ v3 is equivalent to v2 but with new dictionary format for attention maps.
247
+ The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
248
+ `AttnProcessor` will put attention maps into the `save_attn_to_dict`.
249
+
250
+ `index` is the timestep.
251
+ `ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token).
252
+ `ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens.
253
+ """
254
+ loss = torch.tensor(0).float().cuda()
255
+ object_number = len(bboxes)
256
+ if object_number == 0:
257
+ return loss
258
+
259
+ for attn_key in guidance_attn_keys:
260
+ # We only have 1 cross attention for mid.
261
+ attn_map_integrated = saved_attn[attn_key]
262
+ if not attn_map_integrated.is_cuda:
263
+ attn_map_integrated = attn_map_integrated.cuda()
264
+ # Example dimension: [20, 64, 77]
265
+ attn_map = attn_map_integrated.squeeze(dim=0)
266
+ loss = add_ca_loss_per_attn_map_to_loss(loss, attn_map, object_number, bboxes, object_positions, verbose=verbose, **kwargs)
267
+
268
+ num_attn = len(guidance_attn_keys)
269
+
270
+ if num_attn > 0:
271
+ loss = loss / (object_number * num_attn)
272
+
273
+ if ref_ca_saved_attns is not None:
274
+ ref_loss = torch.tensor(0).float().cuda()
275
+ ref_loss = add_ref_ca_loss_per_attn_map_to_lossv2(
276
+ ref_loss, saved_attn=saved_attn, object_number=object_number, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys,
277
+ ref_ca_saved_attns=ref_ca_saved_attns, ref_ca_last_token_only=ref_ca_last_token_only, ref_ca_word_token_only=ref_ca_word_token_only, word_token_indices=word_token_indices, verbose=verbose, index=index, loss_weight=ref_ca_loss_weight
278
+ )
279
+
280
+ num_attn = len(guidance_attn_keys)
281
+
282
+ if verbose:
283
+ print(f"loss {loss.item():.3f}, reference attention loss (weighted) {ref_loss.item() / (object_number * num_attn):.3f}")
284
+
285
+ loss += ref_loss / (object_number * num_attn)
286
+
287
+ return loss
288
+
289
+ # For compatibility
290
+ def add_ref_ca_loss_per_attn_map_to_loss(loss, attn_maps, object_number, bboxes, object_positions, ref_ca_maps, stage_id, index, verbose=False):
291
+ """
292
+ This adds the ca loss with ref. Note that this should be used with ca loss without ref since it only enforces the mse of the normalized ca between ref and target.
293
+
294
+ ref_ca_maps should have the same structure as bboxes and object_positions.
295
+ """
296
+ # attn_map_items is all cond ca maps for current down/mid/up for the overall generation.
297
+ attn_map_items = attn_maps[stage_id]
298
+
299
+ for obj_idx in range(object_number):
300
+ obj_loss = 0
301
+
302
+ obj_boxes = bboxes[obj_idx]
303
+ obj_ref_ca_maps = ref_ca_maps[obj_idx]
304
+
305
+ # We support two level (one box per phrase) and three level (multiple boxes per phrase)
306
+ if not isinstance(obj_boxes[0], Iterable):
307
+ obj_boxes = [obj_boxes]
308
+ obj_ref_ca_maps = [obj_ref_ca_maps]
309
+
310
+ assert len(obj_boxes) == len(obj_ref_ca_maps), f"obj_boxes: {len(obj_boxes)}, obj_ref_ca_maps: {len(obj_ref_ca_maps)}"
311
+
312
+ for obj_box, obj_ref_ca_map_items in zip(obj_boxes, obj_ref_ca_maps):
313
+ # obj_ref_ca_map_items format: (stage, timestep (index), block, batch, heads, 2d dim, num text tokens (selected 1))
314
+ # Different from ca_loss without ref, which has one loss for all boxes for a phrase (a set of object positions), we have one loss per box.
315
+
316
+ # print(len(obj_ref_ca_map_items), obj_ref_ca_map_items[stage_id].shape)
317
+ # Mid example: 1 torch.Size([50, 1, 1, 8, 64, 1])
318
+ # Up example: 3 torch.Size([50, 3, 1, 8, 256, 1])
319
+
320
+ # obj_ref_ca_map_items is all cond ca maps for current down/mid/up for the single object generation.
321
+ obj_ref_ca_map_items = obj_ref_ca_map_items[stage_id][index]
322
+
323
+ for attn_map, obj_ref_ca_map in zip(attn_map_items, obj_ref_ca_map_items):
324
+ attn_map = attn_map.squeeze(dim=0)
325
+ # b is the number of heads, not batch
326
+ b, i, j = attn_map.shape
327
+ H = W = int(math.sqrt(i))
328
+ # obj_ref_ca_map only has one text token (the 0 at the last dimension)
329
+
330
+ assert obj_ref_ca_map.ndim == 4, f"{obj_ref_ca_map.ndim}"
331
+ obj_ref_ca_map = obj_ref_ca_map[0, :, :, 0]
332
+
333
+ # Same mask for all heads
334
+ obj_mask = torch.zeros(size=(H, W), device="cuda")
335
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
336
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
337
+ obj_mask[y_min: y_max, x_min: x_max] = 1
338
+
339
+ # keep 1d mask
340
+ obj_mask = obj_mask.reshape(1, -1)
341
+
342
+ for obj_position in object_positions[obj_idx]:
343
+ ca_map_obj = attn_map[:, :, obj_position]
344
+
345
+ ca_map_obj_masked = ca_map_obj * obj_mask
346
+ obj_ref_ca_map_masked = obj_ref_ca_map * obj_mask
347
+ # We found dividing by object mask size makes the loss too small. Since the normalized masked attn has mean value inversely proportional to the mask size, summing the values up spatially gives a relatively standard scale to add to other losses.
348
+ activation_value = (torch.abs(ca_map_obj_masked / ca_map_obj_masked.sum(dim=-1, keepdim=True) - obj_ref_ca_map_masked / obj_ref_ca_map_masked.sum(dim=-1, keepdim=True))).sum(dim=-1) # / obj_mask.sum()
349
+
350
+ obj_loss += torch.mean(activation_value, dim=0)
351
+
352
+ # The normalization for len(obj_ref_ca_map_items) is at the outside of this function.
353
+ loss += obj_loss / (len(obj_boxes) * len(object_positions[obj_idx]))
354
+
355
+ if verbose:
356
+ print(f"reference cross-attention obj_loss: {obj_loss.item() / (len(obj_boxes) * len(object_positions[obj_idx])):.3f}")
357
+
358
+ return loss
utils/latents.py CHANGED
@@ -44,9 +44,10 @@ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inferenc
44
 
45
  # Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
46
  if use_fast_schedule:
47
- # If we use fast schedule, we only need to compose the frozen steps.
48
  composed_latents = torch.zeros((fast_after_steps + 1, *latents_bg.shape), dtype=dtype)
49
  else:
 
50
  composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
51
  composed_latents[0] = latents_bg
52
 
@@ -73,7 +74,7 @@ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inferenc
73
  latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
74
  foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
75
  mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
76
- composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all * mask_tensor_expanded
77
 
78
  composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
79
  return composed_latents, foreground_indices
 
44
 
45
  # Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
46
  if use_fast_schedule:
47
+ # If we use fast schedule, we only compose the frozen steps because the later steps do not match.
48
  composed_latents = torch.zeros((fast_after_steps + 1, *latents_bg.shape), dtype=dtype)
49
  else:
50
+ # Otherwise we compose all steps so that we don't need to compose again if we change the frozen steps.
51
  composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
52
  composed_latents[0] = latents_bg
53
 
 
74
  latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
75
  foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
76
  mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
77
+ composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all[:fast_after_steps + 1] * mask_tensor_expanded
78
 
79
  composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
80
  return composed_latents, foreground_indices
utils/parse.py CHANGED
@@ -1,33 +1,39 @@
1
  import ast
2
- import os
3
- import json
4
  from matplotlib.patches import Polygon
5
  from matplotlib.collections import PatchCollection
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
- import cv2
9
  import inflect
10
 
11
  p = inflect.engine()
12
 
13
  img_dir = "imgs"
 
14
  bg_prompt_text = "Background prompt: "
 
 
 
 
15
  # h, w
16
  box_scale = (512, 512)
17
  size = box_scale
18
  size_h, size_w = size
19
  print(f"Using box scale: {box_scale}")
20
 
 
21
  def parse_input(text=None, no_input=False):
 
 
22
  if not text:
23
  if no_input:
24
  return
25
 
26
  text = input("Enter the response: ")
27
- if "Objects: " in text:
28
- text = text.split("Objects: ")[1]
29
 
30
- text_split = text.split(bg_prompt_text)
31
  if len(text_split) == 2:
32
  gen_boxes, bg_prompt = text_split
33
  elif len(text_split) == 1:
@@ -38,8 +44,8 @@ def parse_input(text=None, no_input=False):
38
  while not bg_prompt:
39
  # Ignore the empty lines in the response
40
  bg_prompt = input("Enter the background prompt: ").strip()
41
- if bg_prompt_text in bg_prompt:
42
- bg_prompt = bg_prompt.split(bg_prompt_text)[1]
43
  else:
44
  raise ValueError(f"text: {text}")
45
  try:
@@ -54,7 +60,70 @@ def parse_input(text=None, no_input=False):
54
 
55
  return gen_boxes, bg_prompt
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3):
 
 
 
58
  if len(gen_boxes) == 0:
59
  return []
60
 
@@ -62,9 +131,13 @@ def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=
62
  gen_boxes_new = []
63
  for gen_box in gen_boxes:
64
  if isinstance(gen_box, dict):
 
 
65
  name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
66
  box_dict_format = True
67
  else:
 
 
68
  name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
69
  if bbox_w <= 0 or bbox_h <= 0:
70
  # Empty boxes
@@ -73,6 +146,12 @@ def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=
73
  if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]:
74
  # Ignore the background boxes
75
  continue
 
 
 
 
 
 
76
  gen_boxes_new.append(gen_box)
77
 
78
  gen_boxes = gen_boxes_new
@@ -99,9 +178,11 @@ def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=
99
 
100
  # Used if scale_boxes is True
101
  shift = -bbox_left_x_min
102
- scale = size_w / (bbox_right_x_max - bbox_left_x_min)
 
 
103
 
104
- scale = min(scale, max_scale)
105
 
106
  for gen_box in gen_boxes:
107
  if box_dict_format:
@@ -165,7 +246,7 @@ def draw_boxes(anns):
165
  ax.add_collection(p)
166
 
167
 
168
- def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
169
  if len(gen_boxes) == 0:
170
  return
171
 
@@ -183,7 +264,7 @@ def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
183
 
184
  if bg_prompt is not None:
185
  ax = plt.gca()
186
- ax.text(0, 0, bg_prompt, style='italic',
187
  bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
188
 
189
  c = (np.zeros((1, 3)))
@@ -200,12 +281,6 @@ def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
200
  draw_boxes(anns)
201
  if show:
202
  plt.show()
203
- else:
204
- print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}")
205
- if ind is not None:
206
- plt.savefig(f"{img_dir}/boxes_{ind}.png")
207
- plt.savefig(f"{img_dir}/boxes.png")
208
-
209
 
210
  def show_masks(masks):
211
  masks_to_show = np.zeros((*size, 3), dtype=np.float32)
 
1
  import ast
 
 
2
  from matplotlib.patches import Polygon
3
  from matplotlib.collections import PatchCollection
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ import warnings
7
  import inflect
8
 
9
  p = inflect.engine()
10
 
11
  img_dir = "imgs"
12
+ objects_text = "Objects: "
13
  bg_prompt_text = "Background prompt: "
14
+ bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
15
+ neg_prompt_text = "Negative prompt: "
16
+ neg_prompt_text_no_trailing_space = neg_prompt_text.rstrip()
17
+
18
  # h, w
19
  box_scale = (512, 512)
20
  size = box_scale
21
  size_h, size_w = size
22
  print(f"Using box scale: {box_scale}")
23
 
24
+
25
  def parse_input(text=None, no_input=False):
26
+ warnings.warn("Parsing input without negative prompt is deprecated.")
27
+
28
  if not text:
29
  if no_input:
30
  return
31
 
32
  text = input("Enter the response: ")
33
+ if objects_text in text:
34
+ text = text.split(objects_text)[1]
35
 
36
+ text_split = text.split(bg_prompt_text_no_trailing_space)
37
  if len(text_split) == 2:
38
  gen_boxes, bg_prompt = text_split
39
  elif len(text_split) == 1:
 
44
  while not bg_prompt:
45
  # Ignore the empty lines in the response
46
  bg_prompt = input("Enter the background prompt: ").strip()
47
+ if bg_prompt_text_no_trailing_space in bg_prompt:
48
+ bg_prompt = bg_prompt.split(bg_prompt_text_no_trailing_space)[1]
49
  else:
50
  raise ValueError(f"text: {text}")
51
  try:
 
60
 
61
  return gen_boxes, bg_prompt
62
 
63
+ def parse_input_with_negative(text=None, no_input=False):
64
+ # no_input: should not request interactive input
65
+
66
+ if not text:
67
+ if no_input:
68
+ return
69
+
70
+ text = input("Enter the response: ")
71
+ if objects_text in text:
72
+ text = text.split(objects_text)[1]
73
+
74
+ text_split = text.split(bg_prompt_text_no_trailing_space)
75
+ if len(text_split) == 2:
76
+ gen_boxes, text_rem = text_split
77
+ elif len(text_split) == 1:
78
+ if no_input:
79
+ return
80
+ gen_boxes = text
81
+ text_rem = ""
82
+ while not text_rem:
83
+ # Ignore the empty lines in the response
84
+ text_rem = input("Enter the background prompt: ").strip()
85
+ if bg_prompt_text_no_trailing_space in text_rem:
86
+ text_rem = text_rem.split(bg_prompt_text_no_trailing_space)[1]
87
+ else:
88
+ raise ValueError(f"text: {text}")
89
+
90
+ text_split = text_rem.split(neg_prompt_text_no_trailing_space)
91
+
92
+ if len(text_split) == 2:
93
+ bg_prompt, neg_prompt = text_split
94
+ elif len(text_split) == 1:
95
+ bg_prompt = text_rem
96
+ # Negative prompt is optional: if it's not provided, we default to empty string
97
+ neg_prompt = ""
98
+ if not no_input:
99
+ # Ignore the empty lines in the response
100
+ neg_prompt = input("Enter the negative prompt: ").strip()
101
+ if neg_prompt_text_no_trailing_space in neg_prompt:
102
+ neg_prompt = neg_prompt.split(neg_prompt_text_no_trailing_space)[1]
103
+ else:
104
+ raise ValueError(f"text: {text}")
105
+
106
+ try:
107
+ gen_boxes = ast.literal_eval(gen_boxes)
108
+ except SyntaxError as e:
109
+ # Sometimes the response is in plain text
110
+ if "No objects" in gen_boxes or gen_boxes.strip() == "":
111
+ gen_boxes = []
112
+ else:
113
+ raise e
114
+ bg_prompt = bg_prompt.strip()
115
+ neg_prompt = neg_prompt.strip()
116
+
117
+ # LLM may return "None" to mean no negative prompt provided.
118
+ if neg_prompt == "None":
119
+ neg_prompt = ""
120
+
121
+ return gen_boxes, bg_prompt, neg_prompt
122
+
123
  def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3):
124
+ if gen_boxes is None:
125
+ return []
126
+
127
  if len(gen_boxes) == 0:
128
  return []
129
 
 
131
  gen_boxes_new = []
132
  for gen_box in gen_boxes:
133
  if isinstance(gen_box, dict):
134
+ if not gen_box['bounding_box']:
135
+ continue
136
  name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
137
  box_dict_format = True
138
  else:
139
+ if not gen_box[1]:
140
+ continue
141
  name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
142
  if bbox_w <= 0 or bbox_h <= 0:
143
  # Empty boxes
 
146
  if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]:
147
  # Ignore the background boxes
148
  continue
149
+
150
+ if bbox_x < 0 or bbox_y < 0 or bbox_x + bbox_w > size[1] or bbox_y + bbox_h > size[0]:
151
+ # Out of bounds boxes exist: we need to scale and shift all the boxes
152
+ print(f"**Some boxes are out of bounds: {gen_box}, scaling all the boxes to fit**")
153
+ scale_boxes = True
154
+
155
  gen_boxes_new.append(gen_box)
156
 
157
  gen_boxes = gen_boxes_new
 
178
 
179
  # Used if scale_boxes is True
180
  shift = -bbox_left_x_min
181
+ # Make sure the boxes fit horizontally and vertically
182
+ scale_w = size_w / (bbox_right_x_max - bbox_left_x_min)
183
+ scale_h = size_h / (bbox_bottom_y_max - bbox_top_y_min)
184
 
185
+ scale = min(scale_w, scale_h, max_scale)
186
 
187
  for gen_box in gen_boxes:
188
  if box_dict_format:
 
246
  ax.add_collection(p)
247
 
248
 
249
+ def show_boxes(gen_boxes, bg_prompt=None, neg_prompt=None, ind=None, show=False):
250
  if len(gen_boxes) == 0:
251
  return
252
 
 
264
 
265
  if bg_prompt is not None:
266
  ax = plt.gca()
267
+ ax.text(0, 0, bg_prompt + f"(Neg: {neg_prompt})" if neg_prompt else bg_prompt, style='italic',
268
  bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
269
 
270
  c = (np.zeros((1, 3)))
 
281
  draw_boxes(anns)
282
  if show:
283
  plt.show()
 
 
 
 
 
 
284
 
285
  def show_masks(masks):
286
  masks_to_show = np.zeros((*size, 3), dtype=np.float32)
utils/utils.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  from PIL import ImageDraw
3
  import numpy as np
4
- import os
5
  import gc
6
 
7
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
 
1
  import torch
2
  from PIL import ImageDraw
3
  import numpy as np
 
4
  import gc
5
 
6
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
utils/vis.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import math
3
+ import utils
4
+ from . import parse
5
+
6
+ save_ind = 0
7
+
8
+ def visualize(image, title, colorbar=False, show_plot=True, **kwargs):
9
+ plt.title(title)
10
+ plt.imshow(image, **kwargs)
11
+ if colorbar:
12
+ plt.colorbar()
13
+ if show_plot:
14
+ plt.show()
15
+
16
+ def visualize_arrays(image_title_pairs, colorbar_index=-1, show_plot=True, figsize=None, **kwargs):
17
+ if figsize is not None:
18
+ plt.figure(figsize=figsize)
19
+ num_subplots = len(image_title_pairs)
20
+ for idx, image_title_pair in enumerate(image_title_pairs):
21
+ plt.subplot(1, num_subplots, idx+1)
22
+ if isinstance(image_title_pair, (list, tuple)):
23
+ image, title = image_title_pair
24
+ else:
25
+ image, title = image_title_pair, None
26
+
27
+ if title is not None:
28
+ plt.title(title)
29
+
30
+ plt.imshow(image, **kwargs)
31
+ if idx == colorbar_index:
32
+ plt.colorbar()
33
+
34
+ if show_plot:
35
+ plt.show()
36
+
37
+ def visualize_masked_latents(latents_all, masked_latents, timestep_T=False, timestep_0=True):
38
+ if timestep_T:
39
+ # from T to 0
40
+ latent_idx = 0
41
+
42
+ plt.subplot(1, 2, 1)
43
+ plt.title("latents_all (t=T)")
44
+ plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
45
+
46
+ plt.subplot(1, 2, 2)
47
+ plt.title("mask latents (t=T)")
48
+ plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
49
+
50
+ plt.show()
51
+
52
+ if timestep_0:
53
+ latent_idx = -1
54
+ plt.subplot(1, 2, 1)
55
+ plt.title("latents_all (t=0)")
56
+ plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
57
+
58
+ plt.subplot(1, 2, 2)
59
+ plt.title("mask latents (t=0)")
60
+ plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")
61
+
62
+ plt.show()
63
+
64
+ # This function has not been adapted to new `saved_attn`.
65
+ def visualize_attn(token_map, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False):
66
+ """
67
+ Visualize cross attention: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token
68
+ cross_attention_probs_tensors:
69
+ One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors`
70
+ stage_id: index of downsampling/mid/upsaming block
71
+ block_id: index of the transformer block
72
+ """
73
+
74
+ plt.figure(figsize=(20, 8))
75
+
76
+ for token_id in range(len(token_map)):
77
+ token = token_map[token_id]
78
+ plt.subplot(1, len(token_map), token_id + 1)
79
+ plt.title(token)
80
+ attn = cross_attention_probs_tensors[stage_id][visualize_step_start:].mean(dim=0)[block_id]
81
+
82
+ if not input_ca_has_condition_only:
83
+ assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
84
+ attn = attn[1]
85
+ else:
86
+ assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
87
+ attn = attn[0]
88
+
89
+ attn = attn.mean(dim=0)[:, token_id]
90
+ H = W = int(math.sqrt(attn.shape[0]))
91
+ attn = attn.reshape((H, W))
92
+ plt.imshow(attn.cpu().numpy())
93
+
94
+ plt.show()
95
+
96
+ # This function has not been adapted to new `saved_attn`.
97
+ def visualize_across_timesteps(token_id, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False):
98
+ """
99
+ Visualize cross attention for one token, across timesteps: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token
100
+ cross_attention_probs_tensors:
101
+ One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors`
102
+ stage_id: index of downsampling/mid/upsaming block
103
+ block_id: index of the transformer block
104
+
105
+ `visualize_step_start` is not used. We visualize all timesteps.
106
+ """
107
+ plt.figure(figsize=(50, 8))
108
+
109
+ attn_stage = cross_attention_probs_tensors[stage_id]
110
+ num_inference_steps = attn_stage.shape[0]
111
+
112
+ for t in range(num_inference_steps):
113
+ plt.subplot(1, num_inference_steps, t + 1)
114
+ plt.title(f"t: {t}")
115
+
116
+ attn = attn_stage[t][block_id]
117
+
118
+ if not input_ca_has_condition_only:
119
+ assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
120
+ attn = attn[1]
121
+ else:
122
+ assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
123
+ attn = attn[0]
124
+
125
+ attn = attn.mean(dim=0)[:, token_id]
126
+ H = W = int(math.sqrt(attn.shape[0]))
127
+ attn = attn.reshape((H, W))
128
+ plt.imshow(attn.cpu().numpy())
129
+ plt.axis("off")
130
+ plt.tight_layout()
131
+
132
+ plt.show()
133
+
134
+ def visualize_bboxes(bboxes, H, W):
135
+ num_boxes = len(bboxes)
136
+ for ind, bbox in enumerate(bboxes):
137
+ plt.subplot(1, num_boxes, ind + 1)
138
+ fg_mask = utils.proportion_to_mask(bbox, H, W)
139
+ plt.title(f"transformed bbox ({ind})")
140
+ plt.imshow(fg_mask.cpu().numpy())
141
+ plt.show()
142
+
143
+ def display(image, save_prefix="", ind=None):
144
+ global save_ind
145
+ if save_prefix != "":
146
+ save_prefix = save_prefix + "_"
147
+ ind = f"{ind}_" if ind is not None else ""
148
+ path = f"{parse.img_dir}/{save_prefix}{ind}{save_ind}.png"
149
+
150
+ print(f"Saved to {path}")
151
+
152
+ image.save(path)
153
+ save_ind = save_ind + 1