songweig commited on
Commit
41fdef7
·
1 Parent(s): 757e20b

update token map

Browse files
app.py CHANGED
@@ -22,18 +22,17 @@ from share_btn import community_icon_html, loading_icon_html, share_js, css
22
  help_text = """
23
  If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
24
  1. If you format only a portion of a word rather than the complete word, an error may occur.
25
- 2. The token map may not always accurately capture the region of the formatted tokens. If you're experiencing this problem, experiment with selecting more or fewer tokens to expand or reduce the area covered by the token maps.
26
- 3. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
27
- 4. Consider using a different seed.
28
  """
29
 
30
 
31
  canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
32
  get_js_data = """
33
- async (text_input, negative_prompt, height, width, seed, steps, guidance_weight, color_guidance_weight, rich_text_input) => {
34
  const richEl = document.getElementById("rich-text-root");
35
  const data = richEl? richEl.contentDocument.body._data : {};
36
- return [text_input, negative_prompt, height, width, seed, steps, guidance_weight, color_guidance_weight, JSON.stringify(data)];
37
  }
38
  """
39
  set_js_data = """
@@ -71,9 +70,13 @@ def main():
71
  width: int,
72
  seed: int,
73
  steps: int,
 
 
 
74
  guidance_weight: float,
75
  color_guidance_weight: float,
76
- rich_text_input: str
 
77
  ):
78
  run_dir = 'results/'
79
  # Load region diffusion model.
@@ -88,7 +91,7 @@ def main():
88
  # parse json to span attributes
89
  base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
90
  color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
91
- json.loads(text_input), device)
92
 
93
  # create control input for region diffusion
94
  region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
@@ -108,7 +111,7 @@ def main():
108
  # get token maps from plain text to image generation.
109
  begin_time = time.time()
110
  if model.attention_maps is None:
111
- model.register_evaluation_hooks()
112
  else:
113
  model.reset_attention_maps()
114
  plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
@@ -116,27 +119,38 @@ def main():
116
  guidance_scale=guidance_weight)
117
  print('time lapses to get attention maps: %.4f' %
118
  (time.time()-begin_time))
119
- color_obj_masks, _ = get_token_maps(
120
- model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
121
- model.masks, token_maps = get_token_maps(
122
- model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens)
 
 
 
 
 
 
123
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
124
  interpolation=transforms.InterpolationMode.BICUBIC,
125
  antialias=True)
126
  for color_obj_mask in color_obj_masks]
127
  text_format_dict['color_obj_atten'] = color_obj_masks
128
- model.remove_evaluation_hooks()
129
 
130
  # generate image from rich text
131
  begin_time = time.time()
132
  seed_everything(seed)
 
 
 
 
133
  rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
134
  height=height, width=width, num_inference_steps=steps,
135
- guidance_scale=guidance_weight, use_grad_guidance=use_grad_guidance,
136
- text_format_dict=text_format_dict)
 
137
  print('time lapses to generate image from rich text: %.4f' %
138
  (time.time()-begin_time))
139
- return [plain_img[0], rich_img[0], token_maps]
140
 
141
  with gr.Blocks(css=css) as demo:
142
  url_params = gr.JSON({}, visible=False, label="URL Params")
@@ -162,6 +176,29 @@ def main():
162
  placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
163
  elem_id="negative_prompt"
164
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  seed = gr.Slider(label='Seed',
166
  minimum=0,
167
  maximum=100000,
@@ -169,15 +206,14 @@ def main():
169
  value=6,
170
  elem_id="seed"
171
  )
172
- color_guidance_weight = gr.Slider(label='Color weight lambda',
173
- minimum=0,
174
- maximum=2,
175
- step=0.1,
176
- value=0.5)
177
  with gr.Accordion('Other Parameters', open=False):
178
  steps = gr.Slider(label='Number of Steps',
179
  minimum=0,
180
- maximum=100,
181
  step=1,
182
  value=41)
183
  guidance_weight = gr.Slider(label='CFG weight',
@@ -206,6 +242,8 @@ def main():
206
  with gr.Row():
207
  plaintext_result = gr.Image(
208
  label='Plain-text', elem_id="plain-text-image")
 
 
209
  token_map = gr.Image(label='Token Maps')
210
  with gr.Row(visible=False) as share_row:
211
  with gr.Group(elem_id="share-btn-container"):
@@ -218,181 +256,238 @@ def main():
218
  gr.Markdown(help_text)
219
 
220
  with gr.Row():
221
- style_examples = [
222
  [
223
- '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
224
  '',
225
- 512,
226
- 512,
 
227
  6,
228
  1,
229
- None
 
230
  ],
231
  [
232
- '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
233
  '',
234
- 512,
235
- 512,
236
- 3,
 
237
  1,
238
- None
 
239
  ],
240
  [
241
- '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
242
- 'worst quality, dark, poor quality',
243
- 512,
244
- 512,
245
- 9,
 
246
  1,
247
- None
 
248
  ],
249
  ]
250
- gr.Examples(examples=style_examples,
251
- label='Font style examples',
 
252
  inputs=[
253
  text_input,
254
  negative_prompt,
255
- height,
256
- width,
 
257
  seed,
258
  color_guidance_weight,
259
  rich_text_input,
 
260
  ],
261
  outputs=[
262
  plaintext_result,
263
  richtext_result,
 
264
  token_map,
265
  ],
266
  fn=generate,
267
  # cache_examples=True,
268
  examples_per_page=20)
269
  with gr.Row():
270
- footnote_examples = [
271
  [
272
- '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
273
- '',
274
- 512,
275
- 512,
 
276
  6,
277
- 1,
278
- None
 
279
  ],
280
  [
281
- '{"ops":[{"insert":"A "},{"attributes":{"link":"kitchen island with a built-in oven and a stove with gas burners "},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
282
- '',
283
- 512,
284
- 512,
 
285
  6,
286
- 1,
287
- None
 
288
  ],
289
  [
290
- '{"ops":[{"insert":"A "},{"attributes":{"link":"Art inspired by kung fu panda, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
291
  '',
292
- 512,
293
- 512,
 
294
  6,
 
 
 
 
 
 
 
 
 
 
 
295
  1,
296
- None
 
 
 
 
 
 
 
 
 
 
 
 
297
  ],
298
  ]
299
-
300
- gr.Examples(examples=footnote_examples,
301
- label='Footnote examples',
302
  inputs=[
303
  text_input,
304
  negative_prompt,
305
- height,
306
- width,
 
307
  seed,
308
  color_guidance_weight,
309
  rich_text_input,
 
310
  ],
311
  outputs=[
312
  plaintext_result,
313
  richtext_result,
 
314
  token_map,
315
  ],
316
  fn=generate,
317
  # cache_examples=True,
318
  examples_per_page=20)
 
319
  with gr.Row():
320
- color_examples = [
321
  [
322
- '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
323
  '',
324
- 512,
325
- 512,
326
- 6,
327
- 1,
328
- None
 
 
329
  ],
330
  [
331
- '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
332
- '',
333
- 512,
334
- 512,
 
335
  9,
336
- 1,
337
- None
 
338
  ],
339
  [
340
- '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
341
  '',
342
- 512,
343
- 512,
344
- 5,
345
- 0.6,
346
- None
 
 
347
  ],
348
  ]
349
- gr.Examples(examples=color_examples,
350
- label='Font color examples',
351
  inputs=[
352
  text_input,
353
  negative_prompt,
354
- height,
355
- width,
 
356
  seed,
357
  color_guidance_weight,
358
  rich_text_input,
 
359
  ],
360
  outputs=[
361
  plaintext_result,
362
  richtext_result,
 
363
  token_map,
364
  ],
365
  fn=generate,
366
  # cache_examples=True,
367
  examples_per_page=20)
 
368
  with gr.Row():
369
  size_examples = [
370
  [
371
  '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
372
  'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
373
- 512,
374
- 512,
 
375
  13,
376
  1,
377
- None
 
378
  ],
379
  [
380
  '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
381
  'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
382
- 512,
383
- 512,
 
384
  13,
385
  1,
386
- None
 
387
  ],
388
  [
389
  '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
390
  'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
391
- 512,
392
- 512,
 
393
  13,
394
  1,
395
- None
 
396
  ],
397
  ]
398
  gr.Examples(examples=size_examples,
@@ -400,15 +495,18 @@ def main():
400
  inputs=[
401
  text_input,
402
  negative_prompt,
403
- height,
404
- width,
 
405
  seed,
406
  color_guidance_weight,
407
  rich_text_input,
 
408
  ],
409
  outputs=[
410
  plaintext_result,
411
  richtext_result,
 
412
  token_map,
413
  ],
414
  fn=generate,
@@ -423,11 +521,15 @@ def main():
423
  width,
424
  seed,
425
  steps,
 
 
 
426
  guidance_weight,
427
  color_guidance_weight,
428
- rich_text_input
 
429
  ],
430
- outputs=[plaintext_result, richtext_result, token_map],
431
  _js=get_js_data
432
  ).then(
433
  fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
 
22
  help_text = """
23
  If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
24
  1. If you format only a portion of a word rather than the complete word, an error may occur.
25
+ 2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
26
+ 3. Consider using a different seed.
 
27
  """
28
 
29
 
30
  canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
31
  get_js_data = """
32
+ async (text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, rich_text_input, background_aug) => {
33
  const richEl = document.getElementById("rich-text-root");
34
  const data = richEl? richEl.contentDocument.body._data : {};
35
+ return [text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, JSON.stringify(data), background_aug];
36
  }
37
  """
38
  set_js_data = """
 
70
  width: int,
71
  seed: int,
72
  steps: int,
73
+ num_segments: int,
74
+ segment_threshold: float,
75
+ inject_interval: float,
76
  guidance_weight: float,
77
  color_guidance_weight: float,
78
+ rich_text_input: str,
79
+ background_aug: bool,
80
  ):
81
  run_dir = 'results/'
82
  # Load region diffusion model.
 
91
  # parse json to span attributes
92
  base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
93
  color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
94
+ json.loads(text_input))
95
 
96
  # create control input for region diffusion
97
  region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
 
111
  # get token maps from plain text to image generation.
112
  begin_time = time.time()
113
  if model.attention_maps is None:
114
+ model.register_tokenmap_hooks()
115
  else:
116
  model.reset_attention_maps()
117
  plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
 
119
  guidance_scale=guidance_weight)
120
  print('time lapses to get attention maps: %.4f' %
121
  (time.time()-begin_time))
