Ubuntu commited on
Commit
e5efca7
β€’
1 Parent(s): c071a86

Update Inpainting Demo

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. .log/log.txt +6 -0
  3. SegFormer +1 -0
  4. output.png +0 -0
  5. requirements.txt +2 -2
  6. test.png +0 -0
  7. test.py +168 -76
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  __pycache__
2
  *.pyc
3
  checkpoints/
 
4
  *.pth
 
1
  __pycache__
2
  *.pyc
3
  checkpoints/
4
+ I2SB/
5
  *.pth
.log/log.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [19:02:29] INFO (0:00:00) Loaded options from opt_pkl_path=PosixPath('I2SB/results/inpaint-freeform2030/options.pkl')!
2
+ INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
3
+ [19:02:33] INFO (0:00:03) [Net] Initialized network from ckpt_pkl='I2SB/data/256x256_diffusion_uncond_fixedsigma.pkl'! Size=552807171!
4
+ [19:02:44] INFO (0:00:14) [Net] Loaded pretrained adm ckpt_pt='I2SB/data/256x256_diffusion_uncond_fixedsigma.pt'!
5
+ [19:02:49] INFO (0:00:19) [Net] Loaded network ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
6
+ [19:02:50] INFO (0:00:20) [Ema] Loaded ema ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
SegFormer ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 64ab11278eb30b8e2d8ea1d10a777fc5b1563948
output.png ADDED
requirements.txt CHANGED
@@ -18,8 +18,8 @@ timm
18
  # torch==2.0.0
19
  # torchvision==0.15.1
20
 
21
- torch==2.2.1
22
- torchvision==0.17.1
23
 
24
  gevent
25
  yapf
 
18
  # torch==2.0.0
19
  # torchvision==0.15.1
20
 
21
+ # torch==2.2.1
22
+ # torchvision==0.17.1
23
 
24
  gevent
25
  yapf
test.png ADDED
test.py CHANGED
@@ -36,6 +36,34 @@ from GroundingDINO.groundingdino.util import box_ops
36
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
37
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  import cv2
40
  import numpy as np
41
  import matplotlib
@@ -126,6 +154,30 @@ kosmos_processor = None
126
  colors = [(255, 0, 0), (0, 255, 0)]
127
  markers = [1, 5]
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def get_point(img, sel_pix, evt: gr.SelectData):
130
  img = np.array(img, dtype=np.uint8)
131
  sel_pix.append(evt.index)
@@ -146,6 +198,10 @@ def undo_button(orig_img, sel_pix):
146
  for point in sel_pix:
147
  cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
148
  return Image.fromarray(temp).convert("RGB")
 
 
 
 
149
 
150
  def toggle_button(orig_img, task_type):
151
  print(task_type)
@@ -173,6 +229,37 @@ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
173
  _ = model.eval()
174
  return model
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def plot_boxes_to_image(image_pil, tgt):
177
  H, W = tgt["size"]
178
  boxes = tgt["boxes"]
@@ -238,6 +325,8 @@ def load_image(image_path):
238
  image, _ = transform(image_pil, None) # 3, h, w
239
  return image_pil, image
240
 
 
 
241
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
242
  caption = caption.lower()
243
  caption = caption.strip()
@@ -357,6 +446,24 @@ def load_sd_model(device):
357
  torch_dtype=torch.float16,
358
  )
359
  sd_model = sd_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
362
  try:
@@ -511,7 +618,7 @@ def concatenate_images_vertical(image1, image2):
511
  return new_image
512
 
513
  mask_source_draw = "draw a mask on input image"
514
- mask_source_segment = "type what to detect below"
515
 
516
  def get_time_cost(run_task_time, time_cost_str):
517
  now_time = int(time.time()*1000)
@@ -524,11 +631,8 @@ def get_time_cost(run_task_time, time_cost_str):
524
  run_task_time = now_time
525
  return run_task_time, time_cost_str
526
 
527
- def run_anything_task(input_image, input_points, origin_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
528
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
529
-
530
- text_prompt = getTextTrans(text_prompt, source='zh', target='en')
531
- inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
532
 
533
  run_task_time = 0
534
  time_cost_str = ''
@@ -543,27 +647,19 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
543
  image_pil, image = load_image(input_image.convert("RGB"))
544
  input_img = input_image
545
 
546
- kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_input, kosmos_model, kosmos_processor)
547
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
548
  return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
