songweig commited on
Commit
d908d7c
β€’
1 Parent(s): bdf1746

reduce memory footprint

Browse files
Files changed (2) hide show
  1. app.py +161 -161
  2. models/region_diffusion_xl.py +11 -6
app.py CHANGED
@@ -260,45 +260,114 @@ def main():
260
  with gr.Row():
261
  gr.Markdown(help_text)
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  with gr.Row():
264
- footnote_examples = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  [
266
- '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
267
- '',
268
- 9,
 
269
  0.3,
270
  0.3,
 
271
  0.5,
272
- 3,
273
- 0,
274
  None,
275
  ],
276
  [
277
- '{"ops":[{"insert":"A cozy "},{"attributes":{"link":"A charming wooden cabin with Christmas decoration, warm light coming out from the windows."},"insert":"cabin"},{"insert":" nestled in a "},{"attributes":{"link":"Towering evergreen trees covered in a thick layer of pristine snow."},"insert":"snowy forest"},{"insert":", and a "},{"attributes":{"link":"A cute snowman wearing a carrot nose, coal eyes, and a colorful scarf, welcoming visitors with a cheerful vibe."},"insert":"snowman"},{"insert":" stands in the yard."}]}',
278
  '',
279
- 12,
280
- 0.4,
281
- 0.3,
282
  0.5,
283
- 3,
284
- 0,
285
- None,
286
- ],
287
- [
288
- '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
289
- '',
290
- 5,
291
  0.3,
292
- 0,
293
- 0.1,
294
- 4,
295
- 0,
296
  None,
297
  ],
298
  ]
