silentchen commited on
Commit
72c1946
·
1 Parent(s): 6ae5687

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -452
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  from omegaconf import OmegaConf
4
- # from layout_guidance.inference import inference
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, LMSDiscreteScheduler
7
  from my_model import unet_2d_condition
@@ -9,151 +8,17 @@ import json
9
  import numpy as np
10
  from PIL import Image, ImageDraw, ImageFont
11
  from functools import partial
12
- from collections import Counter
13
  import math
14
- import gc
15
  from utils import compute_ca_loss
16
  from gradio import processing_utils
17
  from typing import Optional
18
 
19
  import warnings
20
 
21
- from datetime import datetime
22
-
23
- from huggingface_hub import hf_hub_download
24
-
25
- hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
26
-
27
  import sys
28
 
29
  sys.tracebacklimit = 0
30
 
31
-
32
- def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
33
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
34
- return torch.load(cache_file, map_location='cpu')
35
-
36
-
37
- def load_ckpt_config_from_hf(modality):
38
- ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
39
- config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
40
- return ckpt, config
41
-
42
-
43
- def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
44
- pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
45
- config = OmegaConf.create(config["_content"]) # config used in training
46
- config.alpha_scale = 1.0
47
- config.model['params']['is_inpaint'] = is_inpaint
48
- config.model['params']['is_style'] = is_style
49
-
50
- if common_instances is None:
51
- common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
52
- common_instances = load_common_ckpt(config, common_ckpt)
53
-
54
- loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
55
-
56
- return loaded_model_list, common_instances
57
-
58
-
59
- class Instance:
60
- def __init__(self, capacity=2):
61
- self.model_type = 'base'
62
- self.loaded_model_list = {}
63
- self.counter = Counter()
64
- self.global_counter = Counter()
65
- self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
66
- 'gligen-generation-text-box',
67
- is_inpaint=False, is_style=False, common_instances=None
68
- )
69
- self.capacity = capacity
70
-
71
- def _log(self, model_type, batch_size, instruction, phrase_list):
72
- self.counter[model_type] += 1
73
- self.global_counter[model_type] += 1
74
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
75
- print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
76
- current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
77
- ))
78
-
79
- def get_model(self, model_type, batch_size, instruction, phrase_list):
80
- if model_type in self.loaded_model_list:
81
- self._log(model_type, batch_size, instruction, phrase_list)
82
- return self.loaded_model_list[model_type]
83
-
84
- if self.capacity == len(self.loaded_model_list):
85
- least_used_type = self.counter.most_common()[-1][0]
86
- del self.loaded_model_list[least_used_type]
87
- del self.counter[least_used_type]
88
- gc.collect()
89
- torch.cuda.empty_cache()
90
-
91
- self.loaded_model_list[model_type] = self._get_model(model_type)
92
- self._log(model_type, batch_size, instruction, phrase_list)
93
- return self.loaded_model_list[model_type]
94
-
95
- def _get_model(self, model_type):
96
- if model_type == 'base':
97
- return ckpt_load_helper(
98
- 'gligen-generation-text-box',
99
- is_inpaint=False, is_style=False, common_instances=self.common_instances
100
- )[0]
101
- elif model_type == 'inpaint':
102
- return ckpt_load_helper(
103
- 'gligen-inpainting-text-box',
104
- is_inpaint=True, is_style=False, common_instances=self.common_instances
105
- )[0]
106
- elif model_type == 'style':
107
- return ckpt_load_helper(
108
- 'gligen-generation-text-image-box',
109
- is_inpaint=False, is_style=True, common_instances=self.common_instances
110
- )[0]
111
-
112
- assert False
113
-
114
-
115
- # instance = Instance()
116
-
117
-
118
- def load_clip_model():
119
- from transformers import CLIPProcessor, CLIPModel
120
- version = "openai/clip-vit-large-patch14"
121
- model = CLIPModel.from_pretrained(version).cuda()
122
- processor = CLIPProcessor.from_pretrained(version)
123
-
124
- return {
125
- 'version': version,
126
- 'model': model,
127
- 'processor': processor,
128
- }
129
-
130
-
131
- # clip_model = load_clip_model()
132
-
133
-
134
- class ImageMask(gr.components.Image):
135
- """
136
- Sets: source="canvas", tool="sketch"
137
- """
138
-
139
- is_template = True
140
-
141
- def __init__(self, **kwargs):
142
- super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
143
-
144
- def preprocess(self, x):
145
- if x is None:
146
- return x
147
- if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
148
- decode_image = processing_utils.decode_base64_to_image(x)
149
- width, height = decode_image.size
150
- mask = np.zeros((height, width, 4), dtype=np.uint8)
151
- mask[..., -1] = 255
152
- mask = self.postprocess(mask)
153
- x = {'image': x, 'mask': mask}
154
- return super().preprocess(x)
155
-
156
-
157
  class Blocks(gr.Blocks):
