songweig commited on
Commit
ea48617
1 Parent(s): 4d49165

token map v3

Browse files
Files changed (3) hide show
  1. app.py +71 -69
  2. models/region_diffusion.py +15 -28
  3. utils/attention_utils.py +29 -17
app.py CHANGED
@@ -29,10 +29,10 @@ If you are encountering an error or not achieving your desired outcome, here are
29
 
30
  canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
31
  get_js_data = """
32
- async (text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, rich_text_input, background_aug) => {
33
  const richEl = document.getElementById("rich-text-root");
34
  const data = richEl? richEl.contentDocument.body._data : {};
35
- return [text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, JSON.stringify(data), background_aug];
36
  }
37
  """
38
  set_js_data = """
@@ -66,27 +66,27 @@ def main():
66
  def generate(
67
  text_input: str,
68
  negative_text: str,
69
- height: int,
70
- width: int,
71
- seed: int,
72
- steps: int,
73
  num_segments: int,
74
  segment_threshold: float,
75
  inject_interval: float,
76
- guidance_weight: float,
 
77
  color_guidance_weight: float,
78
  rich_text_input: str,
79
- background_aug: bool,
 
 
 
80
  ):
81
  run_dir = 'results/'
82
  os.makedirs(run_dir, exist_ok=True)
83
  # Load region diffusion model.
84
- height = int(height)
85
- width = int(width)
86
  steps = 41 if not steps else steps
87
  guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
- text_input = rich_text_input if rich_text_input != '' else text_input
89
- print('text_input', text_input)
90
  if (text_input == '' or rich_text_input == ''):
91
  raise gr.Error("Please enter some text.")
92
  # parse json to span attributes
@@ -132,25 +132,25 @@ def main():
132
  512//8, 512//8, region_target_token_ids[:-1], seed,
133
  base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
  return_vis=True)
 
 
 
135
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
136
  interpolation=transforms.InterpolationMode.BICUBIC,
137
  antialias=True)
138
  for color_obj_mask in color_obj_masks]
139
  text_format_dict['color_obj_atten'] = color_obj_masks
 
140
  model.remove_tokenmap_hooks()
141
 
142
  # generate image from rich text
143
  begin_time = time.time()
144
  seed_everything(seed)
145
- if background_aug:
146
- bg_aug_end = 500
147
- else:
148
- bg_aug_end = 1000
149
  rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
150
  height=height, width=width, num_inference_steps=steps,
151
  guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
  text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
- bg_aug_end=bg_aug_end)
154
  print('time lapses to generate image from rich text: %.4f' %
155
  (time.time()-begin_time))
156
  return [plain_img[0], rich_img[0], segments_vis, token_maps]
@@ -191,6 +191,12 @@ def main():
191
  maximum=1,
192
  step=0.01,
193
  value=0.)
 
 
 
 
 
 
194
  color_guidance_weight = gr.Slider(label='Color weight',
195
  info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
196
  minimum=0,
@@ -209,10 +215,6 @@ def main():
209
  value=6,
210
  elem_id="seed"
211
  )
212
- background_aug = gr.Checkbox(
213
- label='Precise region alignment',
214
- info='(For strict region alignment, select this option, but beware of potential artifacts when using with style.)',
215
- value=True)
216
  with gr.Accordion('Other Parameters', open=False):
217
  steps = gr.Slider(label='Number of Steps',
218
  minimum=0,
@@ -266,32 +268,32 @@ def main():
266
  5,
267
  0.3,
268
  0,
 
269
  6,
270
- 1,
271
  None,
272
- True
273
  ],
274
  [
275
- '{"ops":[{"insert":"A "},{"attributes":{"link":"kitchen island with a stove with gas burners and a built-in oven "},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
276
  '',
277
- 6,
278
  0.5,
279
  0,
 
280
  6,
281
- 1,
282
  None,
283
- True
284
  ],
285
  [
286
  '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
287
  '',
288
- 4,
289
  0.3,
290
  0,
 
291
  4,
292
- 1,
293
  None,
294
- True
295
  ],
296
  ]
297
 
@@ -303,10 +305,10 @@ def main():
303
  num_segments,
304
  segment_threshold,
305
  inject_interval,
 
306
  seed,
307
  color_guidance_weight,
308
  rich_text_input,
309
- background_aug,
310
  ],
311
  outputs=[
312
  plaintext_result,
@@ -315,42 +317,42 @@ def main():
315
  token_map,
316
  ],
317
  fn=generate,
318
- # cache_examples=True,
319
  examples_per_page=20)
320
  with gr.Row():
321
  color_examples = [
322
  [
323
- '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#00ffff"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
324
  'lowres, had anatomy, bad hands, cropped, worst quality',
325
- 9,
326
- 0.25,
 
327
  0.3,
328
  6,
329
  0.5,
330
  None,
331
- True
332
  ],
333
  [
334
- '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#eeeeee"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
335
  'lowres, had anatomy, bad hands, cropped, worst quality',
336
- 9,
337
- 0.25,
 
338
  0.3,
339
  6,
340
- 0.1,
341
  None,
342
- True
343
  ],
344
  [
345
  '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
346
  '',
347
- 5,
348
- 0.3,
349
  0.5,
 
350
  6,
351
  0.5,
352
  None,
353
- False
354
  ],
355
  [
356
  '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
@@ -358,21 +360,21 @@ def main():
358
  3,
359
  0.3,
360
  0,
 
361
  9,
362
  1,
363
  None,
364
- False
365
  ],
366
  [
367
  '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
368
  '',
369
  5,
 
 
370
  0.3,
371
- 0,
372
  5,
373
  0.6,
374
  None,
375
- False
376
  ],
377
  ]
378
  gr.Examples(examples=color_examples,
@@ -383,10 +385,10 @@ def main():
383
  num_segments,
384
  segment_threshold,
385
  inject_interval,
 
386
  seed,
387
  color_guidance_weight,
388
  rich_text_input,
389
- background_aug,
390
  ],
391
  outputs=[
392
  plaintext_result,
@@ -395,7 +397,7 @@ def main():
395
  token_map,
396
  ],
397
  fn=generate,
398
- # cache_examples=True,
399
  examples_per_page=20)
400
 
401
  with gr.Row():
@@ -403,13 +405,13 @@ def main():
403
  [
404
  '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
405
  '',
406
- 5,
407
- 0.3,
 
408
  0.2,
409
  3,
410
- 0.5,
411
  None,
412
- False
413
  ],
414
  [
415
  '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
@@ -417,21 +419,21 @@ def main():
417
  5,
418
  0.3,
419
  0,
 
420
  9,
421
  0.5,
422
  None,
423
- False
424
  ],
425
  [
426
  '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
427
  '',
428
  2,
429
- 0.4,
 
430
  0,
431
  6,
432
  0.5,
433
  None,
434
- False
435
  ],
436
  ]
437
  gr.Examples(examples=style_examples,
@@ -442,10 +444,10 @@ def main():
442
  num_segments,
443
  segment_threshold,
444
  inject_interval,
 
445
  seed,
446
  color_guidance_weight,
447
  rich_text_input,
448
- background_aug,
449
  ],
450
  outputs=[
451
  plaintext_result,
@@ -454,7 +456,7 @@ def main():
454
  token_map,
455
  ],
456
  fn=generate,
457
- # cache_examples=True,
458
  examples_per_page=20)
459
 
460
  with gr.Row():
@@ -465,10 +467,10 @@ def main():
465
  5,
466
  0.3,
467
  0,
 
468
  13,
469
  1,
470
  None,
471
- False
472
  ],
473
  [
474
  '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
@@ -476,10 +478,10 @@ def main():
476
  5,
477
  0.3,
478
  0,
 
479
  13,
480
  1,
481
  None,
482
- False
483
  ],
484
  [
485
  '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
@@ -487,10 +489,10 @@ def main():
487
  5,
488
  0.3,
489
  0,
 
490
  13,
491
  1,
492
  None,
493
- False
494
  ],
495
  ]
496
  gr.Examples(examples=size_examples,
@@ -501,10 +503,10 @@ def main():
501
  num_segments,
502
  segment_threshold,
503
  inject_interval,
 
504
  seed,
505
  color_guidance_weight,
506
  rich_text_input,
507
- background_aug,
508
  ],
509
  outputs=[
510
  plaintext_result,
@@ -513,24 +515,24 @@ def main():
513
  token_map,
514
  ],
515
  fn=generate,
516
- # cache_examples=True,
517
  examples_per_page=20)
518
  generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
519
  fn=generate,
520
  inputs=[
521
  text_input,
522
  negative_prompt,
523
- height,
524
- width,
525
- seed,
526
- steps,
527
  num_segments,
528
  segment_threshold,
529
  inject_interval,
530
- guidance_weight,
 
531
  color_guidance_weight,
532
  rich_text_input,
533
- background_aug
 
 
 
534
  ],
535
  outputs=[plaintext_result, richtext_result, segments, token_map],
536
  _js=get_js_data
 
29
 
30
  canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
31
  get_js_data = """