549
 
550
- text_prompt = text_prompt.strip()
551
- # if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
552
- # if text_prompt == '':
553
- # return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
554
-
555
  if input_image is None:
556
  return [], gr.Gallery.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
557
 
558
  file_temp = int(time.time())
559
- logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
560
 
561
  output_images = []
562
 
563
  # load image
564
- if mask_source_radio == mask_source_draw:
565
- input_mask_pil = input_image['mask']
566
- input_mask = np.array(input_mask_pil.convert("L"))
567
 
568
  if isinstance(input_image, dict):
569
  image_pil, image = load_image(input_image['image'].convert("RGB"))
@@ -626,17 +722,17 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
626
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
627
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
628
  elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
629
- if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
630
  task_type = 'remove'
631
 
632
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
633
  if mask_source_radio == mask_source_draw:
 
 
634
  mask_pil = input_mask_pil
635
  mask = input_mask
636
  else:
637
  masks_ori = copy.deepcopy(masks)
638
- if inpaint_mode == 'merge':
639
- masks = torch.sum(masks, dim=0).unsqueeze(0)
640
  masks = torch.where(masks > 0, True, False)
641
  mask = masks[0][0].cpu().numpy()
642
  mask_pil = Image.fromarray(mask)
@@ -644,18 +740,11 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
644
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
645
 
646
  if task_type in ['inpainting', 'outpainting']:
647
- # inpainting pipeline
648
- image_source_for_inpaint = image_pil.resize((512, 512))
649
- image_mask_for_inpaint = mask_pil.resize((512, 512))
650
- if task_type in ['outpainting']:
651
- # reverse mask
652
- img_arr = np.array(image_mask_for_inpaint)
653
- img_arr = np.where(img_arr > 0, 1, img_arr)
654
- img_arr = 1 - img_arr
655
- image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
656
- output_images.append(image_mask_for_inpaint.convert("RGB"))
657
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
658
- image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
659
  else:
660
  # remove from mask
661
  aasds = 1
@@ -681,8 +770,6 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
681
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
682
 
683
  def change_radio_display(task_type, mask_source_radio, orig_img):
684
- text_prompt_visible = True
685
- inpaint_prompt_visible = False
686
  mask_source_radio_visible = False
687
  num_relation_visible = False
688
 
@@ -693,35 +780,29 @@ def change_radio_display(task_type, mask_source_radio, orig_img):
693
  print(task_type)
694
  if task_type == "Kosmos-2":
695
  if kosmos_enable:
696
- text_prompt_visible = False
697
  image_gallery_visible = False
698
  kosmos_input_visible = True
699
  kosmos_output_visible = True
700
  kosmos_text_output_visible = True
701
 
702
- if task_type in ['inpainting', 'outpainting']:
703
- inpaint_prompt_visible = False
704
  if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
705
  mask_source_radio_visible = True
706
- if mask_source_radio == mask_source_draw:
707
- text_prompt_visible = False
708
  if task_type == "relate anything":
709
- text_prompt_visible = False
710
  num_relation_visible = True
711
  if task_type == "segment":
712
  ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
713
  elif task_type == "inpainting":
714
  ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
715
 
716
- return (gr.Textbox.update(visible=text_prompt_visible),
717
- gr.Textbox.update(visible=inpaint_prompt_visible),
718
- gr.Radio.update(visible=mask_source_radio_visible),
719
  gr.Slider.update(visible=num_relation_visible),
720
  gr.Gallery.update(visible=image_gallery_visible),
721
  gr.Radio.update(visible=kosmos_input_visible),
722
  gr.Image.update(visible=kosmos_output_visible),
723
  gr.HighlightedText.update(visible=kosmos_text_output_visible),
724
- ret, [], gr.Button("Undo point", visible = task_type == "segment"))
 
 
725
 
726
  def get_model_device(module):
727
  try:
@@ -770,42 +851,52 @@ def main_gradio(args):
770
  [input_image, selected_points],
771
  [input_image]
772
  )
773
-
774
- undo_point_button = gr.Button("Undo point")
775
- undo_point_button.click(
776
- fn= undo_button,
777
- inputs=[original_image, selected_points],
778
- outputs=[input_image]
779
- )
 
 
 
 
 
 
 
 
 
 
 
 
780
  print(dir(input_image))