122
+ seed_everything(seed)
123
+ color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
124
+ 512//8, 512//8, color_target_token_ids[:-1], seed,
125
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
126
+ return_vis=True)
127
+ seed_everything(seed)
128
+ model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
129
+ 512//8, 512//8, region_target_token_ids[:-1], seed,
130
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
131
+ return_vis=True)
132
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
133
  interpolation=transforms.InterpolationMode.BICUBIC,
134
  antialias=True)
135
  for color_obj_mask in color_obj_masks]
136
  text_format_dict['color_obj_atten'] = color_obj_masks
137
+ model.remove_tokenmap_hooks()
138
 
139
  # generate image from rich text
140
  begin_time = time.time()
141
  seed_everything(seed)
142
+ if background_aug:
143
+ bg_aug_end = 500
144
+ else:
145
+ bg_aug_end = 1000
146
  rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
147
  height=height, width=width, num_inference_steps=steps,
148
+ guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
149
+ text_format_dict=text_format_dict, inject_selfattn=inject_interval,
150
+ bg_aug_end=bg_aug_end)
151
  print('time lapses to generate image from rich text: %.4f' %
152
  (time.time()-begin_time))
153
+ return [plain_img[0], rich_img[0], segments_vis, token_maps]
154
 
155
  with gr.Blocks(css=css) as demo:
156
  url_params = gr.JSON({}, visible=False, label="URL Params")
 
176
  placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
177
  elem_id="negative_prompt"
178
  )
179
+ segment_threshold = gr.Slider(label='Token map threshold',
180
+ info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
181
+ minimum=0,
182
+ maximum=1,
183
+ step=0.01,
184
+ value=0.25)
185
+ inject_interval = gr.Slider(label='Detail preservation',
186
+ info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
187
+ minimum=0,
188
+ maximum=1,
189
+ step=0.01,
190
+ value=0.)
191
+ color_guidance_weight = gr.Slider(label='Color weight',
192
+ info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
193
+ minimum=0,
194
+ maximum=2,
195
+ step=0.1,
196
+ value=0.5)
197
+ num_segments = gr.Slider(label='Number of segments',
198
+ minimum=2,
199
+ maximum=20,
200
+ step=1,
201
+ value=9)
202
  seed = gr.Slider(label='Seed',
203
  minimum=0,
204
  maximum=100000,
 
206
  value=6,
207
  elem_id="seed"
208
  )
209
+ background_aug = gr.Checkbox(
210
+ label='Precise region alignment',
211
+ info='(For strict region alignment, select this option, but beware of potential artifacts when using with style.)',
212
+ value=True)
 
213
  with gr.Accordion('Other Parameters', open=False):
214
  steps = gr.Slider(label='Number of Steps',
215
  minimum=0,
216
+ maximum=500,
217
  step=1,
218
  value=41)
219
  guidance_weight = gr.Slider(label='CFG weight',
 
242
  with gr.Row():
243
  plaintext_result = gr.Image(
244
  label='Plain-text', elem_id="plain-text-image")
245
+ segments = gr.Image(label='Segmentation')
246
+ with gr.Row():
247
  token_map = gr.Image(label='Token Maps')
248
  with gr.Row(visible=False) as share_row:
249
  with gr.Group(elem_id="share-btn-container"):
 
256
  gr.Markdown(help_text)
257
 
258
  with gr.Row():
259
+ footnote_examples = [
260
  [
261
+ '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
262
  '',
263
+ 5,
264
+ 0.3,
265
+ 0,
266
  6,
267
  1,
268
+ None,
269
+ True
270
  ],
271
  [
272
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"kitchen island with a stove with gas burners and a built-in oven "},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
273
  '',
274
+ 6,
275
+ 0.5,
276
+ 0,
277
+ 6,
278
  1,
279
+ None,
280
+ True
281
  ],
282
  [
283
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
284
+ '',
285
+ 4,
286
+ 0.3,
287
+ 0,
288
+ 4,
289
  1,
290
+ None,
291
+ True
292
  ],
293
  ]
294
+
295
+ gr.Examples(examples=footnote_examples,
296
+ label='Footnote examples',
297
  inputs=[
298
  text_input,
299
  negative_prompt,
300
+ num_segments,
301
+ segment_threshold,
302
+ inject_interval,
303
  seed,
304
  color_guidance_weight,
305
  rich_text_input,
306
+ background_aug,
307
  ],
308
  outputs=[
309
  plaintext_result,
310
  richtext_result,
311
+ segments,
312
  token_map,
313
  ],
314
  fn=generate,
315
  # cache_examples=True,
316
  examples_per_page=20)
317
  with gr.Row():
318
+ color_examples = [
319
  [
320
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#00ffff"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
321
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
322
+ 9,
323
+ 0.25,
324
+ 0.3,
325
  6,
326
+ 0.5,
327
+ None,
328
+ True
329
  ],
330
  [
331
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#eeeeee"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
332
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
333
+ 9,
334
+ 0.25,
335
+ 0.3,
336
  6,
337
+ 0.1,
338
+ None,
339
+ True
340
  ],
341
  [
342
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
343
  '',
344
+ 5,
345
+ 0.3,
346
+ 0.3,
347
  6,
348
+ 0.5,
349
+ None,
350
+ False
351
+ ],
352
+ [
353
+ '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
354
+ '',
355
+ 3,
356
+ 0.3,
357
+ 0,
358
+ 9,
359
  1,
360
+ None,
361
+ False
362
+ ],
363
+ [
364
+ '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
365
+ '',
366
+ 5,
367
+ 0.3,
368
+ 0,
369
+ 5,
370
+ 0.6,
371
+ None,
372
+ False
373
  ],
374
  ]
375
+ gr.Examples(examples=color_examples,
376
+ label='Font color examples',
 
377
  inputs=[
378
  text_input,
379
  negative_prompt,
380
+ num_segments,
381
+ segment_threshold,
382
+ inject_interval,
383
  seed,
384
  color_guidance_weight,
385
  rich_text_input,
386
+ background_aug,
387
  ],
388
  outputs=[
389
  plaintext_result,
390
  richtext_result,
391
+ segments,
392
  token_map,
393
  ],
394
  fn=generate,
395
  # cache_examples=True,
396
  examples_per_page=20)
397
+
398
  with gr.Row():
399
+ style_examples = [
400
  [
401
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
402
  '',
403
+ 5,
404
+ 0.3,
405
+ 0.2,
406
+ 3,
407
+ 0.5,
408
+ None,
409
+ False
410
  ],
411
  [
412
+ '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
413
+ 'worst quality, dark, poor quality',
414
+ 5,
415
+ 0.3,
416
+ 0,
417
  9,
418
+ 0.5,
419
+ None,
420
+ False
421
  ],
422
  [
423
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
424
  '',
425
+ 2,
426
+ 0.4,
427
+ 0,
428
+ 6,
429
+ 0.5,
430
+ None,
431
+ False
432
  ],
433
  ]
434
+ gr.Examples(examples=style_examples,
435
+ label='Font style examples',
436
  inputs=[
437
  text_input,
438
  negative_prompt,
439
+ num_segments,
440
+ segment_threshold,
441
+ inject_interval,
442
  seed,
443
  color_guidance_weight,
444
  rich_text_input,
445
+ background_aug,
446
  ],
447
  outputs=[
448
  plaintext_result,
449
  richtext_result,
450
+ segments,
451
  token_map,
452
  ],
453
  fn=generate,
454
  # cache_examples=True,
455
  examples_per_page=20)
456
+
457
  with gr.Row():
458
  size_examples = [
459
  [
460
  '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
461
  'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
462
+ 5,
463
+ 0.3,
464
+ 0,
465
  13,
466
  1,
467
+ None,
468
+ False
469
  ],
470
  [
471
  '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
472
  'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
473
+ 5,
474
+ 0.3,
475
+ 0,
476
  13,
477
  1,
478
+ None,
479
+ False
480
  ],
481
  [
482
  '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
483
  'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
484
+ 5,
485
+ 0.3,
486
+ 0,
487
  13,
488
  1,
489
+ None,
490
+ False
491
  ],
492
  ]
493
  gr.Examples(examples=size_examples,
 
495
  inputs=[
496
  text_input,
497
  negative_prompt,
498
+ num_segments,
499
+ segment_threshold,
500
+ inject_interval,
501
  seed,
502
  color_guidance_weight,
503
  rich_text_input,
504
+ background_aug,
505
  ],
506
  outputs=[
507
  plaintext_result,
508
  richtext_result,
509
+ segments,
510
  token_map,
511
  ],
512
  fn=generate,
 
521
  width,
522
  seed,
523
  steps,
524
+ num_segments,
525
+ segment_threshold,
526
+ inject_interval,
527
  guidance_weight,
528
  color_guidance_weight,
529
+ rich_text_input,
530
+ background_aug
531
  ],
532
+ outputs=[plaintext_result, richtext_result, segments, token_map],
533
  _js=get_js_data
534
  ).then(
535
  fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
models/attention.py CHANGED
@@ -492,7 +492,7 @@ class BasicTransformerBlock(nn.Module):
492
 
493
  if self.only_cross_attention:
494
  attn_out, _ = self.attn1(
495
- norm_hidden_states, context, text_format_dict=text_format_dict) + hidden_states
496
  hidden_states = attn_out + hidden_states
497
  else:
498
  attn_out, _ = self.attn1(norm_hidden_states)
@@ -583,7 +583,7 @@ class CrossAttention(nn.Module):
583
  head_size, seq_len, seq_len2)
584
  return tensor.mean(1)
585
 
586
- def forward(self, hidden_states, context=None, mask=None, text_format_dict={}):
587
  batch_size, sequence_length, _ = hidden_states.shape
588
 
589
  query = self.to_q(hidden_states)
@@ -607,7 +607,7 @@ class CrossAttention(nn.Module):
607
  if self._slice_size is None or query.shape[0] // self._slice_size == 1:
608
  # only this attention function is used
609
  hidden_states, attn_probs = self._attention(
610
- query, key, value, **text_format_dict)
611
 
612
  # linear proj
613
  hidden_states = self.to_out[0](hidden_states)
@@ -625,11 +625,11 @@ class CrossAttention(nn.Module):
625
  alpha=self.scale,
626
  )
627
 
628
- def _attention(self, query, key, value, word_pos=None, font_size=None,
629
  **kwargs):
630
  attention_scores = self._qk(query, key)
631
 
632
- # Font size:
633
  if self.is_cross_attn and word_pos is not None and font_size is not None:
634
  assert key.shape[1] == 77
635
  attention_score_exp = attention_scores.exp()
@@ -642,13 +642,25 @@ class CrossAttention(nn.Module):
642
  else:
643
  attention_probs = attention_scores.softmax(dim=-1)
644
 
645
- hidden_states = torch.bmm(attention_probs, value)
 
 
 
 
 
 
 
 
 
 
646
 