32
+ async (text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, rich_text_input, height, width, steps, guidance_weights) => {
33
  const richEl = document.getElementById("rich-text-root");
34
  const data = richEl? richEl.contentDocument.body._data : {};
35
+ return [text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, JSON.stringify(data), height, width, steps, guidance_weights];
36
  }
37
  """
38
  set_js_data = """
 
66
  def generate(
67
  text_input: str,
68
  negative_text: str,
 
 
 
 
69
  num_segments: int,
70
  segment_threshold: float,
71
  inject_interval: float,
72
+ inject_background: float,
73
+ seed: int,
74
  color_guidance_weight: float,
75
  rich_text_input: str,
76
+ height: int,
77
+ width: int,
78
+ steps: int,
79
+ guidance_weight: float,
80
  ):
81
  run_dir = 'results/'
82
  os.makedirs(run_dir, exist_ok=True)
83
  # Load region diffusion model.
84
+ height = int(height) if height else 512
85
+ width = int(width) if width else 512
86
  steps = 41 if not steps else steps
87
  guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
+ text_input = rich_text_input if rich_text_input != '' and rich_text_input != None else text_input
89
+ print('text_input', text_input, width, height, steps, guidance_weight, num_segments, segment_threshold, inject_interval, inject_background, color_guidance_weight, negative_text)
90
  if (text_input == '' or rich_text_input == ''):