781
  task_type = gr.Radio(task_types, value="segment",
782
  label='Task type', visible=True)
783
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
784
- value=mask_source_segment, label="Mask from",
785
  visible=False)
786
- text_prompt = gr.Textbox(label="Detection", placeholder="Cannot be empty")
787
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
788
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
789
 
790
  kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
791
 
792
  run_button = gr.Button(label="Run", visible=True)
793
- with gr.Accordion("Advanced options", open=False) as advanced_options:
794
- box_threshold = gr.Slider(
795
- label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
796
- )
797
- text_threshold = gr.Slider(
798
- label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
799
- )
800
- iou_threshold = gr.Slider(
801
- label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
802
- )
803
- inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
804
- with gr.Row():
805
- with gr.Column(scale=1):
806
- remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
807
- with gr.Column(scale=1):
808
- remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
809
 
810
  with gr.Column():
811
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
@@ -841,15 +932,15 @@ def main_gradio(args):
841
  selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
842
 
843
  run_button.click(fn=run_anything_task, inputs=[
844
- input_image, selected_points, original_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
845
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
846
  outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
847
 
848
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
849
- outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
850
  task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
851
- outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
852
- image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button
853
  ])
854
 
855
  # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
@@ -895,8 +986,9 @@ if __name__ == "__main__":
895
  if sam_enable:
896
  load_sam_model(device)
897
 
898
- # if inpainting_enable:
899
- # load_sd_model(device)
 
900
 
901
  # if lama_cleaner_enable:
902
  # load_lama_cleaner_model(device)
 
36
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
37
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
38
 
39
+ # I2SB
40
+ import sys
41
+
42
+ sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
43
+
44
+ import numpy as np
45
+ import torch
46
+ import torch.distributed as dist
47
+ import torchvision.transforms as transforms
48
+ import torchvision.utils as tu
49
+ from easydict import EasyDict as edict
50
+ from fastapi import (Body, Depends, FastAPI, File, Form, HTTPException, Query,
51
+ UploadFile)
52
+ from ipdb import set_trace as debug
53
+ from PIL import Image
54
+ from torch.multiprocessing import Process
55
+ from torch.utils.data import DataLoader, Subset
56
+ from torch_ema import ExponentialMovingAverage
57
+
58
+ import I2SB.distributed_util as dist_util
59
+ from I2SB.corruption import build_corruption
60
+ from I2SB.dataset import air_liquide
61
+ from I2SB.i2sb import Runner, ckpt_util, download_ckpt
62
+ from I2SB.logger import Logger
63
+ from I2SB.sample import *
64
+
65
+
66
+
67
  import cv2
68
  import numpy as np
69
  import matplotlib
 
154
  colors = [(255, 0, 0), (0, 255, 0)]
155
  markers = [1, 5]
156
 
157
+ i2sb_opt = edict(
158
+ distributed=False,
159
+ device="cuda",
160
+ batch_size=1,
161
+ nfe=10,
162
+ dataset="sample",
163
+ dataset_dir=Path(f"dataset/sample"),
164
+ n_gpu_per_node=1,
165
+ use_fp16=False,
166
+ ckpt="inpaint-freeform2030",
167
+ image_size=256,
168
+ partition=None,
169
+ global_size=1,
170
+ global_rank=0,
171
+ clip_denoise=True
172
+ )
173
+
174
+ i2sb_transforms = transforms.Compose([
175
+ transforms.Resize(i2sb_opt.image_size),
176
+ transforms.CenterCrop(i2sb_opt.image_size),
177
+ transforms.ToTensor(),
178
+ transforms.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1]
179
+ ])
180
+
181
  def get_point(img, sel_pix, evt: gr.SelectData):
182
  img = np.array(img, dtype=np.uint8)
183
  sel_pix.append(evt.index)
 
198
  for point in sel_pix:
199
  cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
200
  return Image.fromarray(temp).convert("RGB")
201
+
202
+ def clear_button(orig_img):
203
+
204
+ return orig_img, []
205
 
206
  def toggle_button(orig_img, task_type):
207
  print(task_type)
 
229
  _ = model.eval()