299
-
300
- gr.Examples(examples=footnote_examples,
301
- label='Footnote examples',
302
  inputs=[
303
  text_input,
304
  negative_prompt,
@@ -319,55 +388,93 @@ def main():
319
  fn=generate,
320
  cache_examples=True,
321
  examples_per_page=20)
 
322
  # with gr.Row():
323
- # color_examples = [
324
  # [
325
- # '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
326
- # 'lowres, had anatomy, bad hands, cropped, worst quality',
327
- # 11,
328
- # 0.5,
329
- # 0.3,
330
- # 0.3,
331
- # 6,
332
- # 0.5,
333
  # None,
334
  # ],
335
  # [
336
- # '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#ff5df1"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
337
- # 'lowres, had anatomy, bad hands, cropped, worst quality',
338
- # 11,
339
- # 0.5,
340
- # 0.3,
341
- # 0.3,
342
  # 6,
343
  # 0.5,
344
  # None,
345
  # ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  # [
347
- # '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
348
- # 'lowres, had anatomy, bad hands, cropped, worst quality',
349
- # 11,
350
- # 0.5,
351
  # 0.3,
 
 
 
 
 
 
 
 
 
 
352
  # 0.3,
353
- # 6,
354
- # 0.5,
 
 
355
  # None,
356
  # ],
357
  # [
358
- # '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
359
  # '',
360
- # 10,
361
- # 0.5,
362
- # 0.5,
363
  # 0.3,
364
- # 7,
365
- # 0.5,
 
 
366
  # None,
367
  # ],
368
  # ]
369
- # gr.Examples(examples=color_examples,
370
- # label='Font color examples',
371
  # inputs=[
372
  # text_input,
373
  # negative_prompt,
@@ -388,113 +495,6 @@ def main():
388
  # fn=generate,
389
  # cache_examples=True,
390
  # examples_per_page=20)
391
-
392
- with gr.Row():
393
- style_examples = [
394
- [
395
- '{"ops":[{"insert":"a beautiful"},{"attributes":{"font":"mirza"},"insert":" garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain"},{"insert":" in the background"}]}',
396
- '',
397
- 10,
398
- 0.6,
399
- 0,
400
- 0.4,
401
- 5,
402
- 0,
403
- None,
404
- ],
405
- [
406
- '{"ops":[{"insert":"a night"},{"attributes":{"font":"slabo"},"insert":" sky"},{"insert":" filled with stars above a turbulent"},{"attributes":{"font":"roboto"},"insert":" sea"},{"insert":" with giant waves"}]}',
407
- '',
408
- 2,
409
- 0.6,
410
- 0,
411
- 0,
412
- 6,
413
- 0.5,
414
- None,
415
- ],
416
- ]
417
- gr.Examples(examples=style_examples,
418
- label='Font style examples',
419
- inputs=[
420
- text_input,
421
- negative_prompt,
422
- num_segments,
423
- segment_threshold,
424
- inject_interval,
425
- inject_background,
426
- seed,
427
- color_guidance_weight,
428
- rich_text_input,
429
- ],
430
- outputs=[
431
- plaintext_result,
432
- richtext_result,
433
- segments,
434
- token_map,
435
- ],
436
- fn=generate,
437
- cache_examples=True,
438
- examples_per_page=20)
439
-
440
- with gr.Row():
441
- size_examples = [
442
- [
443
- '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": " pepperoni, and mushroom on the top"}]}',
444
- '',
445
- 5,
446
- 0.3,
447
- 0,
448
- 0,
449
- 3,
450
- 1,
451
- None,
452
- ],
453
- [
454
- '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "60px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top"}]}',
455
- '',
456
- 5,
457
- 0.3,
458
- 0,
459
- 0,
460
- 3,
461
- 1,
462
- None,
463
- ],
464
- [
465
- '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "60px"}, "insert": "mushroom"}, {"insert": " on the top"}]}',
466
- '',
467
- 5,
468
- 0.3,
469
- 0,
470
- 0,
471
- 3,
472
- 1,
473
- None,
474
- ],
475
- ]
476
- gr.Examples(examples=size_examples,
477
- label='Font size examples',
478
- inputs=[
479
- text_input,
480
- negative_prompt,
481
- num_segments,
482
- segment_threshold,
483
- inject_interval,
484
- inject_background,
485
- seed,
486
- color_guidance_weight,
487
- rich_text_input,
488
- ],
489
- outputs=[
490
- plaintext_result,
491
- richtext_result,
492
- segments,
493
- token_map,
494
- ],
495
- fn=generate,
496
- cache_examples=True,
497
- examples_per_page=20)
498
  generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
499
  fn=generate,
500
  inputs=[
 
260
  with gr.Row():
261
  gr.Markdown(help_text)
262
 
263
+ # with gr.Row():
264
+ # footnote_examples = [
265
+ # [
266
+ # '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
267
+ # '',
268
+ # 9,
269
+ # 0.3,
270
+ # 0.3,
271
+ # 0.5,
272
+ # 3,
273
+ # 0,
274
+ # None,
275
+ # ],
276
+ # [
277
+ # '{"ops":[{"insert":"A cozy "},{"attributes":{"link":"A charming wooden cabin with Christmas decoration, warm light coming out from the windows."},"insert":"cabin"},{"insert":" nestled in a "},{"attributes":{"link":"Towering evergreen trees covered in a thick layer of pristine snow."},"insert":"snowy forest"},{"insert":", and a "},{"attributes":{"link":"A cute snowman wearing a carrot nose, coal eyes, and a colorful scarf, welcoming visitors with a cheerful vibe."},"insert":"snowman"},{"insert":" stands in the yard."}]}',
278
+ # '',
279
+ # 12,
280
+ # 0.4,
281
+ # 0.3,
282
+ # 0.5,
283
+ # 3,
284
+ # 0,
285
+ # None,
286
+ # ],
287
+ # [
288
+ # '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
289
+ # '',
290
+ # 5,
291
+ # 0.3,
292
+ # 0,
293
+ # 0.1,
294
+ # 4,
295
+ # 0,
296
+ # None,
297
+ # ],
298
+ # ]
299
+
300
+ # gr.Examples(examples=footnote_examples,
301
+ # label='Footnote examples',
302
+ # inputs=[
303
+ # text_input,
304
+ # negative_prompt,
305
+ # num_segments,
306
+ # segment_threshold,
307
+ # inject_interval,
308
+ # inject_background,
309
+ # seed,
310
+ # color_guidance_weight,
311
+ # rich_text_input,
312
+ # ],
313
+ # outputs=[
314
+ # plaintext_result,
315
+ # richtext_result,
316
+ # segments,
317
+ # token_map,
318
+ # ],
319
+ # fn=generate,
320
+ # cache_examples=True,
321
+ # examples_per_page=20)
322
  with gr.Row():
323
+ color_examples = [
324
+ # [
325
+ # '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
326
+ # 'lowres, had anatomy, bad hands, cropped, worst quality',
327
+ # 11,
328
+ # 0.5,
329
+ # 0.3,
330
+ # 0.3,
331
+ # 6,
332
+ # 0.5,
333
+ # None,
334
+ # ],
335
+ # [
336
+ # '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#ff5df1"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
337
+ # 'lowres, had anatomy, bad hands, cropped, worst quality',
338
+ # 11,
339
+ # 0.5,
340
+ # 0.3,
341
+ # 0.3,
342
+ # 6,
343
+ # 0.5,
344
+ # None,
345
+ # ],
346
  [
347
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
348
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
349
+ 11,
350
+ 0.5,
351
  0.3,
352
  0.3,
353
+ 6,
354
  0.5,
 
 
355
  None,
356
  ],
357
  [
358
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
359
  '',
360
+ 10,
361
+ 0.5,
 
362
  0.5,
 
 
 
 
 
 
 
 
363
  0.3,
364
+ 7,
365
+ 0.5,
 
 
366
  None,
367
  ],
368
  ]
369
+ gr.Examples(examples=color_examples,
370
+ label='Font color examples',
 
371
  inputs=[
372
  text_input,
373
  negative_prompt,
 
388
  fn=generate,
389
  cache_examples=True,
390
  examples_per_page=20)
391
+
392
  # with gr.Row():
393
+ # style_examples = [
394
  # [
395
+ # '{"ops":[{"insert":"a beautiful"},{"attributes":{"font":"mirza"},"insert":" garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain"},{"insert":" in the background"}]}',
396
+ # '',
397
+ # 10,
398
+ # 0.6,
399
+ # 0,
400
+ # 0.4,
401
+ # 5,
402
+ # 0,
403
  # None,
404
  # ],
405
  # [
406
+ # '{"ops":[{"insert":"a night"},{"attributes":{"font":"slabo"},"insert":" sky"},{"insert":" filled with stars above a turbulent"},{"attributes":{"font":"roboto"},"insert":" sea"},{"insert":" with giant waves"}]}',
407
+ # '',
408
+ # 2,
409
+ # 0.6,
410
+ # 0,
411
+ # 0,
412
  # 6,
413
  # 0.5,
414
  # None,
415
  # ],
416
+ # ]
417
+ # gr.Examples(examples=style_examples,
418
+ # label='Font style examples',
419
+ # inputs=[
420
+ # text_input,
421
+ # negative_prompt,
422
+ # num_segments,
423
+ # segment_threshold,
424
+ # inject_interval,
425
+ # inject_background,
426
+ # seed,
427
+ # color_guidance_weight,
428
+ # rich_text_input,
429
+ # ],
430
+ # outputs=[
431
+ # plaintext_result,
432
+ # richtext_result,
433
+ # segments,
434
+ # token_map,
435
+ # ],
436
+ # fn=generate,
437
+ # cache_examples=True,
438
+ # examples_per_page=20)
439
+
440
+ # with gr.Row():
441
+ # size_examples = [
442
  # [
443
+ # '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": " pepperoni, and mushroom on the top"}]}',
444
+ # '',
445
+ # 5,
 
446
  # 0.3,
447
+ # 0,
448
+ # 0,
449
+ # 3,
450
+ # 1,
451
+ # None,
452
+ # ],
453
+ # [
454
+ # '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "60px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top"}]}',
455
+ # '',
456
+ # 5,
457
  # 0.3,
458
+ # 0,
459
+ # 0,
460
+ # 3,
461
+ # 1,
462
  # None,
463
  # ],
464
  # [
465
+ # '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "60px"}, "insert": "mushroom"}, {"insert": " on the top"}]}',
466
  # '',
467
+ # 5,
 
 
468
  # 0.3,
469
+ # 0,
470
+ # 0,
471
+ # 3,
472
+ # 1,
473
  # None,
474
  # ],
475
  # ]
476
+ # gr.Examples(examples=size_examples,
477
+ # label='Font size examples',
478
  # inputs=[
479
  # text_input,
480
  # negative_prompt,
 
495
  # fn=generate,
496
  # cache_examples=True,
497
  # examples_per_page=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
499
  fn=generate,
500
  inputs=[
models/region_diffusion_xl.py CHANGED
@@ -846,12 +846,16 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
846
  # apply guidance
847
  if use_guidance and t < text_format_dict['guidance_start_step']:
848
  with torch.enable_grad():
 
 
849
  if not latents.requires_grad:
850
  latents.requires_grad = True
851
  # import ipdb;ipdb.set_trace()
852
- latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=latents.dtype)
 
853
  latents_inp = latents_0 / self.vae.config.scaling_factor
854
- imgs = self.vae.decode(latents_inp.to(dtype=torch.float32)).sample
 
855
  imgs = (imgs / 2 + 0.5).clamp(0, 1)
856
  loss_total = 0.
857
  for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
@@ -863,6 +867,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
863
  loss_total.backward()
864
  latents = (
865
  latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone().to(dtype=prompt_embeds.dtype)
 
866
 
867
  # apply background injection
868
  if i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0:
@@ -1023,7 +1028,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
1023
  PyTorch Forward hook to save outputs at each forward pass.
1024
  """
1025
  if 'attn1' in name:
1026
- modified_args = (args[0], self.self_attention_maps_cur[name])
1027
  return modified_args
1028
  # cross attention injection
1029
  # elif 'attn2' in name:
@@ -1039,7 +1044,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
1039
  PyTorch Forward hook to save outputs at each forward pass.
1040
  """
1041
  modified_args = (args[0], args[1],
1042
- self.self_attention_maps_cur[name])
1043
  return modified_args
1044
  for name, module in self.unet.named_modules():
1045
  leaf_name = name.split('.')[-1]
@@ -1077,7 +1082,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
1077
  # activations[name] = out[1][1].detach()
1078
  else:
1079
  assert out[1][1].shape[-1] != 77
1080
- activations[name] = out[1][1].detach()
1081
 
1082
  def save_resnet_activations(activations, name, module, inp, out):
1083
  r"""
@@ -1087,7 +1092,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
1087
  # out[1] - residual hidden feature
1088
  # import ipdb;ipdb.set_trace()
1089
  assert out[1].shape[-1] == 64
1090
- activations[name] = out[1].detach()
1091
  attention_dict = collections.defaultdict(list)
1092
  for name, module in self.unet.named_modules():
1093
  leaf_name = name.split('.')[-1]
 
846
  # apply guidance
847
  if use_guidance and t < text_format_dict['guidance_start_step']:
848
  with torch.enable_grad():
849
+ self.unet.to(device='cpu')
850
+ torch.cuda.empty_cache()
851
  if not latents.requires_grad:
852
  latents.requires_grad = True
853
  # import ipdb;ipdb.set_trace()
854
+ # latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=latents.dtype)
855
+ latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=torch.bfloat16)
856
  latents_inp = latents_0 / self.vae.config.scaling_factor
857
+ imgs = self.vae.to(dtype=latents_inp.dtype).decode(latents_inp).sample
858
+ # imgs = self.vae.decode(latents_inp.to(dtype=torch.float32)).sample
859
  imgs = (imgs / 2 + 0.5).clamp(0, 1)
860
  loss_total = 0.
861
  for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
 
867
  loss_total.backward()
868
  latents = (
869
  latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone().to(dtype=prompt_embeds.dtype)
870
+ self.unet.to(device=latents.device)
871
 
872
  # apply background injection
873
  if i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0:
 
1028
  PyTorch Forward hook to save outputs at each forward pass.
1029
  """
1030
  if 'attn1' in name:
1031
+ modified_args = (args[0], self.self_attention_maps_cur[name].to(args[0].device))
1032
  return modified_args
1033
  # cross attention injection
1034
  # elif 'attn2' in name:
 
1044
  PyTorch Forward hook to save outputs at each forward pass.
1045
  """
1046
  modified_args = (args[0], args[1],
1047
+ self.self_attention_maps_cur[name].to(args[0].device))
1048
  return modified_args
1049
  for name, module in self.unet.named_modules():
1050
  leaf_name = name.split('.')[-1]
 
1082
  # activations[name] = out[1][1].detach()
1083
  else:
1084
  assert out[1][1].shape[-1] != 77
1085
+ activations[name] = out[1][1].detach().cpu()
1086
 
1087
  def save_resnet_activations(activations, name, module, inp, out):
1088
  r"""
 
1092
  # out[1] - residual hidden feature
1093
  # import ipdb;ipdb.set_trace()
1094
  assert out[1].shape[-1] == 64
1095
+ activations[name] = out[1].detach().cpu()
1096
  attention_dict = collections.defaultdict(list)
1097
  for name, module in self.unet.named_modules():
1098
  leaf_name = name.split('.')[-1]