91
  raise gr.Error("Please enter some text.")
92
  # parse json to span attributes
 
132
  512//8, 512//8, region_target_token_ids[:-1], seed,
133
  base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
  return_vis=True)
135
+ color_obj_atten_all = torch.zeros_like(color_obj_masks[-1])
136
+ for obj_mask in color_obj_masks[:-1]:
137
+ color_obj_atten_all += obj_mask
138
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
139
  interpolation=transforms.InterpolationMode.BICUBIC,
140
  antialias=True)
141
  for color_obj_mask in color_obj_masks]
142
  text_format_dict['color_obj_atten'] = color_obj_masks
143
+ text_format_dict['color_obj_atten_all'] = color_obj_atten_all
144
  model.remove_tokenmap_hooks()
145
 
146
  # generate image from rich text
147
  begin_time = time.time()
148
  seed_everything(seed)
 
 
 
 
149
  rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
150
  height=height, width=width, num_inference_steps=steps,
151
  guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
  text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
+ inject_background=inject_background)
154
  print('time lapses to generate image from rich text: %.4f' %
155
  (time.time()-begin_time))
156
  return [plain_img[0], rich_img[0], segments_vis, token_maps]
 
191
  maximum=1,
192
  step=0.01,
193
  value=0.)
194
+ inject_background = gr.Slider(label='Unformatted token preservation',
195
+ info='(To affect less the tokens without any rich-text attributes, increase this.)',
196
+ minimum=0,
197
+ maximum=1,
198
+ step=0.01,
199
+ value=0.3)
200
  color_guidance_weight = gr.Slider(label='Color weight',
201
  info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
202
  minimum=0,
 
215
  value=6,
216
  elem_id="seed"
217
  )
 
 
 
 
218
  with gr.Accordion('Other Parameters', open=False):
219
  steps = gr.Slider(label='Number of Steps',
220
  minimum=0,
 
268
  5,
269
  0.3,
270
  0,
271
+ 0.5,
272
  6,
273
+ 0,
274
  None,
 
275
  ],
276
  [
277
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"Thor Kitchen 30 Inch Wide Freestanding Gas Range with Automatic Re-Ignition System"},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
278
  '',
279
+ 7,
280
  0.5,