158
 
159
  def __init__(
@@ -206,19 +71,7 @@ def draw_box(boxes=[], texts=[], img=None):
206
  fill=(255, 255, 255))
207
  return img
208
 
209
- with open('./conf/unet/config.json') as f:
210
- unet_config = json.load(f)
211
-
212
- unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
213
- tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
214
- text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
215
- vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
216
- attn_map = None
217
- cfg = OmegaConf.load('./conf/net_conf.yaml')
218
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
219
- unet.to(device)
220
- text_encoder.to(device)
221
- vae.to(device)
222
  def inference(device, unet, vae, tokenizer, text_encoder, prompt, cfg,attn_map, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
223
  uncond_input = tokenizer(
224
  [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
@@ -328,7 +181,7 @@ def auto_append_grounding(language_instruction, grounding_texts):
328
  return language_instruction
329
 
330
 
331
- def generate(language_instruction, grounding_texts, sketch_pad,
332
  loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
333
  state):
334
  if 'boxes' not in state:
@@ -406,26 +259,16 @@ def center_crop(img, HW=None, tgt_size=(512, 512)):
406
 
407
 
408
  def draw(input, grounding_texts, new_image_trigger, state):
409
-
410
  if type(input) == dict:
411
  image = input['image']
412
  mask = input['mask']
413
  else:
414
  mask = input
415
  if mask.ndim == 3:
416
- mask = mask[..., 0]
417
 
418
  image_scale = 1.0
419
 
420
- mask = binarize(mask)
421
- if mask.shape != (512, 512):
422
- # assert False, "should not receive any non- 512x512 masks."
423
- if 'original_image' in state and state['original_image'].shape[:2] == mask.shape:
424
- mask = center_crop(mask, state['inpaint_hw'])
425
- image = center_crop(state['original_image'], state['inpaint_hw'])
426
- else:
427
- mask = np.zeros((512, 512), dtype=np.uint8)
428
- # mask = center_crop(mask)
429
  mask = binarize(mask)
430
 
431
  if type(mask) != np.ndarray:
@@ -464,14 +307,8 @@ def draw(input, grounding_texts, new_image_trigger, state):
464
  grounding_texts = [x for x in grounding_texts if len(x) > 0]
465
  if len(grounding_texts) < len(state['boxes']):
466
  grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))]
467
- print("state", state)
468
  box_image = draw_box(state['boxes'], grounding_texts, image)
469
 
470
- if box_image is not None and state.get('inpaint_hw', None):
471
- inpaint_hw = state['inpaint_hw']
472
- box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
473
- original_image = state['original_image'].copy()
474
- box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
475
  return [box_image, new_image_trigger, image_scale, state]
476
 
477
 
@@ -479,291 +316,252 @@ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
479
  if task != 'Grounded Inpainting':
480
  sketch_pad_trigger = sketch_pad_trigger + 1
481
  blank_samples = batch_size % 2 if batch_size > 1 else 0