647
  # reshape hidden_states
648
  hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
649
- attention_probs = self.reshape_batch_dim_to_heads_and_average(
 
 
650
  attention_probs)
651
- return hidden_states, attention_probs
652
 
653
  def _memory_efficient_attention_xformers(self, query, key, value):
654
  query = query.contiguous()
 
492
 
493
  if self.only_cross_attention:
494
  attn_out, _ = self.attn1(
495
+ norm_hidden_states, context=context, text_format_dict=text_format_dict) + hidden_states
496
  hidden_states = attn_out + hidden_states
497
  else:
498
  attn_out, _ = self.attn1(norm_hidden_states)
 
583
  head_size, seq_len, seq_len2)
584
  return tensor.mean(1)
585
 
586
+ def forward(self, hidden_states, real_attn_probs=None, context=None, mask=None, text_format_dict={}):
587
  batch_size, sequence_length, _ = hidden_states.shape
588
 
589
  query = self.to_q(hidden_states)
 
607
  if self._slice_size is None or query.shape[0] // self._slice_size == 1:
608
  # only this attention function is used
609
  hidden_states, attn_probs = self._attention(
610
+ query, key, value, real_attn_probs, **text_format_dict)
611
 
612
  # linear proj
613
  hidden_states = self.to_out[0](hidden_states)
 
625
  alpha=self.scale,
626
  )
627
 
628
+ def _attention(self, query, key, value, real_attn_probs=None, word_pos=None, font_size=None,
629
  **kwargs):
630
  attention_scores = self._qk(query, key)
631
 
632
+ # Font size V2:
633
  if self.is_cross_attn and word_pos is not None and font_size is not None:
634
  assert key.shape[1] == 77
635
  attention_score_exp = attention_scores.exp()
 
642
  else:
643
  attention_probs = attention_scores.softmax(dim=-1)
644
 
645
+ # compute attention output
646
+ if real_attn_probs is None:
647
+ hidden_states = torch.bmm(attention_probs, value)
648
+ else:
649
+ if isinstance(real_attn_probs, dict):
650
+ for pos1, pos2 in zip(real_attn_probs['inject_pos'][0], real_attn_probs['inject_pos'][1]):
651
+ attention_probs[:, :,
652
+ pos2] = real_attn_probs['reference'][:, :, pos1]
653
+ hidden_states = torch.bmm(attention_probs, value)
654
+ else:
655
+ hidden_states = torch.bmm(real_attn_probs, value)
656
 
657
  # reshape hidden_states
658
  hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
659
+
660
+ # we also return the map averaged over heads to save memory footprint
661
+ attention_probs_avg = self.reshape_batch_dim_to_heads_and_average(
662
  attention_probs)
663
+ return hidden_states, [attention_probs_avg, attention_probs]
664
 
665
  def _memory_efficient_attention_xformers(self, query, key, value):
666
  query = query.contiguous()
models/region_diffusion.py CHANGED
@@ -6,6 +6,7 @@ from functools import partial
6
  from transformers import CLIPTextModel, CLIPTokenizer, logging
7
  from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
8
  from models.unet_2d_condition import UNet2DConditionModel
 
9
 
10
  # suppress partial model loading warning
11
  logging.set_verbosity_error()
@@ -38,6 +39,7 @@ class RegionDiffusion(nn.Module):
38
  self.masks = []
39
  self.attention_maps = None
40
  self.color_loss = torch.nn.functional.mse_loss
 
41
 
42
  print(f'[INFO] loaded stable diffusion!')
43
 
@@ -79,47 +81,83 @@ class RegionDiffusion(nn.Module):
79
  return text_embeddings
80
 
81
  def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
82
- latents=None, use_grad_guidance=False, text_format_dict={}):
83
 
84
  if latents is None:
85
  latents = torch.randn(
86
  (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
87
 
 
 
88
  self.scheduler.set_timesteps(num_inference_steps)
89
  n_styles = text_embeddings.shape[0]-1
90
  assert n_styles == len(self.masks)
91
-
92
  with torch.autocast('cuda'):
93
  for i, t in enumerate(self.scheduler.timesteps):
94
 
95
  # predict the noise residual
96
  with torch.no_grad():
97
- noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
98
- text_format_dict={})['sample']
99
- noise_pred_text = None
100
- for style_i, mask in enumerate(self.masks):
101
- if style_i < len(self.masks) - 1:
102
- masked_latent = latents
103
- noise_pred_text_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
 
104
  text_format_dict={})['sample']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  else:
106
- noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
107
- text_format_dict=text_format_dict)['sample']
108
- if noise_pred_text is None:
109
- noise_pred_text = noise_pred_text_cur * mask
110
- else:
111
- noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
 
112
 
113
  # perform classifier-free guidance
114
  noise_pred = noise_pred_uncond + guidance_scale * \
115
  (noise_pred_text - noise_pred_uncond)
116
 
117
- # compute the previous noisy sample x_t -> x_t-1
118
- latents = self.scheduler.step(noise_pred, t, latents)[
119
- 'prev_sample']
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # apply gradient guidance
122
- if use_grad_guidance and t < text_format_dict['guidance_start_step']:
123
  with torch.enable_grad():
124
  if not latents.requires_grad:
125
  latents.requires_grad = True
@@ -137,7 +175,7 @@ class RegionDiffusion(nn.Module):
137
  loss_total += loss
138
  loss_total.backward()
139
  latents = (
140
- latents - latents.grad * text_format_dict['color_guidance_weight']).detach().clone()
141
 
142
  return latents
143
 
@@ -162,6 +200,7 @@ class RegionDiffusion(nn.Module):
162
  (text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
163
 
164
  self.scheduler.set_timesteps(num_inference_steps)
 
165
 
166
  with torch.autocast('cuda'):
167
  for i, t in enumerate(self.scheduler.timesteps):
@@ -202,8 +241,18 @@ class RegionDiffusion(nn.Module):
202
 
203
  return imgs
204
 
 
 
 
 
 
 
 
 
 
 
205
  def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
206
- guidance_scale=7.5, latents=None, text_format_dict={}, use_grad_guidance=False):
207
 
208
  if isinstance(prompts, str):
209
  prompts = [prompts]
@@ -215,18 +264,11 @@ class RegionDiffusion(nn.Module):
215
  text_embeds = self.get_text_embeds(
216
  prompts, negative_prompts) # [2, 77, 768]
217
 
218
- if len(text_format_dict) > 0:
219
- if 'font_styles' in text_format_dict and text_format_dict['font_styles'] is not None:
220
- text_format_dict['font_styles_embs'] = self.get_text_embeds_list(
221
- text_format_dict['font_styles']) # [2, 77, 768]
222
- else:
223
- text_format_dict['font_styles_embs'] = None
224
-
225
  # else:
226
  latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
227
  num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
228
- use_grad_guidance=use_grad_guidance, text_format_dict=text_format_dict) # [1, 4, 64, 64]
229
-
230
  # Img latents -> imgs
231
  imgs = self.decode_latents(latents) # [1, 3, 512, 512]
232
 
@@ -272,7 +314,156 @@ class RegionDiffusion(nn.Module):
272
  # attention_dict is a dictionary containing attention maps for every attention layer
273
  self.attention_maps = attention_dict
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  def remove_evaluation_hooks(self):
276
  for hook in self.forward_hooks:
277
  hook.remove()
278
  self.attention_maps = None
 
 
 
 
 
 
 
 
 
6
  from transformers import CLIPTextModel, CLIPTokenizer, logging
7
  from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
8
  from models.unet_2d_condition import UNet2DConditionModel
9
+ from utils.attention_utils import CrossAttentionLayers, SelfAttentionLayers
10
 
11
  # suppress partial model loading warning
12
  logging.set_verbosity_error()
 
39
  self.masks = []
40
  self.attention_maps = None
41
  self.color_loss = torch.nn.functional.mse_loss
42
+ self.forward_replacement_hooks = []
43
 
44
  print(f'[INFO] loaded stable diffusion!')
45
 
 
81
  return text_embeddings
82
 
83
  def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
84
+ latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, bg_aug_end=1000):
85
 
86
  if latents is None:
87
  latents = torch.randn(
88
  (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
89
 
90
+ if inject_selfattn > 0:
91
+ latents_reference = latents.clone().detach()
92
  self.scheduler.set_timesteps(num_inference_steps)
93
  n_styles = text_embeddings.shape[0]-1
94
  assert n_styles == len(self.masks)
 
95
  with torch.autocast('cuda'):
96
  for i, t in enumerate(self.scheduler.timesteps):
97
 
98
  # predict the noise residual
99
  with torch.no_grad():
100
+ # tokens without any attributes
101
+ feat_inject_step = t > (1-inject_selfattn) * 1000
102
+ noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
103
+ text_format_dict={})['sample']
104
+ noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
105
+ text_format_dict=text_format_dict)['sample']
106
+ if inject_selfattn > 0:
107
+ noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
108
  text_format_dict={})['sample']
109
+ self.register_selfattn_hooks(feat_inject_step)
110
+ noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:],
111
+ text_format_dict={})['sample']
112
+ self.remove_selfattn_hooks()
113
+ noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
114
+ noise_pred_text = noise_pred_text_cur * self.masks[-1]
115
+ # tokens with attributes
116
+ for style_i, mask in enumerate(self.masks[:-1]):
117
+ if t > bg_aug_end:
118
+ rand_rgb = torch.rand([1, 3, 1, 1]).cuda()
119
+ black_background = torch.ones(
120
+ [1, 3, height, width]).cuda()*rand_rgb
121
+ black_latent = self.encode_imgs(
122
+ black_background)
123
+ noise = torch.randn_like(black_latent)
124
+ black_latent_noisy = self.scheduler.add_noise(
125
+ black_latent, noise, t)
126
+ masked_latent = (
127
+ mask > 0.001) * latents + (mask < 0.001) * black_latent_noisy
128
+ noise_pred_uncond_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[:1],
129
+ text_format_dict={})['sample']
130
  else:
131
+ masked_latent = latents
132
+ self.register_replacement_hooks(feat_inject_step)
133
+ noise_pred_text_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
134
+ text_format_dict={})['sample']
135
+ self.remove_replacement_hooks()
136
+ noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
137
+ noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
138
 
139
  # perform classifier-free guidance
140
  noise_pred = noise_pred_uncond + guidance_scale * \
141
  (noise_pred_text - noise_pred_uncond)
142
 
143
+ if inject_selfattn > 0:
144
+ noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
145
+ (noise_pred_text_refer - noise_pred_uncond_refer)
146
+
147
+ # compute the previous noisy sample x_t -> x_t-1
148
+ latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
149
+ torch.cat([latents, latents_reference]))[
150
+ 'prev_sample']
151
+ latents, latents_reference = torch.chunk(
152
+ latents_reference, 2, dim=0)
153
+
154
+ else:
155
+ # compute the previous noisy sample x_t -> x_t-1
156
+ latents = self.scheduler.step(noise_pred, t, latents)[
157
+ 'prev_sample']
158
 
159
+ # apply guidance
160
+ if use_guidance and t < text_format_dict['guidance_start_step']:
161
  with torch.enable_grad():
162
  if not latents.requires_grad:
163
  latents.requires_grad = True
 
175
  loss_total += loss
176
  loss_total.backward()
177
  latents = (
178
+ latents - latents.grad * text_format_dict['color_guidance_weight'] * self.masks[0]).detach().clone()
179
 
180
  return latents
181
 
 
200
  (text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
201
 
202
  self.scheduler.set_timesteps(num_inference_steps)
203
+ self.remove_replacement_hooks()
204
 
205
  with torch.autocast('cuda'):
206
  for i, t in enumerate(self.scheduler.timesteps):
 
241
 
242
  return imgs
243
 
244
+ def encode_imgs(self, imgs):
245
+ # imgs: [B, 3, H, W]
246
+
247
+ imgs = 2 * imgs - 1
248
+
249
+ posterior = self.vae.encode(imgs).latent_dist
250
+ latents = posterior.sample() * 0.18215
251
+
252
+ return latents
253
+
254
  def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
255
+ guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, bg_aug_end=1000):
256
 
257
  if isinstance(prompts, str):
258
  prompts = [prompts]
 
264
  text_embeds = self.get_text_embeds(
265
  prompts, negative_prompts) # [2, 77, 768]
266
 
 
 
 
 
 
 
 
267
  # else:
268
  latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
269
  num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
270
+ use_guidance=use_guidance, text_format_dict=text_format_dict,
271
+ inject_selfattn=inject_selfattn, bg_aug_end=bg_aug_end) # [1, 4, 64, 64]
272
  # Img latents -> imgs
273
  imgs = self.decode_latents(latents) # [1, 3, 512, 512]
274
 
 
314
  # attention_dict is a dictionary containing attention maps for every attention layer
315
  self.attention_maps = attention_dict
316
 
317
+ def register_selfattn_hooks(self, feat_inject_step=False):
318
+ r"""Function for registering hooks during evaluation.
319
+ We mainly store activation maps averaged over queries.
320
+ """
321
+ self.selfattn_forward_hooks = []
322
+
323
+ def save_activations(activations, name, module, inp, out):
324
+ r"""
325
+ PyTorch Forward hook to save outputs at each forward pass.
326
+ """
327
+ # out[0] - final output of attention layer
328
+ # out[1] - attention probability matrix
329
+ if 'attn2' in name:
330
+ assert out[1][1].shape[-1] == 77
331
+ # cross attention injection
332
+ # activations[name] = out[1][1].detach()
333
+ else:
334
+ assert out[1][1].shape[-1] != 77
335
+ activations[name] = out[1][1].detach()
336
+
337
+ def save_resnet_activations(activations, name, module, inp, out):
338
+ r"""
339
+ PyTorch Forward hook to save outputs at each forward pass.
340
+ """
341
+ # out[0] - final output of residual layer
342
+ # out[1] - residual hidden feature
343
+ # import ipdb
344
+ # ipdb.set_trace()
345
+ assert out[1].shape[-1] == 16
346
+ activations[name] = out[1].detach()
347
+ attention_dict = collections.defaultdict(list)
348
+ for name, module in self.unet.named_modules():
349
+ leaf_name = name.split('.')[-1]
350
+ if 'attn' in leaf_name and feat_inject_step:
351
+ # Register hook to obtain outputs at every attention layer.
352
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
353
+ partial(save_activations, attention_dict, name)
354
+ ))
355
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
356
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
357
+ partial(save_resnet_activations, attention_dict, name)
358
+ ))
359
+ # attention_dict is a dictionary containing attention maps for every attention layer
360
+ self.self_attention_maps_cur = attention_dict
361
+
362
+ def register_replacement_hooks(self, feat_inject_step=False):
363
+ r"""Function for registering hooks to replace self attention.
364
+ """
365
+ self.forward_replacement_hooks = []
366
+
367
+ def replace_activations(name, module, args):
368
+ r"""
369
+ PyTorch Forward hook to save outputs at each forward pass.
370
+ """
371
+ if 'attn1' in name:
372
+ modified_args = (args[0], self.self_attention_maps_cur[name])
373
+ return modified_args
374
+ # cross attention injection
375
+ # elif 'attn2' in name:
376
+ # modified_map = {
377
+ # 'reference': self.self_attention_maps_cur[name],
378
+ # 'inject_pos': self.inject_pos,
379
+ # }
380
+ # modified_args = (args[0], modified_map)
381
+ # return modified_args
382
+
383
+ def replace_resnet_activations(name, module, args):
384
+ r"""
385
+ PyTorch Forward hook to save outputs at each forward pass.
386
+ """
387
+ modified_args = (args[0], args[1],
388
+ self.self_attention_maps_cur[name])
389
+ return modified_args
390
+ for name, module in self.unet.named_modules():
391
+ leaf_name = name.split('.')[-1]
392
+ if 'attn' in leaf_name and feat_inject_step:
393
+ # Register hook to obtain outputs at every attention layer.
394
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
395
+ partial(replace_activations, name)
396
+ ))
397
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
398
+ # Register hook to obtain outputs at every attention layer.
399
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
400
+ partial(replace_resnet_activations, name)
401
+ ))
402
+
403
+ def register_tokenmap_hooks(self):
404
+ r"""Function for registering hooks during evaluation.
405
+ We mainly store activation maps averaged over queries.
406
+ """
407
+ self.forward_hooks = []
408
+
409
+ def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
410
+ r"""
411
+ PyTorch Forward hook to save outputs at each forward pass.
412
+ """
413
+ # out[0] - final output of attention layer
414
+ # out[1] - attention probability matrices
415
+ if name in n_maps:
416
+ n_maps[name] += 1
417
+ else:
418
+ n_maps[name] = 1
419
+ if 'attn2' in name:
420
+ assert out[1][0].shape[-1] == 77
421
+ if name in CrossAttentionLayers and n_maps[name] > 10:
422
+ if name in crossattn_maps:
423
+ crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
424
+ else:
425
+ crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
426
+ else:
427
+ assert out[1][0].shape[-1] != 77
428
+ if name in SelfAttentionLayers and n_maps[name] > 10:
429
+ if name in crossattn_maps:
430
+ selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
431
+ else:
432
+ selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
433
+
434
+ selfattn_maps = collections.defaultdict(list)
435
+ crossattn_maps = collections.defaultdict(list)
436
+ n_maps = collections.defaultdict(list)
437
+
438
+ for name, module in self.unet.named_modules():
439
+ leaf_name = name.split('.')[-1]
440
+ if 'attn' in leaf_name:
441
+ # Register hook to obtain outputs at every attention layer.
442
+ self.forward_hooks.append(module.register_forward_hook(
443
+ partial(save_activations, selfattn_maps,
444
+ crossattn_maps, n_maps, name)
445
+ ))
446
+ # attention_dict is a dictionary containing attention maps for every attention layer
447
+ self.selfattn_maps = selfattn_maps
448
+ self.crossattn_maps = crossattn_maps
449
+ self.n_maps = n_maps
450
+
451
+ def remove_tokenmap_hooks(self):
452
+ for hook in self.forward_hooks:
453
+ hook.remove()
454
+ self.selfattn_maps = None
455
+ self.crossattn_maps = None
456
+ self.n_maps = None
457
+
458
  def remove_evaluation_hooks(self):
459
  for hook in self.forward_hooks:
460
  hook.remove()
461
  self.attention_maps = None
462
+
463
+ def remove_replacement_hooks(self):
464
+ for hook in self.forward_replacement_hooks:
465
+ hook.remove()
466
+
467
+ def remove_selfattn_hooks(self):
468
+ for hook in self.selfattn_forward_hooks:
469
+ hook.remove()
models/unet_2d_blocks.py CHANGED
@@ -16,7 +16,7 @@ import torch
16
  from torch import nn
17
 
18
  from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
19
- from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
20
 
21
 
22
  def get_down_block(
@@ -36,7 +36,8 @@ def get_down_block(
36
  use_linear_projection=False,
37
  only_cross_attention=False,
38
  ):
39
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
 
40
  if down_block_type == "DownBlock2D":
41
  return DownBlock2D(
42
  num_layers=num_layers,
@@ -64,7 +65,8 @@ def get_down_block(
64
  )
65
  elif down_block_type == "CrossAttnDownBlock2D":
66
  if cross_attention_dim is None:
67
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
 
68
  return CrossAttnDownBlock2D(
69
  num_layers=num_layers,
70
  in_channels=in_channels,
@@ -147,7 +149,8 @@ def get_up_block(
147
  use_linear_projection=False,
148
  only_cross_attention=False,
149
  ):
150
- up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
 
151
  if up_block_type == "UpBlock2D":
152
  return UpBlock2D(
153
  num_layers=num_layers,
@@ -162,7 +165,8 @@ def get_up_block(
162
  )
163
  elif up_block_type == "CrossAttnUpBlock2D":
164
  if cross_attention_dim is None:
165
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
 
166
  return CrossAttnUpBlock2D(
167
  num_layers=num_layers,
168
  in_channels=in_channels,
@@ -258,7 +262,8 @@ class UNetMidBlock2D(nn.Module):
258
  super().__init__()
259
 
260
  self.attention_type = attention_type
261
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
 
262
 
263
  # there is always at least one resnet
264
  resnets = [
@@ -312,7 +317,7 @@ class UNetMidBlock2D(nn.Module):
312
  hidden_states = attn(hidden_states)
313
  else:
314
  hidden_states = attn(hidden_states, encoder_states)
315
- hidden_states = resnet(hidden_states, temb)
316
 
317
  return hidden_states
318
 
@@ -340,7 +345,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
340
 
341
  self.attention_type = attention_type
342
  self.attn_num_head_channels = attn_num_head_channels
343
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
 
344
 
345
  # there is always at least one resnet
346
  resnets = [
@@ -420,15 +426,16 @@ class UNetMidBlock2DCrossAttn(nn.Module):
420
 
421
  def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
422
  for attn in self.attentions:
423
- attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
 
424
 
425
  def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
426
  text_format_dict={}):
427
- hidden_states = self.resnets[0](hidden_states, temb)
428
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
429
- hidden_states = attn(hidden_states, encoder_hidden_states,
430
  text_format_dict).sample
431
- hidden_states = resnet(hidden_states, temb)
432
 
433
  return hidden_states
434
 
@@ -502,7 +509,7 @@ class AttnDownBlock2D(nn.Module):
502
  output_states = ()
503
 
504
  for resnet, attn in zip(self.resnets, self.attentions):
505
- hidden_states = resnet(hidden_states, temb)
506
  hidden_states = attn(hidden_states)
507
  output_states += (hidden_states,)
508
 
@@ -620,7 +627,8 @@ class CrossAttnDownBlock2D(nn.Module):
620
 
621
  def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
622
  for attn in self.attentions:
623
- attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
 
624
 
625
  def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
626
  text_format_dict={}):
@@ -638,13 +646,15 @@ class CrossAttnDownBlock2D(nn.Module):
638
 
639
  return custom_forward
640
 
641
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
642
  hidden_states = torch.utils.checkpoint.checkpoint(
643
- create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
644
- text_format_dict
 
 
 
645
  )[0]
646
  else:
647
- hidden_states = resnet(hidden_states, temb)
648
  hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
649
  text_format_dict=text_format_dict).sample
650
 
@@ -723,9 +733,10 @@ class DownBlock2D(nn.Module):
723
 
724
  return custom_forward
725
 
726
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
 
727
  else:
728
- hidden_states = resnet(hidden_states, temb)
729
 
730
  output_states += (hidden_states,)
731
 
@@ -789,7 +800,7 @@ class DownEncoderBlock2D(nn.Module):
789
 
790
  def forward(self, hidden_states):
791
  for resnet in self.resnets:
792
- hidden_states = resnet(hidden_states, temb=None)
793
 
794
  if self.downsamplers is not None:
795
  for downsampler in self.downsamplers:
@@ -861,7 +872,7 @@ class AttnDownEncoderBlock2D(nn.Module):
861
 
862
  def forward(self, hidden_states):
863
  for resnet, attn in zip(self.resnets, self.attentions):
864
- hidden_states = resnet(hidden_states, temb=None)
865
  hidden_states = attn(hidden_states)
866
 
867
  if self.downsamplers is not None:
@@ -937,8 +948,10 @@ class AttnSkipDownBlock2D(nn.Module):
937
  down=True,
938
  kernel="fir",
939
  )
940
- self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
941
- self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
 
 
942
  else:
943
  self.resnet_down = None
944
  self.downsamplers = None
@@ -948,7 +961,7 @@ class AttnSkipDownBlock2D(nn.Module):
948
  output_states = ()
949
 
950
  for resnet, attn in zip(self.resnets, self.attentions):
951
- hidden_states = resnet(hidden_states, temb)
952
  hidden_states = attn(hidden_states)
953
  output_states += (hidden_states,)
954
 
@@ -1017,8 +1030,10 @@ class SkipDownBlock2D(nn.Module):
1017
  down=True,
1018
  kernel="fir",
1019
  )
1020
- self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
1021
- self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
 
 
1022
  else:
1023
  self.resnet_down = None
1024
  self.downsamplers = None
@@ -1028,7 +1043,7 @@ class SkipDownBlock2D(nn.Module):
1028
  output_states = ()
1029
 
1030
  for resnet in self.resnets:
1031
- hidden_states = resnet(hidden_states, temb)
1032
  output_states += (hidden_states,)
1033
 
1034
  if self.downsamplers is not None:
@@ -1069,7 +1084,8 @@ class AttnUpBlock2D(nn.Module):
1069
  self.attention_type = attention_type
1070
 
1071
  for i in range(num_layers):
1072
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
 
1073
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1074
 
1075
  resnets.append(
@@ -1100,7 +1116,8 @@ class AttnUpBlock2D(nn.Module):
1100
  self.resnets = nn.ModuleList(resnets)
1101
 
1102
  if add_upsample:
1103
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
 
1104
  else:
1105
  self.upsamplers = None
1106
 
@@ -1109,9 +1126,10 @@ class AttnUpBlock2D(nn.Module):
1109
  # pop res hidden states
1110
  res_hidden_states = res_hidden_states_tuple[-1]
1111
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1112
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
1113
 
1114
- hidden_states = resnet(hidden_states, temb)
1115
  hidden_states = attn(hidden_states)
1116
 
1117
  if self.upsamplers is not None:
@@ -1152,7 +1170,8 @@ class CrossAttnUpBlock2D(nn.Module):
1152
  self.attn_num_head_channels = attn_num_head_channels
1153
 
1154
  for i in range(num_layers):
1155
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
 
1156
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1157
 
1158
  resnets.append(
@@ -1197,7 +1216,8 @@ class CrossAttnUpBlock2D(nn.Module):
1197
  self.resnets = nn.ModuleList(resnets)
1198
 
1199
  if add_upsample:
1200
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
 
1201
  else:
1202
  self.upsamplers = None
1203
 
@@ -1224,7 +1244,8 @@ class CrossAttnUpBlock2D(nn.Module):
1224
 
1225
  def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1226
  for attn in self.attentions:
1227
- attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
 
1228
 
1229
  def forward(
1230
  self,
@@ -1239,7 +1260,8 @@ class CrossAttnUpBlock2D(nn.Module):
1239
  # pop res hidden states
1240
  res_hidden_states = res_hidden_states_tuple[-1]
1241
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1242
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
1243
 
1244
  if self.training and self.gradient_checkpointing:
1245
 
@@ -1252,13 +1274,15 @@ class CrossAttnUpBlock2D(nn.Module):
1252
 
1253
  return custom_forward
1254
 
1255
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1256
  hidden_states = torch.utils.checkpoint.checkpoint(
1257
- create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
1258
- text_format_dict
 
 
 
1259
  )[0]
1260
  else:
1261
- hidden_states = resnet(hidden_states, temb)
1262
  hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
1263
  text_format_dict=text_format_dict).sample
1264
 
@@ -1290,7 +1314,8 @@ class UpBlock2D(nn.Module):
1290
  resnets = []
1291
 
1292
  for i in range(num_layers):
1293
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
 
1294
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1295
 
1296
  resnets.append(
@@ -1311,7 +1336,8 @@ class UpBlock2D(nn.Module):
1311
  self.resnets = nn.ModuleList(resnets)
1312
 
1313
  if add_upsample:
1314
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
 
1315
  else:
1316
  self.upsamplers = None
1317
 
@@ -1322,7 +1348,8 @@ class UpBlock2D(nn.Module):
1322
  # pop res hidden states
1323
  res_hidden_states = res_hidden_states_tuple[-1]
1324
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1325
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
1326
 
1327
  if self.training and self.gradient_checkpointing:
1328
 
@@ -1332,9 +1359,10 @@ class UpBlock2D(nn.Module):
1332
 
1333
  return custom_forward
1334
 
1335
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
 
1336
  else:
1337
- hidden_states = resnet(hidden_states, temb)
1338
 
1339
  if self.upsamplers is not None:
1340
  for upsampler in self.upsamplers:
@@ -1382,13 +1410,14 @@ class UpDecoderBlock2D(nn.Module):
1382
  self.resnets = nn.ModuleList(resnets)
1383
 
1384
  if add_upsample:
1385
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
 
1386
  else:
1387
  self.upsamplers = None
1388
 
1389
  def forward(self, hidden_states):
1390
  for resnet in self.resnets:
1391
- hidden_states = resnet(hidden_states, temb=None)
1392
 
1393
  if self.upsamplers is not None:
1394
  for upsampler in self.upsamplers:
@@ -1448,13 +1477,14 @@ class AttnUpDecoderBlock2D(nn.Module):
1448
  self.resnets = nn.ModuleList(resnets)
1449
 
1450
  if add_upsample:
1451
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
 
1452
  else:
1453
  self.upsamplers = None
1454
 
1455
  def forward(self, hidden_states):
1456
  for resnet, attn in zip(self.resnets, self.attentions):
1457
- hidden_states = resnet(hidden_states, temb=None)
1458
  hidden_states = attn(hidden_states)
1459
 
1460
  if self.upsamplers is not None:
@@ -1490,7 +1520,8 @@ class AttnSkipUpBlock2D(nn.Module):
1490
  self.attention_type = attention_type
1491
 
1492
  for i in range(num_layers):
1493
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
 
1494
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1495
 
1496
  self.resnets.append(
@@ -1499,7 +1530,8 @@ class AttnSkipUpBlock2D(nn.Module):
1499
  out_channels=out_channels,
1500
  temb_channels=temb_channels,
1501
  eps=resnet_eps,
1502
- groups=min(resnet_in_channels + res_skip_channels // 4, 32),
 
1503
  groups_out=min(out_channels // 4, 32),
1504
  dropout=dropout,
1505
  time_embedding_norm=resnet_time_scale_shift,
@@ -1536,7 +1568,8 @@ class AttnSkipUpBlock2D(nn.Module):
1536
  up=True,
1537
  kernel="fir",
1538
  )
1539
- self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 
1540
  self.skip_norm = torch.nn.GroupNorm(
1541
  num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1542
  )
@@ -1552,9 +1585,10 @@ class AttnSkipUpBlock2D(nn.Module):
1552
  # pop res hidden states
1553
  res_hidden_states = res_hidden_states_tuple[-1]
1554
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1555
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
1556
 
1557
- hidden_states = resnet(hidden_states, temb)
1558
 
1559
  hidden_states = self.attentions[0](hidden_states)
1560
 
@@ -1596,7 +1630,8 @@ class SkipUpBlock2D(nn.Module):
1596
  self.resnets = nn.ModuleList([])
1597
 
1598
  for i in range(num_layers):
1599
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
 
1600
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1601
 
1602
  self.resnets.append(
@@ -1605,7 +1640,8 @@ class SkipUpBlock2D(nn.Module):
1605
  out_channels=out_channels,
1606
  temb_channels=temb_channels,
1607
  eps=resnet_eps,
1608
- groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
 
1609
  groups_out=min(out_channels // 4, 32),
1610
  dropout=dropout,
1611
  time_embedding_norm=resnet_time_scale_shift,
@@ -1633,7 +1669,8 @@ class SkipUpBlock2D(nn.Module):
1633
  up=True,
1634
  kernel="fir",
1635
  )
1636
- self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 
1637
  self.skip_norm = torch.nn.GroupNorm(
1638
  num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1639
  )
@@ -1649,9 +1686,10 @@ class SkipUpBlock2D(nn.Module):
1649
  # pop res hidden states
1650
  res_hidden_states = res_hidden_states_tuple[-1]
1651
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1652
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
1653
 
1654
- hidden_states = resnet(hidden_states, temb)
1655
 
1656
  if skip_sample is not None:
1657
  skip_sample = self.upsampler(skip_sample)
@@ -1668,3 +1706,150 @@ class SkipUpBlock2D(nn.Module):
1668
  hidden_states = self.resnet_up(hidden_states, temb)
1669
 
1670
  return hidden_states, skip_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from torch import nn
17
 
18
  from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
19
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, Upsample2D
20
 
21
 
22
  def get_down_block(
 
36
  use_linear_projection=False,
37
  only_cross_attention=False,
38
  ):
39
+ down_block_type = down_block_type[7:] if down_block_type.startswith(
40
+ "UNetRes") else down_block_type
41
  if down_block_type == "DownBlock2D":
42
  return DownBlock2D(
43
  num_layers=num_layers,
 
65
  )
66
  elif down_block_type == "CrossAttnDownBlock2D":
67
  if cross_attention_dim is None:
68
+ raise ValueError(
69
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D")
70
  return CrossAttnDownBlock2D(
71
  num_layers=num_layers,
72
  in_channels=in_channels,
 
149
  use_linear_projection=False,
150
  only_cross_attention=False,
151
  ):
152
+ up_block_type = up_block_type[7:] if up_block_type.startswith(
153
+ "UNetRes") else up_block_type
154
  if up_block_type == "UpBlock2D":
155
  return UpBlock2D(
156
  num_layers=num_layers,
 
165
  )
166
  elif up_block_type == "CrossAttnUpBlock2D":
167
  if cross_attention_dim is None:
168
+ raise ValueError(
169
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D")
170
  return CrossAttnUpBlock2D(
171
  num_layers=num_layers,
172
  in_channels=in_channels,
 
262
  super().__init__()
263
 
264
  self.attention_type = attention_type
265
+ resnet_groups = resnet_groups if resnet_groups is not None else min(
266
+ in_channels // 4, 32)
267
 
268
  # there is always at least one resnet
269
  resnets = [
 
317
  hidden_states = attn(hidden_states)
318
  else:
319
  hidden_states = attn(hidden_states, encoder_states)
320
+ hidden_states, _ = resnet(hidden_states, temb)
321
 
322
  return hidden_states
323
 
 
345
 
346
  self.attention_type = attention_type
347
  self.attn_num_head_channels = attn_num_head_channels
348
+ resnet_groups = resnet_groups if resnet_groups is not None else min(
349
+ in_channels // 4, 32)
350
 
351
  # there is always at least one resnet
352
  resnets = [
 
426
 
427
  def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
428
  for attn in self.attentions:
429
+ attn._set_use_memory_efficient_attention_xformers(
430
+ use_memory_efficient_attention_xformers)
431
 
432
  def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
433
  text_format_dict={}):
434
+ hidden_states, _ = self.resnets[0](hidden_states, temb)
435
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
436
+ hidden_states = attn(hidden_states, encoder_hidden_states,
437
  text_format_dict).sample
438
+ hidden_states, _ = resnet(hidden_states, temb)
439
 
440
  return hidden_states
441
 
 
509
  output_states = ()
510
 
511
  for resnet, attn in zip(self.resnets, self.attentions):
512
+ hidden_states, _ = resnet(hidden_states, temb)
513
  hidden_states = attn(hidden_states)
514
  output_states += (hidden_states,)
515
 
 
627
 
628
  def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
629
  for attn in self.attentions:
630
+ attn._set_use_memory_efficient_attention_xformers(
631
+ use_memory_efficient_attention_xformers)
632
 
633
  def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
634
  text_format_dict={}):
 
646
 
647
  return custom_forward
648
 
 
649
  hidden_states = torch.utils.checkpoint.checkpoint(
650
+ create_custom_forward(resnet), hidden_states, temb)
651
+ hidden_states = torch.utils.checkpoint.checkpoint(
652
+ create_custom_forward(
653
+ attn, return_dict=False), hidden_states, encoder_hidden_states,
654
+ text_format_dict
655
  )[0]
656
  else:
657
+ hidden_states, _ = resnet(hidden_states, temb)
658
  hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
659
  text_format_dict=text_format_dict).sample
660
 
 
733
 
734
  return custom_forward
735
 
736
+ hidden_states = torch.utils.checkpoint.checkpoint(
737
+ create_custom_forward(resnet), hidden_states, temb)
738
  else:
739
+ hidden_states, _ = resnet(hidden_states, temb)
740
 
741
  output_states += (hidden_states,)
742
 
 
800
 
801
  def forward(self, hidden_states):
802
  for resnet in self.resnets:
803
+ hidden_states, _ = resnet(hidden_states, temb=None)
804
 
805
  if self.downsamplers is not None:
806
  for downsampler in self.downsamplers:
 
872
 
873
  def forward(self, hidden_states):
874
  for resnet, attn in zip(self.resnets, self.attentions):
875
+ hidden_states, _ = resnet(hidden_states, temb=None)
876
  hidden_states = attn(hidden_states)
877
 
878
  if self.downsamplers is not None:
 
948
  down=True,
949
  kernel="fir",
950
  )
951
+ self.downsamplers = nn.ModuleList(
952
+ [FirDownsample2D(out_channels, out_channels=out_channels)])
953
+ self.skip_conv = nn.Conv2d(
954
+ 3, out_channels, kernel_size=(1, 1), stride=(1, 1))
955
  else:
956
  self.resnet_down = None
957
  self.downsamplers = None
 
961
  output_states = ()
962
 
963
  for resnet, attn in zip(self.resnets, self.attentions):
964
+ hidden_states, _ = resnet(hidden_states, temb)
965
  hidden_states = attn(hidden_states)
966
  output_states += (hidden_states,)
967
 
 
1030
  down=True,
1031
  kernel="fir",
1032
  )
1033
+ self.downsamplers = nn.ModuleList(
1034
+ [FirDownsample2D(out_channels, out_channels=out_channels)])
1035
+ self.skip_conv = nn.Conv2d(
1036
+ 3, out_channels, kernel_size=(1, 1), stride=(1, 1))
1037
  else:
1038
  self.resnet_down = None
1039
  self.downsamplers = None
 
1043
  output_states = ()
1044
 
1045
  for resnet in self.resnets:
1046
+ hidden_states, _ = resnet(hidden_states, temb)
1047
  output_states += (hidden_states,)
1048
 
1049
  if self.downsamplers is not None:
 
1084
  self.attention_type = attention_type
1085
 
1086
  for i in range(num_layers):
1087
+ res_skip_channels = in_channels if (
1088
+ i == num_layers - 1) else out_channels
1089
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1090
 
1091
  resnets.append(
 
1116
  self.resnets = nn.ModuleList(resnets)
1117
 
1118
  if add_upsample:
1119
+ self.upsamplers = nn.ModuleList(
1120
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1121
  else:
1122
  self.upsamplers = None
1123
 
 
1126
  # pop res hidden states
1127
  res_hidden_states = res_hidden_states_tuple[-1]
1128
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1129
+ hidden_states = torch.cat(
1130
+ [hidden_states, res_hidden_states], dim=1)
1131
 
1132
+ hidden_states, _ = resnet(hidden_states, temb)
1133
  hidden_states = attn(hidden_states)
1134
 
1135
  if self.upsamplers is not None:
 
1170
  self.attn_num_head_channels = attn_num_head_channels
1171
 
1172
  for i in range(num_layers):
1173
+ res_skip_channels = in_channels if (
1174
+ i == num_layers - 1) else out_channels
1175
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1176
 
1177
  resnets.append(
 
1216
  self.resnets = nn.ModuleList(resnets)
1217
 
1218
  if add_upsample:
1219
+ self.upsamplers = nn.ModuleList(
1220
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1221
  else:
1222
  self.upsamplers = None
1223
 
 
1244
 
1245
  def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1246
  for attn in self.attentions:
1247
+ attn._set_use_memory_efficient_attention_xformers(
1248
+ use_memory_efficient_attention_xformers)
1249
 
1250
  def forward(
1251
  self,
 
1260
  # pop res hidden states
1261
  res_hidden_states = res_hidden_states_tuple[-1]
1262
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1263
+ hidden_states = torch.cat(
1264
+ [hidden_states, res_hidden_states], dim=1)
1265
 
1266
  if self.training and self.gradient_checkpointing:
1267
 
 
1274
 
1275
  return custom_forward
1276
 
 
1277
  hidden_states = torch.utils.checkpoint.checkpoint(
1278
+ create_custom_forward(resnet), hidden_states, temb)
1279
+ hidden_states = torch.utils.checkpoint.checkpoint(
1280
+ create_custom_forward(
1281
+ attn, return_dict=False), hidden_states, encoder_hidden_states,
1282
+ text_format_dict
1283
  )[0]
1284
  else:
1285
+ hidden_states, _ = resnet(hidden_states, temb)
1286
  hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
1287
  text_format_dict=text_format_dict).sample
1288
 
 
1314
  resnets = []
1315
 
1316
  for i in range(num_layers):
1317
+ res_skip_channels = in_channels if (
1318
+ i == num_layers - 1) else out_channels
1319
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1320
 
1321
  resnets.append(
 
1336
  self.resnets = nn.ModuleList(resnets)
1337
 
1338
  if add_upsample:
1339
+ self.upsamplers = nn.ModuleList(
1340
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1341
  else:
1342
  self.upsamplers = None
1343
 
 
1348
  # pop res hidden states
1349
  res_hidden_states = res_hidden_states_tuple[-1]
1350
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1351
+ hidden_states = torch.cat(
1352
+ [hidden_states, res_hidden_states], dim=1)
1353
 
1354
  if self.training and self.gradient_checkpointing:
1355
 
 
1359
 
1360
  return custom_forward
1361
 
1362
+ hidden_states = torch.utils.checkpoint.checkpoint(
1363
+ create_custom_forward(resnet), hidden_states, temb)
1364
  else:
1365
+ hidden_states, _ = resnet(hidden_states, temb)
1366
 
1367
  if self.upsamplers is not None:
1368
  for upsampler in self.upsamplers:
 
1410
  self.resnets = nn.ModuleList(resnets)
1411
 
1412
  if add_upsample:
1413
+ self.upsamplers = nn.ModuleList(
1414
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1415
  else:
1416
  self.upsamplers = None
1417
 
1418
  def forward(self, hidden_states):
1419
  for resnet in self.resnets:
1420
+ hidden_states, _ = resnet(hidden_states, temb=None)
1421
 
1422
  if self.upsamplers is not None:
1423
  for upsampler in self.upsamplers:
 
1477
  self.resnets = nn.ModuleList(resnets)
1478
 
1479
  if add_upsample:
1480
+ self.upsamplers = nn.ModuleList(
1481
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1482
  else:
1483
  self.upsamplers = None
1484
 
1485
  def forward(self, hidden_states):
1486
  for resnet, attn in zip(self.resnets, self.attentions):
1487
+ hidden_states, _ = resnet(hidden_states, temb=None)
1488
  hidden_states = attn(hidden_states)
1489
 
1490
  if self.upsamplers is not None:
 
1520
  self.attention_type = attention_type
1521
 
1522
  for i in range(num_layers):
1523
+ res_skip_channels = in_channels if (
1524
+ i == num_layers - 1) else out_channels
1525
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1526
 
1527
  self.resnets.append(
 
1530
  out_channels=out_channels,
1531
  temb_channels=temb_channels,
1532
  eps=resnet_eps,
1533
+ groups=min(resnet_in_channels +
1534
+ res_skip_channels // 4, 32),
1535
  groups_out=min(out_channels // 4, 32),
1536
  dropout=dropout,
1537
  time_embedding_norm=resnet_time_scale_shift,
 
1568
  up=True,
1569
  kernel="fir",
1570
  )
1571
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(
1572
+ 3, 3), stride=(1, 1), padding=(1, 1))
1573
  self.skip_norm = torch.nn.GroupNorm(
1574
  num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1575
  )
 
1585
  # pop res hidden states
1586
  res_hidden_states = res_hidden_states_tuple[-1]
1587
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1588
+ hidden_states = torch.cat(
1589
+ [hidden_states, res_hidden_states], dim=1)
1590
 
1591
+ hidden_states, _ = resnet(hidden_states, temb)
1592
 
1593
  hidden_states = self.attentions[0](hidden_states)
1594
 
 
1630
  self.resnets = nn.ModuleList([])
1631
 
1632
  for i in range(num_layers):
1633
+ res_skip_channels = in_channels if (
1634
+ i == num_layers - 1) else out_channels
1635
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
1636
 
1637
  self.resnets.append(
 
1640
  out_channels=out_channels,
1641
  temb_channels=temb_channels,
1642
  eps=resnet_eps,
1643
+ groups=min(
1644
+ (resnet_in_channels + res_skip_channels) // 4, 32),
1645
  groups_out=min(out_channels // 4, 32),
1646
  dropout=dropout,
1647
  time_embedding_norm=resnet_time_scale_shift,
 
1669
  up=True,
1670
  kernel="fir",
1671
  )
1672
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(
1673
+ 3, 3), stride=(1, 1), padding=(1, 1))
1674
  self.skip_norm = torch.nn.GroupNorm(
1675
  num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1676
  )
 
1686
  # pop res hidden states
1687
  res_hidden_states = res_hidden_states_tuple[-1]
1688
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1689
+ hidden_states = torch.cat(
1690
+ [hidden_states, res_hidden_states], dim=1)
1691
 
1692
+ hidden_states, _ = resnet(hidden_states, temb)
1693
 
1694
  if skip_sample is not None:
1695
  skip_sample = self.upsampler(skip_sample)
 
1706
  hidden_states = self.resnet_up(hidden_states, temb)
1707
 
1708
  return hidden_states, skip_sample
1709
+
1710
+
1711
+ class ResnetBlock2D(nn.Module):
1712
+ def __init__(
1713
+ self,
1714
+ *,
1715
+ in_channels,
1716
+ out_channels=None,
1717
+ conv_shortcut=False,
1718
+ dropout=0.0,
1719
+ temb_channels=512,
1720
+ groups=32,
1721
+ groups_out=None,
1722
+ pre_norm=True,
1723
+ eps=1e-6,
1724
+ non_linearity="swish",
1725
+ time_embedding_norm="default",
1726
+ kernel=None,
1727
+ output_scale_factor=1.0,
1728
+ use_in_shortcut=None,
1729
+ up=False,
1730
+ down=False,
1731
+ ):
1732
+ super().__init__()
1733
+ self.pre_norm = pre_norm
1734
+ self.pre_norm = True
1735
+ self.in_channels = in_channels
1736
+ out_channels = in_channels if out_channels is None else out_channels
1737
+ self.out_channels = out_channels
1738
+ self.use_conv_shortcut = conv_shortcut
1739
+ self.time_embedding_norm = time_embedding_norm
1740
+ self.up = up
1741
+ self.down = down
1742
+ self.output_scale_factor = output_scale_factor
1743
+
1744
+ if groups_out is None:
1745
+ groups_out = groups
1746
+
1747
+ self.norm1 = torch.nn.GroupNorm(
1748
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
1749
+
1750
+ self.conv1 = torch.nn.Conv2d(
1751
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
1752
+
1753
+ if temb_channels is not None:
1754
+ if self.time_embedding_norm == "default":
1755
+ time_emb_proj_out_channels = out_channels
1756
+ elif self.time_embedding_norm == "scale_shift":
1757
+ time_emb_proj_out_channels = out_channels * 2
1758
+ else:
1759
+ raise ValueError(
1760
+ f"unknown time_embedding_norm : {self.time_embedding_norm} ")
1761
+
1762
+ self.time_emb_proj = torch.nn.Linear(
1763
+ temb_channels, time_emb_proj_out_channels)
1764
+ else:
1765
+ self.time_emb_proj = None
1766
+
1767
+ self.norm2 = torch.nn.GroupNorm(
1768
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
1769
+ self.dropout = torch.nn.Dropout(dropout)
1770
+ self.conv2 = torch.nn.Conv2d(
1771
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
1772
+
1773
+ if non_linearity == "swish":
1774
+ self.nonlinearity = lambda x: F.silu(x)
1775
+ elif non_linearity == "mish":
1776
+ self.nonlinearity = Mish()
1777
+ elif non_linearity == "silu":
1778
+ self.nonlinearity = nn.SiLU()
1779
+
1780
+ self.upsample = self.downsample = None
1781
+ if self.up:
1782
+ if kernel == "fir":
1783
+ fir_kernel = (1, 3, 3, 1)
1784
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
1785
+ elif kernel == "sde_vp":
1786
+ self.upsample = partial(
1787
+ F.interpolate, scale_factor=2.0, mode="nearest")
1788
+ else:
1789
+ self.upsample = Upsample2D(in_channels, use_conv=False)
1790
+ elif self.down:
1791
+ if kernel == "fir":
1792
+ fir_kernel = (1, 3, 3, 1)
1793
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
1794
+ elif kernel == "sde_vp":
1795
+ self.downsample = partial(
1796
+ F.avg_pool2d, kernel_size=2, stride=2)
1797
+ else:
1798
+ self.downsample = Downsample2D(
1799
+ in_channels, use_conv=False, padding=1, name="op")
1800
+
1801
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
1802
+
1803
+ self.conv_shortcut = None
1804
+ if self.use_in_shortcut:
1805
+ self.conv_shortcut = torch.nn.Conv2d(
1806
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
1807
+
1808
+ def forward(self, input_tensor, temb, inject_states=None):
1809
+ hidden_states = input_tensor
1810
+
1811
+ hidden_states = self.norm1(hidden_states)
1812
+ hidden_states = self.nonlinearity(hidden_states)
1813
+
1814
+ if self.upsample is not None:
1815
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
1816
+ if hidden_states.shape[0] >= 64:
1817
+ input_tensor = input_tensor.contiguous()
1818
+ hidden_states = hidden_states.contiguous()
1819
+ input_tensor = self.upsample(input_tensor)
1820
+ hidden_states = self.upsample(hidden_states)
1821
+ elif self.downsample is not None:
1822
+ input_tensor = self.downsample(input_tensor)
1823
+ hidden_states = self.downsample(hidden_states)
1824
+
1825
+ hidden_states = self.conv1(hidden_states)
1826
+
1827
+ if temb is not None:
1828
+ temb = self.time_emb_proj(self.nonlinearity(temb))[
1829
+ :, :, None, None]
1830
+
1831
+ if temb is not None and self.time_embedding_norm == "default":
1832
+ hidden_states = hidden_states + temb
1833
+
1834
+ hidden_states = self.norm2(hidden_states)
1835
+
1836
+ if temb is not None and self.time_embedding_norm == "scale_shift":
1837
+ scale, shift = torch.chunk(temb, 2, dim=1)
1838
+ hidden_states = hidden_states * (1 + scale) + shift
1839
+
1840
+ hidden_states = self.nonlinearity(hidden_states)
1841
+
1842
+ hidden_states = self.dropout(hidden_states)
1843
+ hidden_states = self.conv2(hidden_states)
1844
+
1845
+ if self.conv_shortcut is not None:
1846
+ input_tensor = self.conv_shortcut(input_tensor)
1847
+
1848
+ if inject_states is not None:
1849
+ output_tensor = (input_tensor + inject_states) / \
1850
+ self.output_scale_factor
1851
+ else:
1852
+ output_tensor = (input_tensor + hidden_states) / \
1853
+ self.output_scale_factor
1854
+
1855
+ return output_tensor, hidden_states
utils/attention_utils.py CHANGED
@@ -6,7 +6,46 @@ import seaborn as sns
6
  import torch
7
  import torchvision
8
 
9
- from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def split_attention_maps_over_steps(attention_maps):
@@ -37,7 +76,7 @@ def split_attention_maps_over_steps(attention_maps):
37
 
38
  def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
39
  atten_names = ['presoftmax', 'postsoftmax', 'postsoftmax_erosion']
40
- for i, (attn_map, obj_token) in enumerate(zip(atten_map_list, obj_tokens)):
41
  n_obj = len(attn_map)
42
  plt.figure()
43
  plt.clf()
@@ -63,6 +102,7 @@ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=N
63
  cmap=cmap, vmin=vmin, vmax=vmax
64
  )
65
  axs[tid].set_axis_off()
 
66
  if tokens_vis is not None:
67
  if tid == n_obj-1:
68
  axs_xlabel = 'other tokens'
@@ -79,13 +119,14 @@ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=N
79
  canvas = fig.canvas
80
  canvas.draw()
81
  width, height = canvas.get_width_height()
82
- img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape((height, width, 3))
 
83
 
84
  fig.tight_layout()
85
  return img
86
 
87
 
88
- def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None):
89
  r"""Function to visualize attention maps.
90
  Args:
91
  save_dir (str): Path to save attention maps
@@ -98,25 +139,6 @@ def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0,
98
  attention_maps
99
  )
100
 
101
- selected_layers = [
102
- # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
103
- # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
104
- 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
105
- # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
106
- 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
107
- 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
108
- 'mid_block.attentions.0.transformer_blocks.0.attn2',
109
- 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
110
- 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
111
- 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
112
- # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
113
- 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
114
- # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
115
- # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
116
- # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
117
- # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
118
- ]
119
-
120
  nsteps = len(attention_maps_cond)
121
  hw_ori = width * height
122
 
@@ -128,7 +150,7 @@ def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0,
128
  attention_maps_cur = attention_maps_cond[step_num]
129
 
130
  for layer in attention_maps_cur.keys():
131
- if step_num < 10 or layer not in selected_layers:
132
  continue
133
 
134
  attention_ind = attention_maps_cur[layer].cpu()
@@ -179,7 +201,107 @@ def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0,
179
  attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
180
 
181
  token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
182
- obj_tokens, save_dir, seed, tokens_vis)
183
  attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
184
  [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
185
  return attention_maps_averaged_normalized, token_maps_vis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
  import torchvision
8
 
9
+ from sklearn.cluster import KMeans
10
+
11
+ SelfAttentionLayers = [
12
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
13
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
14
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
15
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
16
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
17
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
18
+ 'mid_block.attentions.0.transformer_blocks.0.attn1',
19
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
20
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
21
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
22
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
23
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
24
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
25
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
26
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
27
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
28
+ ]
29
+
30
+
31
+ CrossAttentionLayers = [
32
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
33
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
34
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
35
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
36
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
37
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
38
+ 'mid_block.attentions.0.transformer_blocks.0.attn2',
39
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
40
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
41
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
42
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
43
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
44
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
45
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
46
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
47
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
48
+ ]
49
 
50
 
51
  def split_attention_maps_over_steps(attention_maps):
 
76
 
77
  def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
78
  atten_names = ['presoftmax', 'postsoftmax', 'postsoftmax_erosion']
79
+ for i, attn_map in enumerate(atten_map_list):
80
  n_obj = len(attn_map)
81
  plt.figure()
82
  plt.clf()
 
102
  cmap=cmap, vmin=vmin, vmax=vmax
103
  )
104
  axs[tid].set_axis_off()
105
+
106
  if tokens_vis is not None:
107
  if tid == n_obj-1:
108
  axs_xlabel = 'other tokens'
 
119
  canvas = fig.canvas
120
  canvas.draw()
121
  width, height = canvas.get_width_height()
122
+ img = np.frombuffer(canvas.tostring_rgb(),
123
+ dtype='uint8').reshape((height, width, 3))
124
 
125
  fig.tight_layout()
126
  return img
127
 
128
 
129
+ def get_token_maps_deprecated(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None):
130
  r"""Function to visualize attention maps.
131
  Args:
132
  save_dir (str): Path to save attention maps
 
139
  attention_maps
140
  )
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  nsteps = len(attention_maps_cond)
143
  hw_ori = width * height
144
 
 
150
  attention_maps_cur = attention_maps_cond[step_num]
151
 
152
  for layer in attention_maps_cur.keys():
153
+ if step_num < 10 or layer not in CrossAttentionLayers:
154
  continue
155
 
156
  attention_ind = attention_maps_cur[layer].cpu()
 
201
  attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
202
 
203
  token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
204
+ obj_tokens, save_dir, seed, tokens_vis)
205
  attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
206
  [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
207
  return attention_maps_averaged_normalized, token_maps_vis
208
+
209
+
210
+ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None,
211
+ preprocess=False, segment_threshold=0.30, num_segments=9, return_vis=False):
212
+ r"""Function to visualize attention maps.
213
+ Args:
214
+ save_dir (str): Path to save attention maps
215
+ batch_size (int): Batch size
216
+ sampler_order (int): Sampler order
217
+ """
218
+
219
+ # create the segmentation mask using self-attention maps
220
+ resolution = 32
221
+ attn_maps_1024 = {8: [], 16: [], 32: []}
222
+ for attn_map in selfattn_maps.values():
223
+ resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
224
+ attn_map = attn_map.reshape(
225
+ 1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2])
226
+ attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
227
+ mode='bicubic', antialias=True)
228
+ attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
229
+ 1, resolution**2, resolution_map**2))
230
+ attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
231
+ for v in attn_maps_1024.values()], -1).numpy()
232
+ kmeans = KMeans(n_clusters=num_segments,
233
+ n_init=10).fit(attn_maps_1024)
234
+ clusters = kmeans.labels_
235
+ clusters = clusters.reshape(resolution, resolution)
236
+ fig = plt.figure()
237
+ plt.imshow(clusters)
238
+ plt.axis('off')
239
+ plt.savefig(os.path.join(save_dir, 'segmentation_k%d.jpg' % (num_segments)),
240
+ bbox_inches='tight', pad_inches=0)
241
+ if return_vis:
242
+ canvas = fig.canvas
243
+ canvas.draw()
244
+ cav_width, cav_height = canvas.get_width_height()
245
+ segments_vis = np.frombuffer(canvas.tostring_rgb(),
246
+ dtype='uint8').reshape((cav_height, cav_width, 3))
247
+
248
+ plt.close()
249
+
250
+ # label the segmentation mask using cross-attention maps
251
+ cross_attn_maps_1024 = []
252
+ for attn_map in crossattn_maps.values():
253
+ resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
254
+ attn_map = attn_map.reshape(
255
+ 1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2])
256
+ attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
257
+ mode='bicubic', antialias=True)
258
+ cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
259
+
260
+ cross_attn_maps_1024 = torch.cat(
261
+ cross_attn_maps_1024).mean(0).cpu().numpy()
262
+ normalized_span_maps = []
263
+ for token_ids in obj_tokens:
264
+ span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
265
+ normalized_span_map = np.zeros_like(span_token_maps)
266
+ for i in range(span_token_maps.shape[-1]):
267
+ curr_noun_map = span_token_maps[:, :, i]
268
+ normalized_span_map[:, :, i] = (
269
+ curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
270
+ normalized_span_maps.append(normalized_span_map)
271
+ foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
272
+ ) for normalized_span_map in normalized_span_maps]
273
+ background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze()
274
+ for c in range(num_segments):
275
+ cluster_mask = np.zeros_like(clusters)
276
+ cluster_mask[clusters == c] = 1.
277
+ is_foreground = False
278
+ for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens):
279
+ score_maps = [cluster_mask * normalized_span_map[:, :, i]
280
+ for i in range(len(token_ids))]
281
+ scores = [score_map.sum() / cluster_mask.sum()
282
+ for score_map in score_maps]
283
+ if max(scores) > segment_threshold:
284
+ foreground_nouns_map += cluster_mask
285
+ is_foreground = True
286
+ if not is_foreground:
287
+ background_map += cluster_mask
288
+ foreground_token_maps.append(background_map)
289
+
290
+ # resize the token maps and visualization
291
+ resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze(
292
+ 0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1)
293
+
294
+ resized_token_maps = resized_token_maps / \
295
+ (resized_token_maps.sum(0, True)+1e-8)
296
+ resized_token_maps = [token_map.unsqueeze(
297
+ 0) for token_map in resized_token_maps]
298
+ foreground_token_maps = [token_map[None, :, :]
299
+ for token_map in foreground_token_maps]
300
+ token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
301
+ save_dir, kmeans_seed, tokens_vis)
302
+ resized_token_maps = [token_map.unsqueeze(1).repeat(
303
+ [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
304
+ if return_vis:
305
+ return resized_token_maps, segments_vis, token_maps_vis
306
+ else:
307
+ return resized_token_maps
utils/richtext_utils.py CHANGED
@@ -27,7 +27,7 @@ def seed_everything(seed):
27
  torch.cuda.manual_seed(seed)
28
 
29
 
30
- def hex_to_rgb(hex_string, return_nearest_color=False, device='cuda'):
31
  r"""
32
  Covert Hex triplet to RGB triplet.
33
  """
@@ -40,8 +40,8 @@ def hex_to_rgb(hex_string, return_nearest_color=False, device='cuda'):
40
  rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41
  if return_nearest_color:
42
  nearest_color = find_nearest_color(rgb)
43
- return rgb.to(device), nearest_color
44
- return rgb.to(device)
45
 
46
 
47
  def find_nearest_color(rgb):
@@ -56,7 +56,7 @@ def find_nearest_color(rgb):
56
  return nearest_color
57
 
58
 
59
- def font2style(font, device='cuda'):
60
  r"""
61
  Convert the font name to the style name.
62
  """
@@ -71,7 +71,7 @@ def font2style(font, device='cuda'):
71
  'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72
 
73
 
74
- def parse_json(json_str, device):
75
  r"""
76
  Convert the JSON string to attributes.
77
  """
@@ -121,7 +121,7 @@ def parse_json(json_str, device):
121
  if 'color' in span['attributes']:
122
  use_grad_guidance = True
123
  color_rgb, nearest_color = hex_to_rgb(
124
- span['attributes']['color'], True, device=device)
125
  if prev_color_rgb == color_rgb:
126
  prev_text_prompt = color_text_prompts[-1]
127
  color_text_prompts[-1] = prev_text_prompt + \
@@ -197,8 +197,8 @@ def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes)
197
  word_pos.append(base_tokens.index(size_token)+1)
198
  font_sizes.append(font_size)
199
  if len(word_pos) > 0:
200
- word_pos = torch.LongTensor(word_pos).to(model.device)
201
- font_sizes = torch.FloatTensor(font_sizes).to(model.device)
202
  else:
203
  word_pos = None
204
  font_sizes = None
 
27
  torch.cuda.manual_seed(seed)
28
 
29
 
30
+ def hex_to_rgb(hex_string, return_nearest_color=False):
31
  r"""
32
  Covert Hex triplet to RGB triplet.
33
  """
 
40
  rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41
  if return_nearest_color:
42
  nearest_color = find_nearest_color(rgb)
43
+ return rgb.cuda(), nearest_color
44
+ return rgb.cuda()
45
 
46
 
47
  def find_nearest_color(rgb):
 
56
  return nearest_color
57
 
58
 
59
+ def font2style(font):
60
  r"""
61
  Convert the font name to the style name.
62
  """
 
71
  'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72
 
73
 
74
+ def parse_json(json_str):
75
  r"""
76
  Convert the JSON string to attributes.
77
  """
 
121
  if 'color' in span['attributes']:
122
  use_grad_guidance = True
123
  color_rgb, nearest_color = hex_to_rgb(
124
+ span['attributes']['color'], True)
125
  if prev_color_rgb == color_rgb:
126
  prev_text_prompt = color_text_prompts[-1]
127
  color_text_prompts[-1] = prev_text_prompt + \
 
197
  word_pos.append(base_tokens.index(size_token)+1)
198
  font_sizes.append(font_size)
199
  if len(word_pos) > 0:
200
+ word_pos = torch.LongTensor(word_pos).cuda()
201
+ font_sizes = torch.FloatTensor(font_sizes).cuda()
202
  else:
203
  word_pos = None
204
  font_sizes = None