281
  0,
282
+ 0.5,
283
  6,
284
+ 0,
285
  None,
 
286
  ],
287
  [
288
  '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
289
  '',
290
+ 5,
291
  0.3,
292
  0,
293
+ 0.1,
294
  4,
295
+ 0,
296
  None,
 
297
  ],
298
  ]
299
 
 
305
  num_segments,
306
  segment_threshold,
307
  inject_interval,
308
+ inject_background,
309
  seed,
310
  color_guidance_weight,
311
  rich_text_input,
 
312
  ],
313
  outputs=[
314
  plaintext_result,
 
317
  token_map,
318
  ],
319
  fn=generate,
320
+ cache_examples=True,
321
  examples_per_page=20)
322
  with gr.Row():
323
  color_examples = [
324
  [
325
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
326
  'lowres, had anatomy, bad hands, cropped, worst quality',
327
+ 11,
328
+ 0.3,
329
+ 0.3,
330
  0.3,
331
  6,
332
  0.5,
333
  None,
 
334
  ],
335
  [
336
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
337
  'lowres, had anatomy, bad hands, cropped, worst quality',
338
+ 11,
339
+ 0.3,
340
+ 0.3,
341
  0.3,
342
  6,
343
+ 0.5,
344
  None,
 
345
  ],
346
  [
347
  '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
348
  '',
349
+ 10,
350
+ 0.4,
351
  0.5,
352
+ 0.3,
353
  6,
354
  0.5,
355
  None,
 
356
  ],
357
  [
358
  '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
 
360
  3,
361
  0.3,
362
  0,
363
+ 0,
364
  9,
365
  1,
366
  None,
 
367
  ],
368
  [
369
  '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
370
  '',
371
  5,
372
+ 0.4,
373
+ 0.3,
374
  0.3,
 
375
  5,
376
  0.6,
377
  None,
 
378
  ],
379
  ]
380
  gr.Examples(examples=color_examples,
 
385
  num_segments,
386
  segment_threshold,
387
  inject_interval,
388
+ inject_background,
389
  seed,
390
  color_guidance_weight,
391
  rich_text_input,
 
392
  ],
393
  outputs=[
394
  plaintext_result,
 
397
  token_map,
398
  ],
399
  fn=generate,
400
+ cache_examples=True,
401
  examples_per_page=20)
402
 
403
  with gr.Row():
 
405
  [
406
  '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
407
  '',
408
+ 10,
409
+ 0.4,
410
+ 0,
411
  0.2,
412
  3,
413
+ 0,
414
  None,
 
415
  ],
416
  [
417
  '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
 
419
  5,
420
  0.3,
421
  0,
422
+ 0,
423
  9,
424
  0.5,
425
  None,
 
426
  ],
427
  [
428
  '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
429
  '',
430
  2,
431
+ 0.35,
432
+ 0,
433
  0,
434
  6,
435
  0.5,
436
  None,
 
437
  ],
438
  ]
439
  gr.Examples(examples=style_examples,
 
444
  num_segments,
445
  segment_threshold,
446
  inject_interval,
447
+ inject_background,
448
  seed,
449
  color_guidance_weight,
450
  rich_text_input,
 
451
  ],
452
  outputs=[
453
  plaintext_result,
 
456
  token_map,
457
  ],
458
  fn=generate,
459
+ cache_examples=True,
460
  examples_per_page=20)
461
 
462
  with gr.Row():
 
467
  5,
468
  0.3,
469
  0,
470
+ 0,
471
  13,
472
  1,
473
  None,
 
474
  ],
475
  [
476
  '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
 
478
  5,
479
  0.3,
480
  0,
481
+ 0,
482
  13,
483
  1,
484
  None,
 
485
  ],
486
  [
487
  '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
 
489
  5,
490
  0.3,
491
  0,
492
+ 0,
493
  13,
494
  1,
495
  None,
 
496
  ],
497
  ]
498
  gr.Examples(examples=size_examples,
 
503
  num_segments,
504
  segment_threshold,
505
  inject_interval,
506
+ inject_background,
507
  seed,
508
  color_guidance_weight,
509
  rich_text_input,
 
510
  ],
511
  outputs=[
512
  plaintext_result,
 
515
  token_map,
516
  ],