482
- out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
483
- + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
484
- + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
485
- state = {}
486
- return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
487
-
488
-
489
- css = """
490
- #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
491
- {
492
- height: var(--height) !important;
493
- max-height: var(--height) !important;
494
- min-height: var(--height) !important;
495
- }
496
- #paper-info a {
497
- color:#008AD7;
498
- text-decoration: none;
499
- }
500
- #paper-info a:hover {
501
- cursor: pointer;
502
- text-decoration: none;
503
- }
504
-
505
- .tooltip {
506
- color: #555;
507
- position: relative;
508
- display: inline-block;
509
- cursor: pointer;
510
- }
511
-
512
- .tooltip .tooltiptext {
513
- visibility: hidden;
514
- width: 400px;
515
- background-color: #555;
516
- color: #fff;
517
- text-align: center;
518
- padding: 5px;
519
- border-radius: 5px;
520
- position: absolute;
521
- z-index: 1; /* Set z-index to 1 */
522
- left: 10px;
523
- top: 100%;
524
- opacity: 0;
525
- transition: opacity 0.3s;
526
- }
527
-
528
- .tooltip:hover .tooltiptext {
529
- visibility: visible;
530
- opacity: 1;
531
- z-index: 9999; /* Set a high z-index value when hovering */
532
- }
533
-
534
-
535
- """
536
-
537
- rescale_js = """
538
- function(x) {
539
- const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
540
- let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
541
- const image_width = root.querySelector('#img2img_image').clientWidth;
542
- const target_height = parseInt(image_width * image_scale);
543
- document.body.style.setProperty('--height', `${target_height}px`);
544
- root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
545
- root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
546
- return x;
547
- }
548
- """
549
-
550
- with Blocks(
551
- css=css,
552
- analytics_enabled=False,
553
- title="Layout-Guidance demo",
554
- ) as main:
555
- description = """<p style="text-align: center; font-weight: bold;">
556
- <span style="font-size: 28px">Layout Guidance</span>
557
- <br>
558
- <span style="font-size: 18px" id="paper-info">
559
- [<a href=" " target="_blank">Project Page</a>]
560
- [<a href=" " target="_blank">Paper</a>]
561
- [<a href=" " target="_blank">GitHub</a>]
562
- </span>
563
- </p>
564
- """
565
- gr.HTML(description)
566
- with gr.Column():
567
- language_instruction = gr.Textbox(
568
- label="Text Prompt",
569
- )
570
- grounding_instruction = gr.Textbox(
571
- label="Grounding instruction (Separated by semicolon)",
572
- )
573
- sketch_pad_trigger = gr.Number(value=0, visible=False)
574
- sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
575
- init_white_trigger = gr.Number(value=0, visible=False)
576
- image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
577
- new_image_trigger = gr.Number(value=0, visible=False)
578
-
579
-
580
-
581
- with gr.Row():
582
- sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
583
- out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
584
- out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
585
- # out_gen_2 = gr.Image(type="pil", visible=True, label="Generated Image")
586
- # out_gen_3 = gr.Image(type="pil", visible=True, show_label=False)
587
- # out_gen_4 = gr.Image(type="pil", visible=True, show_label=False)
588
-
589
- with gr.Row():
590
- clear_btn = gr.Button(value='Clear')
591
- gen_btn = gr.Button(value='Generate')
592
- # clear_btn = gr.Button(value='Clear')
593
- # clear_btn = gr.Button(value='Clear')
594
-
595
- with gr.Accordion("Advanced Options", open=False):
596
- with gr.Column():
597
- description = """<div class="tooltip">Loss Scale Factor &#9432
598
- <span class="tooltiptext">The scale factor of the backward guidance loss. The larger it is, the better control we get while it sometimes losses fidelity. </span>
599
- </div>
600
- <div class="tooltip">Guidance Scale &#9432
601
- <span class="tooltiptext">The scale factor of classifier-free guidance. </span>
602
- </div>
603
- <div class="tooltip" >Max Iteration per Step &#9432
604
- <span class="tooltiptext">The max iterations of backward guidance in each diffusion inference process.</span>
605
- </div>
606
- <div class="tooltip" >Loss Threshold &#9432
607
- <span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the backward guidance is stopped. </span>
608
- </div>
609
- <div class="tooltip" >Max Step of Backward Guidance &#9432
610
- <span class="tooltiptext">The max steps of backward guidance in diffusion inference process.</span>
611
- </div>
612
- """
613
- gr.HTML(description)
614
- Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,label="Loss Scale Factor")
615
- guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
616
- batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False)
617
- max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
618
- loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
619
- max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
620
- # fix_seed = gr.Checkbox(value=True, label="Fixed seed")
621
- rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
622
-
623
- # with gr.Column(scale=4):
624
- # gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
625
- # with gr.Row():
626
- # out_gen_1 = gr.Image(type="pil", visible=True, show_label=False, label="Generated Image")
627
- # out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
628
- # with gr.Row():
629
- # out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
630
- # out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
631
-
632
- state = gr.State({})
633
-
634
-
635
- class Controller:
636
- def __init__(self):
637
- self.calls = 0
638
- self.tracks = 0
639
- self.resizes = 0
640
- self.scales = 0
641
-
642
- def init_white(self, init_white_trigger):
643
- self.calls += 1
644
- return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1
645
-
646
- def change_n_samples(self, n_samples):
647
- blank_samples = n_samples % 2 if n_samples > 1 else 0
648
- return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
649
- + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
650
-
651
- def resize_centercrop(self, state):
652
- self.resizes += 1
653
- image = state['original_image'].copy()
654
- inpaint_hw = int(0.9 * min(*image.shape[:2]))
655
- state['inpaint_hw'] = inpaint_hw
656
- image_cc = center_crop(image, inpaint_hw)
657
- # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape)
658
- return image_cc, state
659
-
660
- def resize_masked(self, state):
661
- self.resizes += 1
662
- image = state['original_image'].copy()
663
- inpaint_hw = int(0.9 * min(*image.shape[:2]))
664
- state['inpaint_hw'] = inpaint_hw
665
- image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw)
666
- state['masked_image'] = image_mask.copy()
667
- # print(f'mask triggered {self.resizes}')
668
- return image_mask, state
669
-
670
- def switch_task_hide_cond(self, task):
671
- cond = False
672
- if task == "Grounded Generation":
673
- cond = True
674
-
675
- return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None,
676
- visible=False), gr.Slider.update(
677
- visible=cond), gr.Checkbox.update(visible=(not cond), value=False)
678
-
679
-
680
- controller = Controller()
681
- main.load(
682
- lambda x: x + 1,
683
- inputs=sketch_pad_trigger,
684
- outputs=sketch_pad_trigger,
685
- queue=False)
686
- sketch_pad.edit(
687
- draw,
688
- inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
689
- outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
690
- queue=False,
691
- )
692
- grounding_instruction.change(
693
- draw,
694
- inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
695
- outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
696
- queue=False,
697
- )
698
- clear_btn.click(
699
- clear,
700
- inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
701
- outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state],
702
- queue=False)
703
-
704
- sketch_pad_trigger.change(
705
- controller.init_white,
706
- inputs=[init_white_trigger],
707
- outputs=[sketch_pad, image_scale, init_white_trigger],
708
- queue=False)
709
- sketch_pad_resize_trigger.change(
710
- controller.resize_masked,
711
- inputs=[state],
712
- outputs=[sketch_pad, state],
713
- queue=False)
714
- # batch_size.change(
715
- # controller.change_n_samples,
716
- # inputs=[batch_size],
717
- # outputs=[out_gen_1, out_gen_2],
718
- # queue=False)
719
-
720
- # batch_size.change(
721
- # controller.change_n_samples,
722
- # inputs=[batch_size],
723
- # outputs=[out_gen_1, out_gen_2],
724
- # queue=False)
725
-
726
- gen_btn.click(
727
- generate,
728
- inputs=[
729
- language_instruction, grounding_instruction, sketch_pad,
730
- loss_threshold, guidance_scale, batch_size, rand_seed,
731
- max_step,
732
- Loss_scale, max_iter,
733
- state,
734
- ],
735
- outputs=[out_gen_1, state],
736
- queue=True
737
- )
738
- sketch_pad_resize_trigger.change(
739
- None,
740
- None,
741
- sketch_pad_resize_trigger,
742
- _js=rescale_js,
743
- queue=False)
744
- init_white_trigger.change(
745
- None,
746
- None,
747
- init_white_trigger,
748
- _js=rescale_js,
749
- queue=False)
750
-
751
- with gr.Column():
752
- gr.Examples(
753
- examples=[
754
- [
755
- # "images/input.png",
756
- "A hello kitty toy is playing with a purple ball.",
757
- "hello kitty;ball",
758
- "images/hello_kitty_results.png"
759
- ],
760
- ],
761
- inputs=[language_instruction, grounding_instruction, out_gen_1],
762
- outputs=None,
763
- fn=None,
764
- cache_examples=False,
765
- )
766
 
767
- main.queue(concurrency_count=1, api_open=False)
768
- main.launch(share=False, show_api=False, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from omegaconf import OmegaConf
 
4
  from transformers import CLIPTextModel, CLIPTokenizer
5
  from diffusers import AutoencoderKL, LMSDiscreteScheduler
6
  from my_model import unet_2d_condition
 
8
  import numpy as np
9
  from PIL import Image, ImageDraw, ImageFont
10
  from functools import partial
 
11
  import math
 
12
  from utils import compute_ca_loss
13
  from gradio import processing_utils
14
  from typing import Optional
15
 
16
  import warnings
17
 
 
 
 
 
 
 
18
  import sys
19
 
20
  sys.tracebacklimit = 0
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class Blocks(gr.Blocks):
23
 
24
  def __init__(
 
71
  fill=(255, 255, 255))
72
  return img
73
 
74
+
 
 
 
 
 
 
 
 
 
 
 
 
75
  def inference(device, unet, vae, tokenizer, text_encoder, prompt, cfg,attn_map, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
76
  uncond_input = tokenizer(
77
  [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
 
181
  return language_instruction
182
 
183
 
184
+ def generate(unet, vae, tokenizer, text_encoder, cfg, attn_map, language_instruction, grounding_texts, sketch_pad,
185
  loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
186
  state):
187
  if 'boxes' not in state:
 
259
 
260
 
261
  def draw(input, grounding_texts, new_image_trigger, state):
 
262
  if type(input) == dict:
263
  image = input['image']
264
  mask = input['mask']
265
  else:
266
  mask = input
267
  if mask.ndim == 3:
268
+ mask = 255 - mask[..., 0]
269
 
270
  image_scale = 1.0
271
 
 
 
 
 
 
 
 
 
 
272
  mask = binarize(mask)
273
 
274
  if type(mask) != np.ndarray:
 
307
  grounding_texts = [x for x in grounding_texts if len(x) > 0]
308
  if len(grounding_texts) < len(state['boxes']):
309
  grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))]
 
310
  box_image = draw_box(state['boxes'], grounding_texts, image)
311
 
 
 
 
 
 
312
  return [box_image, new_image_trigger, image_scale, state]
313
 
314
 
 
316
  if task != 'Grounded Inpainting':
317
  sketch_pad_trigger = sketch_pad_trigger + 1
318
  blank_samples = batch_size % 2 if batch_size > 1 else 0
319
+ out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)]
320
+ # state = {}
321
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
322
+
323
+
324
+ def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ css = """
327
+ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
328
+ {
329
+ height: var(--height) !important;
330
+ max-height: var(--height) !important;
331
+ min-height: var(--height) !important;
332
+ }
333
+ #paper-info a {
334
+ color:#008AD7;
335
+ text-decoration: none;
336
+ }
337
+ #paper-info a:hover {
338
+ cursor: pointer;
339
+ text-decoration: none;
340
+ }
341
 
342
+ .tooltip {
343
+ color: #555;
344
+ position: relative;
345
+ display: inline-block;
346
+ cursor: pointer;
347
+ }
348
+
349
+ .tooltip .tooltiptext {
350
+ visibility: hidden;
351
+ width: 400px;
352
+ background-color: #555;
353
+ color: #fff;
354
+ text-align: center;
355
+ padding: 5px;
356
+ border-radius: 5px;
357
+ position: absolute;
358
+ z-index: 1; /* Set z-index to 1 */
359
+ left: 10px;
360
+ top: 100%;
361
+ opacity: 0;
362
+ transition: opacity 0.3s;
363
+ }
364
+
365
+ .tooltip:hover .tooltiptext {
366
+ visibility: visible;
367
+ opacity: 1;
368
+ z-index: 9999; /* Set a high z-index value when hovering */
369
+ }
370
+
371
+
372
+ """
373
+
374
+ rescale_js = """
375
+ function(x) {
376
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
377
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
378
+ const image_width = root.querySelector('#img2img_image').clientWidth;
379
+ const target_height = parseInt(image_width * image_scale);
380
+ document.body.style.setProperty('--height', `${target_height}px`);
381
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
382
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
383
+ return x;
384
+ }
385
+ """
386
+ with open('./conf/unet/config.json') as f:
387
+ unet_config = json.load(f)
388
+
389
+ unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained('runwayml/stable-diffusion-v1-5',
390
+ subfolder="unet")
391
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
392
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
393
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
394
+ attn_map = None
395
+ cfg = OmegaConf.load('./conf/net_conf.yaml')
396
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
397
+ unet.to(device)
398
+ text_encoder.to(device)
399
+ vae.to(device)
400
+
401
+ with Blocks(
402
+ css=css,
403
+ analytics_enabled=False,
404
+ title="Layout-Guidance demo",
405
+ ) as demo:
406
+ description = """<p style="text-align: center; font-weight: bold;">
407
+ <span style="font-size: 28px">Layout Guidance</span>
408
+ <br>
409
+ <span style="font-size: 18px" id="paper-info">
410
+ [<a href=" " target="_blank">Project Page</a>]
411
+ [<a href=" " target="_blank">Paper</a>]
412
+ [<a href=" " target="_blank">GitHub</a>]
413
+ </span>
414
+ </p>
415
+ """
416
+ gr.HTML(description)
417
+ with gr.Column():
418
+ language_instruction = gr.Textbox(
419
+ label="Text Prompt",
420
+ )
421
+ grounding_instruction = gr.Textbox(
422
+ label="Grounding instruction (Separated by semicolon)",
423
+ )
424
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
425
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
426
+ init_white_trigger = gr.Number(value=0, visible=False)
427
+ image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
428
+ new_image_trigger = gr.Number(value=0, visible=False)
429
+
430
+
431
+
432
+ with gr.Row():
433
+ sketch_pad = gr.Paint(label="Sketch Pad", elem_id="img2img_image", source='canvas', shape=(512, 512))
434
+
435
+ out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
436
+ out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
437
+
438
+ with gr.Row():
439
+ clear_btn = gr.Button(value='Clear')
440
+ gen_btn = gr.Button(value='Generate')
441
+
442
+ with gr.Accordion("Advanced Options", open=False):
443
+ with gr.Column():
444
+ description = """<div class="tooltip">Loss Scale Factor &#9432
445
+ <span class="tooltiptext">The scale factor of the backward guidance loss. The larger it is, the better control we get while it sometimes losses fidelity. </span>
446
+ </div>
447
+ <div class="tooltip">Guidance Scale &#9432
448
+ <span class="tooltiptext">The scale factor of classifier-free guidance. </span>
449
+ </div>
450
+ <div class="tooltip" >Max Iteration per Step &#9432
451
+ <span class="tooltiptext">The max iterations of backward guidance in each diffusion inference process.</span>
452
+ </div>
453
+ <div class="tooltip" >Loss Threshold &#9432
454
+ <span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the backward guidance is stopped. </span>
455
+ </div>
456
+ <div class="tooltip" >Max Step of Backward Guidance &#9432
457
+ <span class="tooltiptext">The max steps of backward guidance in diffusion inference process.</span>
458
+ </div>
459
+ """
460
+ gr.HTML(description)
461
+ Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,label="Loss Scale Factor")
462
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
463
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False)
464
+ max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
465
+ loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
466
+ max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
467
+ # fix_seed = gr.Checkbox(value=True, label="Fixed seed")
468
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
469
+
470
+ state = gr.State({})
471
+
472
+
473
+ class Controller:
474
+ def __init__(self):
475
+ self.calls = 0
476
+ self.tracks = 0
477
+ self.resizes = 0
478
+ self.scales = 0
479
+
480
+ def init_white(self, init_white_trigger):
481
+ self.calls += 1
482
+ return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1
483
+
484
+ def change_n_samples(self, n_samples):
485
+ blank_samples = n_samples % 2 if n_samples > 1 else 0
486
+ return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
487
+ + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
488
+
489
+
490
+ controller = Controller()
491
+ demo.load(
492
+ lambda x: x + 1,
493
+ inputs=sketch_pad_trigger,
494
+ outputs=sketch_pad_trigger,
495
+ queue=False)
496
+ sketch_pad.edit(
497
+ draw,
498
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
499
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
500
+ queue=False,
501
+ )
502
+ grounding_instruction.change(
503
+ draw,
504
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
505
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
506
+ queue=False,
507
+ )
508
+ clear_btn.click(
509
+ clear,
510
+ inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
511
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state],
512
+ queue=False)
513
+
514
+ sketch_pad_trigger.change(
515
+ controller.init_white,
516
+ inputs=[init_white_trigger],
517
+ outputs=[sketch_pad, image_scale, init_white_trigger],
518
+ queue=False)
519
+
520
+ gen_btn.click(
521
+ fn=partial(generate, unet, vae, tokenizer, text_encoder, cfg, attn_map),
522
+ inputs=[
523
+ language_instruction, grounding_instruction, sketch_pad,
524
+ loss_threshold, guidance_scale, batch_size, rand_seed,
525
+ max_step,
526
+ Loss_scale, max_iter,
527
+ state,
528
+ ],
529
+ outputs=[out_gen_1, state],
530
+ queue=True
531
+ )
532
+ sketch_pad_resize_trigger.change(
533
+ None,
534
+ None,
535
+ sketch_pad_resize_trigger,
536
+ _js=rescale_js,
537
+ queue=False)
538
+ init_white_trigger.change(
539
+ None,
540
+ None,
541
+ init_white_trigger,
542
+ _js=rescale_js,
543
+ queue=False)
544
+
545
+ with gr.Column():
546
+ gr.Examples(
547
+ examples=[
548
+ [
549
+ # "images/input.png",
550
+ "A hello kitty toy is playing with a purple ball.",
551
+ "hello kitty;ball",
552
+ "images/hello_kitty_results.png"
553
+ ],
554
+ ],
555
+ inputs=[language_instruction, grounding_instruction, out_gen_1],
556
+ outputs=None,
557
+ fn=None,
558
+ cache_examples=False,
559
+ )
560
+ description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
561
+ gr.HTML(description)
562
+
563
+ demo.queue(concurrency_count=1, api_open=False)
564
+ demo.launch(share=False, show_api=False, show_error=True)
565
+
566
+ if __name__ == '__main__':
567
+ main()