230
  return model
231
 
232
+ def load_i2sb_model():
233
+ RESULT_DIR = Path("I2SB/results")
234
+ global i2sb_model
235
+ global ckpt_opt
236
+ global corrupt_type
237
+ global nfe
238
+
239
+ s = time.time()
240
+
241
+ # main from here
242
+ log = Logger(0, ".log")
243
+
244
+ # get (default) ckpt option
245
+ ckpt_opt = ckpt_util.build_ckpt_option(i2sb_opt, log, RESULT_DIR / i2sb_opt.ckpt)
246
+ corrupt_type = ckpt_opt.corrupt
247
+ nfe = i2sb_opt.nfe or ckpt_opt.interval-1
248
+
249
+ # build corruption method
250
+ # corrupt_method = build_corruption(i2sb_opt, log, corrupt_type=cor
251
+ # rupt_type)
252
+ runner = Runner(ckpt_opt, log, save_opt=False)
253
+ if i2sb_opt.use_fp16:
254
+ runner.ema.copy_to() # copy weight from ema to net
255
+ runner.net.diffusion_model.convert_to_fp16()
256
+ runner.ema = ExponentialMovingAverage(
257
+ runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
258
+
259
+ print("Loading time:", (time.time()-s)*1e3, "ms.")
260
+ i2sb_model = runner
261
+ return runner
262
+
263
  def plot_boxes_to_image(image_pil, tgt):
264
  H, W = tgt["size"]
265
  boxes = tgt["boxes"]
 
325
  image, _ = transform(image_pil, None) # 3, h, w
326
  return image_pil, image
327
 
328
+
329
+
330
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
331
  caption = caption.lower()
332
  caption = caption.strip()
 
446
  torch_dtype=torch.float16,
447
  )
448
  sd_model = sd_model.to(device)
449
+
450
+ def forward_i2sb(img, mask):
451
+ print(np.unique(img),mask.shape)
452
+ mask = np.where(mask > 0, 1, 0)
453
+ img_tensor = i2sb_transforms(img).to(
454
+ i2sb_opt.device).unsqueeze(0)
455
+
456
+ mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
457
+ i2sb_opt.device).unsqueeze(0).unsqueeze(0)
458
+ print("POST PROCESSING\t", torch.unique(img_tensor))
459
+ # corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
460
+ f = time.time()
461
+ xs, _ = i2sb_model.ddpm_sampling(
462
+ ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
463
+ recon_img = xs[:, 0, ...].to(i2sb_opt.device)
464
+ tu.save_image((recon_img+1)/2, "output.png")
465
+ print(recon_img.shape)
466
+ return transforms.ToPILImage()(((recon_img+1)/2)[0])
467
 
468
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
469
  try:
 
618
  return new_image
619
 
620
  mask_source_draw = "draw a mask on input image"
621
+ mask_source_segment = "upload a mask"
622
 
623
  def get_time_cost(run_task_time, time_cost_str):
624
  now_time = int(time.time()*1000)
 
631
  run_task_time = now_time
632
  return run_task_time, time_cost_str
633
 
634
+ def run_anything_task(input_image, input_points, origin_image, task_type,
635
+ mask_source_radio, cleaner_size_limit=1080):
 
 
 
636
 
637
  run_task_time = 0
638
  time_cost_str = ''
 
647
  image_pil, image = load_image(input_image.convert("RGB"))
648
  input_img = input_image
649
 
650
+ kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_model, kosmos_processor)
651
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
652
  return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
653
 
 
 
 
 
 
654
  if input_image is None:
655
  return [], gr.Gallery.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
656
 
657
  file_temp = int(time.time())
658
+ logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
659
 
660
  output_images = []
661
 
662
  # load image
 
 
 
663
 
664
  if isinstance(input_image, dict):
665
  image_pil, image = load_image(input_image['image'].convert("RGB"))
 
722
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
723
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
724
  elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
725
+ if mask_source_radio == mask_source_segment:
726
  task_type = 'remove'
727
 
728
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
729
  if mask_source_radio == mask_source_draw:
730
+ input_mask_pil = input_image['mask']
731
+ input_mask = np.array(input_mask_pil.convert("L"))
732
  mask_pil = input_mask_pil
733
  mask = input_mask
734
  else:
735
  masks_ori = copy.deepcopy(masks)
 
 
736
  masks = torch.where(masks > 0, True, False)
737
  mask = masks[0][0].cpu().numpy()
738
  mask_pil = Image.fromarray(mask)
 
740
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
741
 
742
  if task_type in ['inpainting', 'outpainting']:
743
+ # image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
744
+ input_img.save("test.png")
745
+ image_inpainting = forward_i2sb(input_img, mask)
746
+
747
+ print("RESULT\t", np.array(image_inpainting))
 
 
 
 
 
 
 
748
  else:
749
  # remove from mask
750
  aasds = 1
 
770
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
771
 
772
  def change_radio_display(task_type, mask_source_radio, orig_img):
 
 
773
  mask_source_radio_visible = False
774
  num_relation_visible = False
775
 
 
780
  print(task_type)
781
  if task_type == "Kosmos-2":
782
  if kosmos_enable:
 
783
  image_gallery_visible = False
784
  kosmos_input_visible = True
785
  kosmos_output_visible = True
786
  kosmos_text_output_visible = True
787
 
 
 
788
  if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
789
  mask_source_radio_visible = True
 
 
790
  if task_type == "relate anything":
 
791
  num_relation_visible = True
792
  if task_type == "segment":
793
  ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
794
  elif task_type == "inpainting":
795
  ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
796
 
797
+ return (gr.Radio.update(visible=mask_source_radio_visible),
 
 
798
  gr.Slider.update(visible=num_relation_visible),
799
  gr.Gallery.update(visible=image_gallery_visible),
800
  gr.Radio.update(visible=kosmos_input_visible),
801
  gr.Image.update(visible=kosmos_output_visible),
802
  gr.HighlightedText.update(visible=kosmos_text_output_visible),
803
+ ret, [],
804
+ gr.Button("Undo point", visible = task_type == "segment"),
805
+ gr.Button("Clear point", visible = task_type == "segment"),)
806
 
807
  def get_model_device(module):
808
  try:
 
851
  [input_image, selected_points],
852
  [input_image]
853
  )
854
+ with gr.Row():
855
+ with gr.Column():
856
+
857
+ undo_point_button = gr.Button("Undo point")
858
+ undo_point_button.click(
859
+ fn= undo_button,
860
+ inputs=[original_image, selected_points],
861
+ outputs=[input_image]
862
+ )
863
+
864
+ with gr.Column():
865
+
866
+ clear_point_button = gr.Button("Clear point")
867
+ clear_point_button.click(
868
+ fn= clear_button,
869
+ inputs=[original_image],
870
+ outputs=[input_image, selected_points]
871
+ )
872
+
873
  print(dir(input_image))
874
  task_type = gr.Radio(task_types, value="segment",
875
  label='Task type', visible=True)
876
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
877
+ value=mask_source_draw, label="Mask from",
878
  visible=False)
 
 
879
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
880
 
881
  kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
882
 
883
  run_button = gr.Button(label="Run", visible=True)
884
+ # with gr.Accordion("Advanced options", open=False) as advanced_options:
885
+ # box_threshold = gr.Slider(
886
+ # label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
887
+ # )
888
+ # text_threshold = gr.Slider(
889
+ # label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
890
+ # )
891
+ # iou_threshold = gr.Slider(
892
+ # label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
893
+ # )
894
+ # inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
895
+ # with gr.Row():
896
+ # with gr.Column(scale=1):
897
+ # remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
898
+ # with gr.Column(scale=1):
899
+ # remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
900
 
901
  with gr.Column():
902
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
 
932
  selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
933
 
934
  run_button.click(fn=run_anything_task, inputs=[
935
+ input_image, selected_points, original_image, task_type,
936
+ mask_source_radio],
937
  outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
938
 
939
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
940
+ outputs=[mask_source_radio, num_relation])
941
  task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
942
+ outputs=[mask_source_radio, num_relation,
943
+ image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button, clear_point_button
944
  ])
945
 
946
  # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
 
986
  if sam_enable:
987
  load_sam_model(device)
988
 
989
+ if inpainting_enable:
990
+ load_sd_model(device)
991
+ load_i2sb_model()
992
 
993
  # if lama_cleaner_enable:
994
  # load_lama_cleaner_model(device)