517
  fn=generate,
518
+ cache_examples=True,
519
  examples_per_page=20)
520
  generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
521
  fn=generate,
522
  inputs=[
523
  text_input,
524
  negative_prompt,
 
 
 
 
525
  num_segments,
526
  segment_threshold,
527
  inject_interval,
528
+ inject_background,
529
+ seed,
530
  color_guidance_weight,
531
  rich_text_input,
532
+ height,
533
+ width,
534
+ steps,
535
+ guidance_weight,
536
  ],
537
  outputs=[plaintext_result, richtext_result, segments, token_map],
538
  _js=get_js_data
models/region_diffusion.py CHANGED
@@ -84,13 +84,13 @@ class RegionDiffusion(nn.Module):
84
  return text_embeddings
85
 
86
  def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
87
- latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, bg_aug_end=1000):
88
 
89
  if latents is None:
90
  latents = torch.randn(
91
  (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
92
 
93
- if inject_selfattn > 0:
94
  latents_reference = latents.clone().detach()
95
  self.scheduler.set_timesteps(num_inference_steps)
96
  n_styles = text_embeddings.shape[0]-1
@@ -102,11 +102,12 @@ class RegionDiffusion(nn.Module):
102
  with torch.no_grad():
103
  # tokens without any attributes
104
  feat_inject_step = t > (1-inject_selfattn) * 1000
 
105
  noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
106
- text_format_dict={})['sample']
107
  noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
108
  text_format_dict=text_format_dict)['sample']
109
- if inject_selfattn > 0:
110
  noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
111
  text_format_dict={})['sample']
112
  self.register_selfattn_hooks(feat_inject_step)
@@ -117,33 +118,18 @@ class RegionDiffusion(nn.Module):
117
  noise_pred_text = noise_pred_text_cur * self.masks[-1]
118
  # tokens with attributes
119
  for style_i, mask in enumerate(self.masks[:-1]):
120
- if t > bg_aug_end:
121
- rand_rgb = torch.rand([1, 3, 1, 1]).cuda()
122
- black_background = torch.ones(
123
- [1, 3, height, width]).cuda()*rand_rgb
124
- black_latent = self.encode_imgs(
125
- black_background)
126
- noise = torch.randn_like(black_latent)
127
- black_latent_noisy = self.scheduler.add_noise(
128
- black_latent, noise, t)
129
- masked_latent = (
130
- mask > 0.001) * latents + (mask < 0.001) * black_latent_noisy
131
- noise_pred_uncond_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[:1],
132
- text_format_dict={})['sample']
133
- else:
134
- masked_latent = latents
135
  self.register_replacement_hooks(feat_inject_step)
136
- noise_pred_text_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
137
  text_format_dict={})['sample']
138
  self.remove_replacement_hooks()
139
  noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
140
  noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
141
-
142
  # perform classifier-free guidance
143
  noise_pred = noise_pred_uncond + guidance_scale * \
144
  (noise_pred_text - noise_pred_uncond)
145
 
146
- if inject_selfattn > 0:
147
  noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
148
  (noise_pred_text_refer - noise_pred_uncond_refer)
149
 
@@ -174,12 +160,15 @@ class RegionDiffusion(nn.Module):
174
  imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
175
  loss = self.color_loss(
176
  avg_rgb, rgb_val[:, :, 0, 0])*100
177
- # print(loss)
178
  loss_total += loss
179
  loss_total.backward()
180
  latents = (
181
- latents - latents.grad * text_format_dict['color_guidance_weight'] * self.masks[0]).detach().clone()
182
 
 
 
 
 
183
  return latents
184
 
185
  def predict_x0(self, x_t, eps_t, t):
@@ -255,7 +244,7 @@ class RegionDiffusion(nn.Module):
255
  return latents
256
 
257
  def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
258
- guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, bg_aug_end=1000):
259
 
260
  if isinstance(prompts, str):
261
  prompts = [prompts]
@@ -271,7 +260,7 @@ class RegionDiffusion(nn.Module):
271
  latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
272
  num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
273
  use_guidance=use_guidance, text_format_dict=text_format_dict,
274
- inject_selfattn=inject_selfattn, bg_aug_end=bg_aug_end) # [1, 4, 64, 64]
275
  # Img latents -> imgs
276
  imgs = self.decode_latents(latents) # [1, 3, 512, 512]
277
 
@@ -345,8 +334,6 @@ class RegionDiffusion(nn.Module):
345
  """
346
  # out[0] - final output of residual layer
347
  # out[1] - residual hidden feature
348
- # import ipdb
349
- # ipdb.set_trace()
350
  assert out[1].shape[-1] == 16
351
  activations[name] = out[1].detach()
352
  attention_dict = collections.defaultdict(list)
 
84
  return text_embeddings
85
 
86
  def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
87
+ latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, inject_background=0):
88
 
89
  if latents is None:
90
  latents = torch.randn(
91
  (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
92
 
93
+ if inject_selfattn > 0 or inject_background > 0:
94
  latents_reference = latents.clone().detach()
95
  self.scheduler.set_timesteps(num_inference_steps)
96
  n_styles = text_embeddings.shape[0]-1
 
102
  with torch.no_grad():
103
  # tokens without any attributes
104
  feat_inject_step = t > (1-inject_selfattn) * 1000
105
+ background_inject_step = i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0
106
  noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
107
+ text_format_dict={})['sample']
108
  noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
109
  text_format_dict=text_format_dict)['sample']
110
+ if inject_selfattn > 0 or inject_background > 0:
111
  noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
112
  text_format_dict={})['sample']
113
  self.register_selfattn_hooks(feat_inject_step)
 
118
  noise_pred_text = noise_pred_text_cur * self.masks[-1]
119
  # tokens with attributes
120
  for style_i, mask in enumerate(self.masks[:-1]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  self.register_replacement_hooks(feat_inject_step)
122
+ noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
123
  text_format_dict={})['sample']
124
  self.remove_replacement_hooks()
125
  noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
126
  noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
127
+
128
  # perform classifier-free guidance
129
  noise_pred = noise_pred_uncond + guidance_scale * \
130
  (noise_pred_text - noise_pred_uncond)
131
 
132
+ if inject_selfattn > 0 or inject_background > 0:
133
  noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
134
  (noise_pred_text_refer - noise_pred_uncond_refer)
135
 
 
160
  imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
161
  loss = self.color_loss(
162
  avg_rgb, rgb_val[:, :, 0, 0])*100
 
163
  loss_total += loss
164
  loss_total.backward()
165
  latents = (
166
+ latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone()
167
 
168
+ # apply background injection
169
+ if background_inject_step:
170
+ latents = latents_reference * self.masks[-1] + latents * \
171
+ (1-self.masks[-1])
172
  return latents
173
 
174
  def predict_x0(self, x_t, eps_t, t):
 
244
  return latents
245
 
246
  def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
247
+ guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, inject_background=0):
248
 
249
  if isinstance(prompts, str):
250
  prompts = [prompts]
 
260
  latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
261
  num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
262
  use_guidance=use_guidance, text_format_dict=text_format_dict,
263
+ inject_selfattn=inject_selfattn, inject_background=inject_background) # [1, 4, 64, 64]
264
  # Img latents -> imgs
265
  imgs = self.decode_latents(latents) # [1, 3, 512, 512]
266
 
 
334
  """
335
  # out[0] - final output of residual layer
336
  # out[1] - residual hidden feature
 
 
337
  assert out[1].shape[-1] == 16
338
  activations[name] = out[1].detach()
339
  attention_dict = collections.defaultdict(list)
utils/attention_utils.py CHANGED
@@ -6,25 +6,26 @@ import seaborn as sns
6
  import torch
7
  import torchvision
8
 
9
- from sklearn.cluster import KMeans
 
10
 
11
  SelfAttentionLayers = [
12
- # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
13
- # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
14
  'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
15
- # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
16
  'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
17
  'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
18
  'mid_block.attentions.0.transformer_blocks.0.attn1',
19
  'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
20
  'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
21
  'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
22
- # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
23
  'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
24
- # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
25
- # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
26
- # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
27
- # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
28
  ]
29
 
30
 
@@ -208,8 +209,8 @@ def get_token_maps_deprecated(attention_maps, save_dir, width, height, obj_token
208
  return attention_maps_averaged_normalized, token_maps_vis
209
 
210
 
211
- def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None,
212
- preprocess=False, segment_threshold=0.30, num_segments=9, return_vis=False):
213
  r"""Function to visualize attention maps.
214
  Args:
215
  save_dir (str): Path to save attention maps
@@ -219,9 +220,11 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
219
 
220
  # create the segmentation mask using self-attention maps
221
  resolution = 32
222
- attn_maps_1024 = {8: [], 16: [], 32: []}
223
  for attn_map in selfattn_maps.values():
224
  resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
 
 
225
  attn_map = attn_map.reshape(
226
  1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2])
227
  attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
@@ -229,10 +232,15 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
229
  attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
230
  1, resolution**2, resolution_map**2))
231
  attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
232
- for v in attn_maps_1024.values()], -1).numpy()
233
- kmeans = KMeans(n_clusters=num_segments,
234
- n_init=10).fit(attn_maps_1024)
235
- clusters = kmeans.labels_
 
 
 
 
 
236
  clusters = clusters.reshape(resolution, resolution)
237
  fig = plt.figure()
238
  plt.imshow(clusters)
@@ -258,6 +266,10 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
258
 
259
  cross_attn_maps_1024 = torch.cat(
260
  cross_attn_maps_1024).mean(0).cpu().numpy()
 
 
 
 
261
  normalized_span_maps = []
262
  for token_ids in obj_tokens:
263
  span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
@@ -297,7 +309,7 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
297
  foreground_token_maps = [token_map[None, :, :]
298
  for token_map in foreground_token_maps]
299
  token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
300
- save_dir, kmeans_seed, tokens_vis)
301
  resized_token_maps = [token_map.unsqueeze(1).repeat(
302
  [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
303
  if return_vis:
 
6
  import torch
7
  import torchvision
8
 
9
+ from utils.richtext_utils import seed_everything
10
+ from sklearn.cluster import SpectralClustering
11
 
12
  SelfAttentionLayers = [
13
+ 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
14
+ 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
15
  'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
16
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
17
  'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
18
  'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
19
  'mid_block.attentions.0.transformer_blocks.0.attn1',
20
  'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
21
  'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
22
  'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
23
+ 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
24
  'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
25
+ 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
26
+ 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
27
+ 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
28
+ 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
29
  ]
30
 
31
 
 
209
  return attention_maps_averaged_normalized, token_maps_vis
210
 
211
 
212
+ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
213
+ preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
214
  r"""Function to visualize attention maps.
215
  Args:
216
  save_dir (str): Path to save attention maps
 
220
 
221
  # create the segmentation mask using self-attention maps
222
  resolution = 32
223
+ attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
224
  for attn_map in selfattn_maps.values():
225
  resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
226
+ if resolution_map != resolution:
227
+ continue
228
  attn_map = attn_map.reshape(
229
  1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2])
230
  attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
 
232
  attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
233
  1, resolution**2, resolution_map**2))
234
  attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
235
+ for v in attn_maps_1024.values() if len(v) > 0], -1).numpy()
236
+ if save_attn:
237
+ print('saving self-attention maps...', attn_maps_1024.shape)
238
+ torch.save(torch.from_numpy(attn_maps_1024),
239
+ 'results/maps/selfattn_maps.pth')
240
+ seed_everything(seed)
241
+ sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
242
+ assign_labels='kmeans')
243
+ clusters = sc.fit_predict(attn_maps_1024)
244
  clusters = clusters.reshape(resolution, resolution)
245
  fig = plt.figure()
246
  plt.imshow(clusters)
 
266
 
267
  cross_attn_maps_1024 = torch.cat(
268
  cross_attn_maps_1024).mean(0).cpu().numpy()
269
+ if save_attn:
270
+ print('saving cross-attention maps...', cross_attn_maps_1024.shape)
271
+ torch.save(torch.from_numpy(cross_attn_maps_1024),
272
+ 'results/maps/crossattn_maps.pth')
273
  normalized_span_maps = []
274
  for token_ids in obj_tokens:
275
  span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
 
309
  foreground_token_maps = [token_map[None, :, :]
310
  for token_map in foreground_token_maps]
311
  token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
312
+ save_dir, seed, tokens_vis)
313
  resized_token_maps = [token_map.unsqueeze(1).repeat(
314
  [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
315
  if return_vis: